diff --git a/pom.xml b/pom.xml index a9fd0d1..d16dc78 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ nl.andrewl record-net - 1.0.0 + 1.1.0 17 diff --git a/src/main/java/nl/andrewl/record_net/Message.java b/src/main/java/nl/andrewl/record_net/Message.java index 167d317..1c7fc7c 100644 --- a/src/main/java/nl/andrewl/record_net/Message.java +++ b/src/main/java/nl/andrewl/record_net/Message.java @@ -15,15 +15,7 @@ public interface Message { * @return The serializer to use to read and write messages of this type. */ @SuppressWarnings("unchecked") - default MessageTypeSerializer getTypeSerializer() { - return MessageTypeSerializer.get((Class) this.getClass()); - } - - /** - * Convenience method to determine the size of this message in bytes. - * @return The size of this message, in bytes. - */ - default int byteSize() { - return getTypeSerializer().byteSizeFunction().apply(this); + default MessageTypeSerializer getTypeSerializer(Serializer serializer) { + return (MessageTypeSerializer) serializer.getTypeSerializer(this.getClass()); } } diff --git a/src/main/java/nl/andrewl/record_net/MessageTypeSerializer.java b/src/main/java/nl/andrewl/record_net/MessageTypeSerializer.java index 9b25c9f..406f7f3 100644 --- a/src/main/java/nl/andrewl/record_net/MessageTypeSerializer.java +++ b/src/main/java/nl/andrewl/record_net/MessageTypeSerializer.java @@ -1,5 +1,7 @@ package nl.andrewl.record_net; +import nl.andrewl.record_net.util.Pair; + import java.lang.reflect.Constructor; import java.lang.reflect.RecordComponent; import java.util.Arrays; @@ -25,18 +27,25 @@ public record MessageTypeSerializer( MessageReader reader, MessageWriter writer ) { - private static final Map, MessageTypeSerializer> generatedMessageTypes = new HashMap<>(); + /** + * An internal cache for storing generated type serializers. + */ + private static final Map, Serializer>, MessageTypeSerializer> generatedMessageTypes = new HashMap<>(); /** * Gets the {@link MessageTypeSerializer} instance for a given message class, and * generates a new implementation if none exists yet. + * @param serializer The serializer context to get a type serializer for. * @param messageClass The class of the message to get a type for. * @param The type of the message. * @return The message type. */ @SuppressWarnings("unchecked") - public static MessageTypeSerializer get(Class messageClass) { - return (MessageTypeSerializer) generatedMessageTypes.computeIfAbsent(messageClass, c -> generateForRecord((Class) c)); + public static MessageTypeSerializer get(Serializer serializer, Class messageClass) { + return (MessageTypeSerializer) generatedMessageTypes.computeIfAbsent( + new Pair<>(messageClass, serializer), + p -> generateForRecord(serializer, (Class) p.first()) + ); } /** @@ -45,11 +54,12 @@ public record MessageTypeSerializer( *

* Note that this only works for record-based messages. *

+ * @param serializer The serializer context to get a type serializer for. * @param messageTypeClass The class of the message type. * @param The type of the message. * @return A message type instance. */ - public static MessageTypeSerializer generateForRecord(Class messageTypeClass) { + public static MessageTypeSerializer generateForRecord(Serializer serializer, Class messageTypeClass) { RecordComponent[] components = messageTypeClass.getRecordComponents(); if (components == null) throw new IllegalArgumentException("Cannot generate a MessageTypeSerializer for non-record class " + messageTypeClass.getSimpleName()); Constructor constructor; @@ -61,7 +71,7 @@ public record MessageTypeSerializer( } return new MessageTypeSerializer<>( messageTypeClass, - generateByteSizeFunction(components), + generateByteSizeFunction(serializer, components), generateReader(constructor), generateWriter(components) ); @@ -70,17 +80,18 @@ public record MessageTypeSerializer( /** * Generates a function implementation that counts the byte size of a * message based on the message's record component types. + * @param serializer The serializer context to generate a function for. * @param components The list of components that make up the message. * @param The message type. * @return A function that computes the byte size of a message of the given * type. */ - private static Function generateByteSizeFunction(RecordComponent[] components) { + private static Function generateByteSizeFunction(Serializer serializer, RecordComponent[] components) { return msg -> { int size = 0; for (var component : components) { try { - size += MessageUtils.getByteSize(component.getAccessor().invoke(msg)); + size += MessageUtils.getByteSize(serializer, component.getAccessor().invoke(msg)); } catch (ReflectiveOperationException e) { throw new IllegalStateException(e); } diff --git a/src/main/java/nl/andrewl/record_net/MessageUtils.java b/src/main/java/nl/andrewl/record_net/MessageUtils.java index 97e9226..642b742 100644 --- a/src/main/java/nl/andrewl/record_net/MessageUtils.java +++ b/src/main/java/nl/andrewl/record_net/MessageUtils.java @@ -35,19 +35,25 @@ public class MessageUtils { return size; } - public static int getByteSize(Message msg) { - return 1 + (msg == null ? 0 : msg.byteSize()); + @SuppressWarnings("unchecked") + public static int getByteSize(Serializer serializer, T msg) { + if (msg == null) { + return 1; + } else { + MessageTypeSerializer typeSerializer = (MessageTypeSerializer) serializer.getTypeSerializer(msg.getClass()); + return 1 + typeSerializer.byteSizeFunction().apply(msg); + } } - public static int getByteSize(T[] items) { + public static int getByteSize(Serializer serializer, T[] items) { int count = Integer.BYTES; for (var item : items) { - count += getByteSize(item); + count += getByteSize(serializer, item); } return count; } - public static int getByteSize(Object o) { + public static int getByteSize(Serializer serializer, Object o) { if (o instanceof Integer) { return Integer.BYTES; } else if (o instanceof Long) { @@ -61,18 +67,18 @@ public class MessageUtils { } else if (o instanceof byte[]) { return Integer.BYTES + ((byte[]) o).length; } else if (o.getClass().isArray() && Message.class.isAssignableFrom(o.getClass().getComponentType())) { - return getByteSize((Message[]) o); + return getByteSize(serializer, (Message[]) o); } else if (o instanceof Message) { - return getByteSize((Message) o); + return getByteSize(serializer, (Message) o); } else { throw new IllegalArgumentException("Unsupported object type: " + o.getClass().getSimpleName()); } } - public static int getByteSize(Object... objects) { + public static int getByteSize(Serializer serializer, Object... objects) { int size = 0; for (var o : objects) { - size += getByteSize(o); + size += getByteSize(serializer, o); } return size; } diff --git a/src/main/java/nl/andrewl/record_net/Serializer.java b/src/main/java/nl/andrewl/record_net/Serializer.java index 82c2af2..b25622f 100644 --- a/src/main/java/nl/andrewl/record_net/Serializer.java +++ b/src/main/java/nl/andrewl/record_net/Serializer.java @@ -19,6 +19,8 @@ public class Serializer { */ private final Map> messageTypes = new HashMap<>(); + private final Map, MessageTypeSerializer> messageTypeClasses = new HashMap<>(); + /** * An inverse of {@link Serializer#messageTypes} which is used to look up a * message's byte value when you know the class of the message. @@ -47,10 +49,25 @@ public class Serializer { * @param messageClass The type of message associated with the given id. */ public synchronized void registerType(int id, Class messageClass) { + registerTypeSerializer(id, MessageTypeSerializer.generateForRecord(this, messageClass)); + } + + public synchronized void registerTypeSerializer(int id, MessageTypeSerializer typeSerializer) { if (id < 0 || id > 127) throw new IllegalArgumentException("Invalid id."); - MessageTypeSerializer type = MessageTypeSerializer.get(messageClass); - messageTypes.put((byte)id, type); - inverseMessageTypes.put(type, (byte)id); + messageTypes.put((byte) id, typeSerializer); + inverseMessageTypes.put(typeSerializer, (byte) id); + messageTypeClasses.put(typeSerializer.messageClass(), typeSerializer); + } + + /** + * Gets the {@link MessageTypeSerializer} for the given message class. + * @param messageType The class of message to get the serializer for. + * @return The message type serializer. + * @param The type of message. + */ + @SuppressWarnings("unchecked") + public MessageTypeSerializer getTypeSerializer(Class messageType) { + return (MessageTypeSerializer) messageTypeClasses.get(messageType); } /** @@ -63,7 +80,7 @@ public class Serializer { * constructed for the incoming data. */ public Message readMessage(InputStream i) throws IOException { - ExtendedDataInputStream d = new ExtendedDataInputStream(i); + ExtendedDataInputStream d = new ExtendedDataInputStream(this, i); byte typeId = d.readByte(); var type = messageTypes.get(typeId); if (type == null) { @@ -99,12 +116,12 @@ public class Serializer { */ public void writeMessage(T msg, OutputStream o) throws IOException { DataOutputStream d = new DataOutputStream(o); - Byte typeId = inverseMessageTypes.get(msg.getTypeSerializer()); + Byte typeId = inverseMessageTypes.get(msg.getTypeSerializer(this)); if (typeId == null) { throw new IOException("Unsupported message type: " + msg.getClass().getSimpleName()); } d.writeByte(typeId); - msg.getTypeSerializer().writer().write(msg, new ExtendedDataOutputStream(d)); + msg.getTypeSerializer(this).writer().write(msg, new ExtendedDataOutputStream(this, d)); d.flush(); } @@ -117,7 +134,8 @@ public class Serializer { * to write is not supported by this serializer. */ public byte[] writeMessage(T msg) throws IOException { - ByteArrayOutputStream out = new ByteArrayOutputStream(1 + msg.byteSize()); + int bytes = msg.getTypeSerializer(this).byteSizeFunction().apply(msg); + ByteArrayOutputStream out = new ByteArrayOutputStream(1 + bytes); writeMessage(msg, out); return out.toByteArray(); } diff --git a/src/main/java/nl/andrewl/record_net/util/ExtendedDataInputStream.java b/src/main/java/nl/andrewl/record_net/util/ExtendedDataInputStream.java index 8921d5f..d40c723 100644 --- a/src/main/java/nl/andrewl/record_net/util/ExtendedDataInputStream.java +++ b/src/main/java/nl/andrewl/record_net/util/ExtendedDataInputStream.java @@ -2,6 +2,7 @@ package nl.andrewl.record_net.util; import nl.andrewl.record_net.Message; import nl.andrewl.record_net.MessageTypeSerializer; +import nl.andrewl.record_net.Serializer; import java.io.ByteArrayInputStream; import java.io.DataInputStream; @@ -16,12 +17,15 @@ import java.util.UUID; * complex types that are used by the Concord system. */ public class ExtendedDataInputStream extends DataInputStream { - public ExtendedDataInputStream(InputStream in) { + private final Serializer serializer; + + public ExtendedDataInputStream(Serializer serializer, InputStream in) { super(in); + this.serializer = serializer; } - public ExtendedDataInputStream(byte[] data) { - this(new ByteArrayInputStream(data)); + public ExtendedDataInputStream(Serializer serializer, byte[] data) { + this(serializer, new ByteArrayInputStream(data)); } public String readString() throws IOException { @@ -81,10 +85,10 @@ public class ExtendedDataInputStream extends DataInputStream { int length = this.readInt(); return this.readNBytes(length); } else if (type.isArray() && Message.class.isAssignableFrom(type.getComponentType())) { - var messageType = MessageTypeSerializer.get((Class) type.getComponentType()); + var messageType = MessageTypeSerializer.get(serializer, (Class) type.getComponentType()); return this.readArray(messageType); } else if (Message.class.isAssignableFrom(type)) { - var messageType = MessageTypeSerializer.get((Class) type); + var messageType = MessageTypeSerializer.get(serializer, (Class) type); return messageType.reader().read(this); } else { throw new IOException("Unsupported object type: " + type.getSimpleName()); diff --git a/src/main/java/nl/andrewl/record_net/util/ExtendedDataOutputStream.java b/src/main/java/nl/andrewl/record_net/util/ExtendedDataOutputStream.java index 03922c6..955d7f7 100644 --- a/src/main/java/nl/andrewl/record_net/util/ExtendedDataOutputStream.java +++ b/src/main/java/nl/andrewl/record_net/util/ExtendedDataOutputStream.java @@ -1,6 +1,7 @@ package nl.andrewl.record_net.util; import nl.andrewl.record_net.Message; +import nl.andrewl.record_net.Serializer; import java.io.DataOutputStream; import java.io.IOException; @@ -13,8 +14,11 @@ import java.util.UUID; * that help us to write more data. */ public class ExtendedDataOutputStream extends DataOutputStream { - public ExtendedDataOutputStream(OutputStream out) { + private final Serializer serializer; + + public ExtendedDataOutputStream(Serializer serializer, OutputStream out) { super(out); + this.serializer = serializer; } /** @@ -125,7 +129,7 @@ public class ExtendedDataOutputStream extends DataOutputStream { public void writeMessage(Message msg) throws IOException { writeBoolean(msg != null); if (msg != null) { - msg.getTypeSerializer().writer().write(msg, this); + msg.getTypeSerializer(serializer).writer().write(msg, this); } } diff --git a/src/test/java/nl/andrewl/record_net/MessageTypeSerializerTest.java b/src/test/java/nl/andrewl/record_net/MessageTypeSerializerTest.java index 7d8b974..cd1c006 100644 --- a/src/test/java/nl/andrewl/record_net/MessageTypeSerializerTest.java +++ b/src/test/java/nl/andrewl/record_net/MessageTypeSerializerTest.java @@ -14,21 +14,22 @@ import static org.junit.jupiter.api.Assertions.assertThrows; public class MessageTypeSerializerTest { @Test public void testGenerateForRecord() throws IOException { - var s1 = MessageTypeSerializer.get(ChatMessage.class); + Serializer serializer = new Serializer(); + var s1 = MessageTypeSerializer.get(serializer, ChatMessage.class); ChatMessage msg = new ChatMessage("andrew", 123, "Hello world!"); int expectedByteSize = 4 + msg.username().length() + 8 + 4 + msg.message().length(); assertEquals(expectedByteSize, s1.byteSizeFunction().apply(msg)); - assertEquals(expectedByteSize, msg.byteSize()); + ByteArrayOutputStream bOut = new ByteArrayOutputStream(); - ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(bOut); + ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(serializer, bOut); s1.writer().write(msg, eOut); byte[] data = bOut.toByteArray(); assertEquals(expectedByteSize, data.length); - ChatMessage readMsg = s1.reader().read(new ExtendedDataInputStream(data)); + ChatMessage readMsg = s1.reader().read(new ExtendedDataInputStream(serializer, data)); assertEquals(msg, readMsg); // Only record classes can be generated. class NonRecordMessage implements Message {} - assertThrows(IllegalArgumentException.class, () -> MessageTypeSerializer.get(NonRecordMessage.class)); + assertThrows(IllegalArgumentException.class, () -> MessageTypeSerializer.get(serializer, NonRecordMessage.class)); } } diff --git a/src/test/java/nl/andrewl/record_net/MessageUtilsTest.java b/src/test/java/nl/andrewl/record_net/MessageUtilsTest.java index 20f5fae..e155d9a 100644 --- a/src/test/java/nl/andrewl/record_net/MessageUtilsTest.java +++ b/src/test/java/nl/andrewl/record_net/MessageUtilsTest.java @@ -18,10 +18,12 @@ public class MessageUtilsTest { assertEquals(10, MessageUtils.getByteSize("a", "b")); Message msg = new ChatMessage("andrew", 123, "Hello world!"); int expectedMsgSize = 1 + 4 + 6 + 8 + 4 + 12; - assertEquals(1, MessageUtils.getByteSize((Message) null)); - assertEquals(expectedMsgSize, MessageUtils.getByteSize(msg)); - assertEquals(4 * expectedMsgSize, MessageUtils.getByteSize(msg, msg, msg, msg)); - assertEquals(16, MessageUtils.getByteSize(UUID.randomUUID())); - assertEquals(4, MessageUtils.getByteSize(StandardCopyOption.ATOMIC_MOVE)); + Serializer serializer = new Serializer(); + serializer.registerType(1, ChatMessage.class); + assertEquals(1, MessageUtils.getByteSize(serializer, (Message) null)); + assertEquals(expectedMsgSize, MessageUtils.getByteSize(serializer, msg)); + assertEquals(4 * expectedMsgSize, MessageUtils.getByteSize(serializer, msg, msg, msg, msg)); + assertEquals(16, MessageUtils.getByteSize(serializer, UUID.randomUUID())); + assertEquals(4, MessageUtils.getByteSize(serializer, StandardCopyOption.ATOMIC_MOVE)); } } diff --git a/src/test/java/nl/andrewl/record_net/SerializerTest.java b/src/test/java/nl/andrewl/record_net/SerializerTest.java index 02d1e98..d6fcc16 100644 --- a/src/test/java/nl/andrewl/record_net/SerializerTest.java +++ b/src/test/java/nl/andrewl/record_net/SerializerTest.java @@ -20,7 +20,7 @@ public class SerializerTest { ByteArrayOutputStream bOut = new ByteArrayOutputStream(); s.writeMessage(msg, bOut); byte[] data = bOut.toByteArray(); - assertEquals(1 + msg.byteSize(), data.length); + assertEquals(MessageUtils.getByteSize(s, msg), data.length); assertEquals(data[0], 1); ChatMessage readMsg = (ChatMessage) s.readMessage(new ByteArrayInputStream(data)); diff --git a/src/test/java/nl/andrewl/record_net/util/ExtendedDataOutputStreamTest.java b/src/test/java/nl/andrewl/record_net/util/ExtendedDataOutputStreamTest.java index f79540c..fbb9968 100644 --- a/src/test/java/nl/andrewl/record_net/util/ExtendedDataOutputStreamTest.java +++ b/src/test/java/nl/andrewl/record_net/util/ExtendedDataOutputStreamTest.java @@ -1,5 +1,6 @@ package nl.andrewl.record_net.util; +import nl.andrewl.record_net.Serializer; import org.junit.jupiter.api.Test; import java.io.ByteArrayInputStream; @@ -13,8 +14,9 @@ public class ExtendedDataOutputStreamTest { @Test public void testWriteString() throws IOException { + Serializer serializer = new Serializer(); ByteArrayOutputStream bOut = new ByteArrayOutputStream(); - ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(bOut); + ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(serializer, bOut); eOut.writeString("Hello world!"); byte[] data = bOut.toByteArray(); assertEquals(4 + "Hello world!".length(), data.length);