(
*
* Note that this only works for record-based messages.
*
+ * @param serializer The serializer context to get a type serializer for.
* @param messageTypeClass The class of the message type.
* @param The type of the message.
* @return A message type instance.
*/
- public static MessageTypeSerializer generateForRecord(Class messageTypeClass) {
+ public static MessageTypeSerializer generateForRecord(Serializer serializer, Class messageTypeClass) {
RecordComponent[] components = messageTypeClass.getRecordComponents();
if (components == null) throw new IllegalArgumentException("Cannot generate a MessageTypeSerializer for non-record class " + messageTypeClass.getSimpleName());
Constructor constructor;
@@ -61,7 +71,7 @@ public record MessageTypeSerializer(
}
return new MessageTypeSerializer<>(
messageTypeClass,
- generateByteSizeFunction(components),
+ generateByteSizeFunction(serializer, components),
generateReader(constructor),
generateWriter(components)
);
@@ -70,17 +80,18 @@ public record MessageTypeSerializer(
/**
* Generates a function implementation that counts the byte size of a
* message based on the message's record component types.
+ * @param serializer The serializer context to generate a function for.
* @param components The list of components that make up the message.
* @param The message type.
* @return A function that computes the byte size of a message of the given
* type.
*/
- private static Function generateByteSizeFunction(RecordComponent[] components) {
+ private static Function generateByteSizeFunction(Serializer serializer, RecordComponent[] components) {
return msg -> {
int size = 0;
for (var component : components) {
try {
- size += MessageUtils.getByteSize(component.getAccessor().invoke(msg));
+ size += MessageUtils.getByteSize(serializer, component.getAccessor().invoke(msg));
} catch (ReflectiveOperationException e) {
throw new IllegalStateException(e);
}
diff --git a/src/main/java/nl/andrewl/record_net/MessageUtils.java b/src/main/java/nl/andrewl/record_net/MessageUtils.java
index 97e9226..642b742 100644
--- a/src/main/java/nl/andrewl/record_net/MessageUtils.java
+++ b/src/main/java/nl/andrewl/record_net/MessageUtils.java
@@ -35,19 +35,25 @@ public class MessageUtils {
return size;
}
- public static int getByteSize(Message msg) {
- return 1 + (msg == null ? 0 : msg.byteSize());
+ @SuppressWarnings("unchecked")
+ public static int getByteSize(Serializer serializer, T msg) {
+ if (msg == null) {
+ return 1;
+ } else {
+ MessageTypeSerializer typeSerializer = (MessageTypeSerializer) serializer.getTypeSerializer(msg.getClass());
+ return 1 + typeSerializer.byteSizeFunction().apply(msg);
+ }
}
- public static int getByteSize(T[] items) {
+ public static int getByteSize(Serializer serializer, T[] items) {
int count = Integer.BYTES;
for (var item : items) {
- count += getByteSize(item);
+ count += getByteSize(serializer, item);
}
return count;
}
- public static int getByteSize(Object o) {
+ public static int getByteSize(Serializer serializer, Object o) {
if (o instanceof Integer) {
return Integer.BYTES;
} else if (o instanceof Long) {
@@ -61,18 +67,18 @@ public class MessageUtils {
} else if (o instanceof byte[]) {
return Integer.BYTES + ((byte[]) o).length;
} else if (o.getClass().isArray() && Message.class.isAssignableFrom(o.getClass().getComponentType())) {
- return getByteSize((Message[]) o);
+ return getByteSize(serializer, (Message[]) o);
} else if (o instanceof Message) {
- return getByteSize((Message) o);
+ return getByteSize(serializer, (Message) o);
} else {
throw new IllegalArgumentException("Unsupported object type: " + o.getClass().getSimpleName());
}
}
- public static int getByteSize(Object... objects) {
+ public static int getByteSize(Serializer serializer, Object... objects) {
int size = 0;
for (var o : objects) {
- size += getByteSize(o);
+ size += getByteSize(serializer, o);
}
return size;
}
diff --git a/src/main/java/nl/andrewl/record_net/Serializer.java b/src/main/java/nl/andrewl/record_net/Serializer.java
index 82c2af2..b25622f 100644
--- a/src/main/java/nl/andrewl/record_net/Serializer.java
+++ b/src/main/java/nl/andrewl/record_net/Serializer.java
@@ -19,6 +19,8 @@ public class Serializer {
*/
private final Map> messageTypes = new HashMap<>();
+ private final Map, MessageTypeSerializer>> messageTypeClasses = new HashMap<>();
+
/**
* An inverse of {@link Serializer#messageTypes} which is used to look up a
* message's byte value when you know the class of the message.
@@ -47,10 +49,25 @@ public class Serializer {
* @param messageClass The type of message associated with the given id.
*/
public synchronized void registerType(int id, Class messageClass) {
+ registerTypeSerializer(id, MessageTypeSerializer.generateForRecord(this, messageClass));
+ }
+
+ public synchronized void registerTypeSerializer(int id, MessageTypeSerializer typeSerializer) {
if (id < 0 || id > 127) throw new IllegalArgumentException("Invalid id.");
- MessageTypeSerializer type = MessageTypeSerializer.get(messageClass);
- messageTypes.put((byte)id, type);
- inverseMessageTypes.put(type, (byte)id);
+ messageTypes.put((byte) id, typeSerializer);
+ inverseMessageTypes.put(typeSerializer, (byte) id);
+ messageTypeClasses.put(typeSerializer.messageClass(), typeSerializer);
+ }
+
+ /**
+ * Gets the {@link MessageTypeSerializer} for the given message class.
+ * @param messageType The class of message to get the serializer for.
+ * @return The message type serializer.
+ * @param The type of message.
+ */
+ @SuppressWarnings("unchecked")
+ public MessageTypeSerializer getTypeSerializer(Class messageType) {
+ return (MessageTypeSerializer) messageTypeClasses.get(messageType);
}
/**
@@ -63,7 +80,7 @@ public class Serializer {
* constructed for the incoming data.
*/
public Message readMessage(InputStream i) throws IOException {
- ExtendedDataInputStream d = new ExtendedDataInputStream(i);
+ ExtendedDataInputStream d = new ExtendedDataInputStream(this, i);
byte typeId = d.readByte();
var type = messageTypes.get(typeId);
if (type == null) {
@@ -99,12 +116,12 @@ public class Serializer {
*/
public void writeMessage(T msg, OutputStream o) throws IOException {
DataOutputStream d = new DataOutputStream(o);
- Byte typeId = inverseMessageTypes.get(msg.getTypeSerializer());
+ Byte typeId = inverseMessageTypes.get(msg.getTypeSerializer(this));
if (typeId == null) {
throw new IOException("Unsupported message type: " + msg.getClass().getSimpleName());
}
d.writeByte(typeId);
- msg.getTypeSerializer().writer().write(msg, new ExtendedDataOutputStream(d));
+ msg.getTypeSerializer(this).writer().write(msg, new ExtendedDataOutputStream(this, d));
d.flush();
}
@@ -117,7 +134,8 @@ public class Serializer {
* to write is not supported by this serializer.
*/
public byte[] writeMessage(T msg) throws IOException {
- ByteArrayOutputStream out = new ByteArrayOutputStream(1 + msg.byteSize());
+ int bytes = msg.getTypeSerializer(this).byteSizeFunction().apply(msg);
+ ByteArrayOutputStream out = new ByteArrayOutputStream(1 + bytes);
writeMessage(msg, out);
return out.toByteArray();
}
diff --git a/src/main/java/nl/andrewl/record_net/util/ExtendedDataInputStream.java b/src/main/java/nl/andrewl/record_net/util/ExtendedDataInputStream.java
index 8921d5f..d40c723 100644
--- a/src/main/java/nl/andrewl/record_net/util/ExtendedDataInputStream.java
+++ b/src/main/java/nl/andrewl/record_net/util/ExtendedDataInputStream.java
@@ -2,6 +2,7 @@ package nl.andrewl.record_net.util;
import nl.andrewl.record_net.Message;
import nl.andrewl.record_net.MessageTypeSerializer;
+import nl.andrewl.record_net.Serializer;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
@@ -16,12 +17,15 @@ import java.util.UUID;
* complex types that are used by the Concord system.
*/
public class ExtendedDataInputStream extends DataInputStream {
- public ExtendedDataInputStream(InputStream in) {
+ private final Serializer serializer;
+
+ public ExtendedDataInputStream(Serializer serializer, InputStream in) {
super(in);
+ this.serializer = serializer;
}
- public ExtendedDataInputStream(byte[] data) {
- this(new ByteArrayInputStream(data));
+ public ExtendedDataInputStream(Serializer serializer, byte[] data) {
+ this(serializer, new ByteArrayInputStream(data));
}
public String readString() throws IOException {
@@ -81,10 +85,10 @@ public class ExtendedDataInputStream extends DataInputStream {
int length = this.readInt();
return this.readNBytes(length);
} else if (type.isArray() && Message.class.isAssignableFrom(type.getComponentType())) {
- var messageType = MessageTypeSerializer.get((Class extends Message>) type.getComponentType());
+ var messageType = MessageTypeSerializer.get(serializer, (Class extends Message>) type.getComponentType());
return this.readArray(messageType);
} else if (Message.class.isAssignableFrom(type)) {
- var messageType = MessageTypeSerializer.get((Class extends Message>) type);
+ var messageType = MessageTypeSerializer.get(serializer, (Class extends Message>) type);
return messageType.reader().read(this);
} else {
throw new IOException("Unsupported object type: " + type.getSimpleName());
diff --git a/src/main/java/nl/andrewl/record_net/util/ExtendedDataOutputStream.java b/src/main/java/nl/andrewl/record_net/util/ExtendedDataOutputStream.java
index 03922c6..955d7f7 100644
--- a/src/main/java/nl/andrewl/record_net/util/ExtendedDataOutputStream.java
+++ b/src/main/java/nl/andrewl/record_net/util/ExtendedDataOutputStream.java
@@ -1,6 +1,7 @@
package nl.andrewl.record_net.util;
import nl.andrewl.record_net.Message;
+import nl.andrewl.record_net.Serializer;
import java.io.DataOutputStream;
import java.io.IOException;
@@ -13,8 +14,11 @@ import java.util.UUID;
* that help us to write more data.
*/
public class ExtendedDataOutputStream extends DataOutputStream {
- public ExtendedDataOutputStream(OutputStream out) {
+ private final Serializer serializer;
+
+ public ExtendedDataOutputStream(Serializer serializer, OutputStream out) {
super(out);
+ this.serializer = serializer;
}
/**
@@ -125,7 +129,7 @@ public class ExtendedDataOutputStream extends DataOutputStream {
public void writeMessage(Message msg) throws IOException {
writeBoolean(msg != null);
if (msg != null) {
- msg.getTypeSerializer().writer().write(msg, this);
+ msg.getTypeSerializer(serializer).writer().write(msg, this);
}
}
diff --git a/src/test/java/nl/andrewl/record_net/MessageTypeSerializerTest.java b/src/test/java/nl/andrewl/record_net/MessageTypeSerializerTest.java
index 7d8b974..cd1c006 100644
--- a/src/test/java/nl/andrewl/record_net/MessageTypeSerializerTest.java
+++ b/src/test/java/nl/andrewl/record_net/MessageTypeSerializerTest.java
@@ -14,21 +14,22 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
public class MessageTypeSerializerTest {
@Test
public void testGenerateForRecord() throws IOException {
- var s1 = MessageTypeSerializer.get(ChatMessage.class);
+ Serializer serializer = new Serializer();
+ var s1 = MessageTypeSerializer.get(serializer, ChatMessage.class);
ChatMessage msg = new ChatMessage("andrew", 123, "Hello world!");
int expectedByteSize = 4 + msg.username().length() + 8 + 4 + msg.message().length();
assertEquals(expectedByteSize, s1.byteSizeFunction().apply(msg));
- assertEquals(expectedByteSize, msg.byteSize());
+
ByteArrayOutputStream bOut = new ByteArrayOutputStream();
- ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(bOut);
+ ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(serializer, bOut);
s1.writer().write(msg, eOut);
byte[] data = bOut.toByteArray();
assertEquals(expectedByteSize, data.length);
- ChatMessage readMsg = s1.reader().read(new ExtendedDataInputStream(data));
+ ChatMessage readMsg = s1.reader().read(new ExtendedDataInputStream(serializer, data));
assertEquals(msg, readMsg);
// Only record classes can be generated.
class NonRecordMessage implements Message {}
- assertThrows(IllegalArgumentException.class, () -> MessageTypeSerializer.get(NonRecordMessage.class));
+ assertThrows(IllegalArgumentException.class, () -> MessageTypeSerializer.get(serializer, NonRecordMessage.class));
}
}
diff --git a/src/test/java/nl/andrewl/record_net/MessageUtilsTest.java b/src/test/java/nl/andrewl/record_net/MessageUtilsTest.java
index 20f5fae..e155d9a 100644
--- a/src/test/java/nl/andrewl/record_net/MessageUtilsTest.java
+++ b/src/test/java/nl/andrewl/record_net/MessageUtilsTest.java
@@ -18,10 +18,12 @@ public class MessageUtilsTest {
assertEquals(10, MessageUtils.getByteSize("a", "b"));
Message msg = new ChatMessage("andrew", 123, "Hello world!");
int expectedMsgSize = 1 + 4 + 6 + 8 + 4 + 12;
- assertEquals(1, MessageUtils.getByteSize((Message) null));
- assertEquals(expectedMsgSize, MessageUtils.getByteSize(msg));
- assertEquals(4 * expectedMsgSize, MessageUtils.getByteSize(msg, msg, msg, msg));
- assertEquals(16, MessageUtils.getByteSize(UUID.randomUUID()));
- assertEquals(4, MessageUtils.getByteSize(StandardCopyOption.ATOMIC_MOVE));
+ Serializer serializer = new Serializer();
+ serializer.registerType(1, ChatMessage.class);
+ assertEquals(1, MessageUtils.getByteSize(serializer, (Message) null));
+ assertEquals(expectedMsgSize, MessageUtils.getByteSize(serializer, msg));
+ assertEquals(4 * expectedMsgSize, MessageUtils.getByteSize(serializer, msg, msg, msg, msg));
+ assertEquals(16, MessageUtils.getByteSize(serializer, UUID.randomUUID()));
+ assertEquals(4, MessageUtils.getByteSize(serializer, StandardCopyOption.ATOMIC_MOVE));
}
}
diff --git a/src/test/java/nl/andrewl/record_net/SerializerTest.java b/src/test/java/nl/andrewl/record_net/SerializerTest.java
index 02d1e98..d6fcc16 100644
--- a/src/test/java/nl/andrewl/record_net/SerializerTest.java
+++ b/src/test/java/nl/andrewl/record_net/SerializerTest.java
@@ -20,7 +20,7 @@ public class SerializerTest {
ByteArrayOutputStream bOut = new ByteArrayOutputStream();
s.writeMessage(msg, bOut);
byte[] data = bOut.toByteArray();
- assertEquals(1 + msg.byteSize(), data.length);
+ assertEquals(MessageUtils.getByteSize(s, msg), data.length);
assertEquals(data[0], 1);
ChatMessage readMsg = (ChatMessage) s.readMessage(new ByteArrayInputStream(data));
diff --git a/src/test/java/nl/andrewl/record_net/util/ExtendedDataOutputStreamTest.java b/src/test/java/nl/andrewl/record_net/util/ExtendedDataOutputStreamTest.java
index f79540c..fbb9968 100644
--- a/src/test/java/nl/andrewl/record_net/util/ExtendedDataOutputStreamTest.java
+++ b/src/test/java/nl/andrewl/record_net/util/ExtendedDataOutputStreamTest.java
@@ -1,5 +1,6 @@
package nl.andrewl.record_net.util;
+import nl.andrewl.record_net.Serializer;
import org.junit.jupiter.api.Test;
import java.io.ByteArrayInputStream;
@@ -13,8 +14,9 @@ public class ExtendedDataOutputStreamTest {
@Test
public void testWriteString() throws IOException {
+ Serializer serializer = new Serializer();
ByteArrayOutputStream bOut = new ByteArrayOutputStream();
- ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(bOut);
+ ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(serializer, bOut);
eOut.writeString("Hello world!");
byte[] data = bOut.toByteArray();
assertEquals(4 + "Hello world!".length(), data.length);