Refactored to remove simple static class cache.

This commit is contained in:
Andrew Lalis 2022-04-16 13:41:01 +02:00
parent 7cc9327fef
commit 1d45822c67
11 changed files with 93 additions and 53 deletions

View File

@ -6,7 +6,7 @@
<groupId>nl.andrewl</groupId> <groupId>nl.andrewl</groupId>
<artifactId>record-net</artifactId> <artifactId>record-net</artifactId>
<version>1.0.0</version> <version>1.1.0</version>
<properties> <properties>
<maven.compiler.source>17</maven.compiler.source> <maven.compiler.source>17</maven.compiler.source>

View File

@ -15,15 +15,7 @@ public interface Message {
* @return The serializer to use to read and write messages of this type. * @return The serializer to use to read and write messages of this type.
*/ */
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
default <T extends Message> MessageTypeSerializer<T> getTypeSerializer() { default <T extends Message> MessageTypeSerializer<T> getTypeSerializer(Serializer serializer) {
return MessageTypeSerializer.get((Class<T>) this.getClass()); return (MessageTypeSerializer<T>) serializer.getTypeSerializer(this.getClass());
}
/**
* Convenience method to determine the size of this message in bytes.
* @return The size of this message, in bytes.
*/
default int byteSize() {
return getTypeSerializer().byteSizeFunction().apply(this);
} }
} }

View File

@ -1,5 +1,7 @@
package nl.andrewl.record_net; package nl.andrewl.record_net;
import nl.andrewl.record_net.util.Pair;
import java.lang.reflect.Constructor; import java.lang.reflect.Constructor;
import java.lang.reflect.RecordComponent; import java.lang.reflect.RecordComponent;
import java.util.Arrays; import java.util.Arrays;
@ -25,18 +27,25 @@ public record MessageTypeSerializer<T extends Message>(
MessageReader<T> reader, MessageReader<T> reader,
MessageWriter<T> writer MessageWriter<T> writer
) { ) {
private static final Map<Class<?>, MessageTypeSerializer<?>> generatedMessageTypes = new HashMap<>(); /**
* An internal cache for storing generated type serializers.
*/
private static final Map<Pair<Class<?>, Serializer>, MessageTypeSerializer<?>> generatedMessageTypes = new HashMap<>();
/** /**
* Gets the {@link MessageTypeSerializer} instance for a given message class, and * Gets the {@link MessageTypeSerializer} instance for a given message class, and
* generates a new implementation if none exists yet. * generates a new implementation if none exists yet.
* @param serializer The serializer context to get a type serializer for.
* @param messageClass The class of the message to get a type for. * @param messageClass The class of the message to get a type for.
* @param <T> The type of the message. * @param <T> The type of the message.
* @return The message type. * @return The message type.
*/ */
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public static <T extends Message> MessageTypeSerializer<T> get(Class<T> messageClass) { public static <T extends Message> MessageTypeSerializer<T> get(Serializer serializer, Class<T> messageClass) {
return (MessageTypeSerializer<T>) generatedMessageTypes.computeIfAbsent(messageClass, c -> generateForRecord((Class<T>) c)); return (MessageTypeSerializer<T>) generatedMessageTypes.computeIfAbsent(
new Pair<>(messageClass, serializer),
p -> generateForRecord(serializer, (Class<T>) p.first())
);
} }
/** /**
@ -45,11 +54,12 @@ public record MessageTypeSerializer<T extends Message>(
* <p> * <p>
* Note that this only works for record-based messages. * Note that this only works for record-based messages.
* </p> * </p>
* @param serializer The serializer context to get a type serializer for.
* @param messageTypeClass The class of the message type. * @param messageTypeClass The class of the message type.
* @param <T> The type of the message. * @param <T> The type of the message.
* @return A message type instance. * @return A message type instance.
*/ */
public static <T extends Message> MessageTypeSerializer<T> generateForRecord(Class<T> messageTypeClass) { public static <T extends Message> MessageTypeSerializer<T> generateForRecord(Serializer serializer, Class<T> messageTypeClass) {
RecordComponent[] components = messageTypeClass.getRecordComponents(); RecordComponent[] components = messageTypeClass.getRecordComponents();
if (components == null) throw new IllegalArgumentException("Cannot generate a MessageTypeSerializer for non-record class " + messageTypeClass.getSimpleName()); if (components == null) throw new IllegalArgumentException("Cannot generate a MessageTypeSerializer for non-record class " + messageTypeClass.getSimpleName());
Constructor<T> constructor; Constructor<T> constructor;
@ -61,7 +71,7 @@ public record MessageTypeSerializer<T extends Message>(
} }
return new MessageTypeSerializer<>( return new MessageTypeSerializer<>(
messageTypeClass, messageTypeClass,
generateByteSizeFunction(components), generateByteSizeFunction(serializer, components),
generateReader(constructor), generateReader(constructor),
generateWriter(components) generateWriter(components)
); );
@ -70,17 +80,18 @@ public record MessageTypeSerializer<T extends Message>(
/** /**
* Generates a function implementation that counts the byte size of a * Generates a function implementation that counts the byte size of a
* message based on the message's record component types. * message based on the message's record component types.
* @param serializer The serializer context to generate a function for.
* @param components The list of components that make up the message. * @param components The list of components that make up the message.
* @param <T> The message type. * @param <T> The message type.
* @return A function that computes the byte size of a message of the given * @return A function that computes the byte size of a message of the given
* type. * type.
*/ */
private static <T extends Message> Function<T, Integer> generateByteSizeFunction(RecordComponent[] components) { private static <T extends Message> Function<T, Integer> generateByteSizeFunction(Serializer serializer, RecordComponent[] components) {
return msg -> { return msg -> {
int size = 0; int size = 0;
for (var component : components) { for (var component : components) {
try { try {
size += MessageUtils.getByteSize(component.getAccessor().invoke(msg)); size += MessageUtils.getByteSize(serializer, component.getAccessor().invoke(msg));
} catch (ReflectiveOperationException e) { } catch (ReflectiveOperationException e) {
throw new IllegalStateException(e); throw new IllegalStateException(e);
} }

View File

@ -35,19 +35,25 @@ public class MessageUtils {
return size; return size;
} }
public static int getByteSize(Message msg) { @SuppressWarnings("unchecked")
return 1 + (msg == null ? 0 : msg.byteSize()); public static <T extends Message> int getByteSize(Serializer serializer, T msg) {
if (msg == null) {
return 1;
} else {
MessageTypeSerializer<T> typeSerializer = (MessageTypeSerializer<T>) serializer.getTypeSerializer(msg.getClass());
return 1 + typeSerializer.byteSizeFunction().apply(msg);
}
} }
public static <T extends Message> int getByteSize(T[] items) { public static <T extends Message> int getByteSize(Serializer serializer, T[] items) {
int count = Integer.BYTES; int count = Integer.BYTES;
for (var item : items) { for (var item : items) {
count += getByteSize(item); count += getByteSize(serializer, item);
} }
return count; return count;
} }
public static int getByteSize(Object o) { public static int getByteSize(Serializer serializer, Object o) {
if (o instanceof Integer) { if (o instanceof Integer) {
return Integer.BYTES; return Integer.BYTES;
} else if (o instanceof Long) { } else if (o instanceof Long) {
@ -61,18 +67,18 @@ public class MessageUtils {
} else if (o instanceof byte[]) { } else if (o instanceof byte[]) {
return Integer.BYTES + ((byte[]) o).length; return Integer.BYTES + ((byte[]) o).length;
} else if (o.getClass().isArray() && Message.class.isAssignableFrom(o.getClass().getComponentType())) { } else if (o.getClass().isArray() && Message.class.isAssignableFrom(o.getClass().getComponentType())) {
return getByteSize((Message[]) o); return getByteSize(serializer, (Message[]) o);
} else if (o instanceof Message) { } else if (o instanceof Message) {
return getByteSize((Message) o); return getByteSize(serializer, (Message) o);
} else { } else {
throw new IllegalArgumentException("Unsupported object type: " + o.getClass().getSimpleName()); throw new IllegalArgumentException("Unsupported object type: " + o.getClass().getSimpleName());
} }
} }
public static int getByteSize(Object... objects) { public static int getByteSize(Serializer serializer, Object... objects) {
int size = 0; int size = 0;
for (var o : objects) { for (var o : objects) {
size += getByteSize(o); size += getByteSize(serializer, o);
} }
return size; return size;
} }

View File

@ -19,6 +19,8 @@ public class Serializer {
*/ */
private final Map<Byte, MessageTypeSerializer<?>> messageTypes = new HashMap<>(); private final Map<Byte, MessageTypeSerializer<?>> messageTypes = new HashMap<>();
private final Map<Class<?>, MessageTypeSerializer<?>> messageTypeClasses = new HashMap<>();
/** /**
* An inverse of {@link Serializer#messageTypes} which is used to look up a * An inverse of {@link Serializer#messageTypes} which is used to look up a
* message's byte value when you know the class of the message. * message's byte value when you know the class of the message.
@ -47,10 +49,25 @@ public class Serializer {
* @param messageClass The type of message associated with the given id. * @param messageClass The type of message associated with the given id.
*/ */
public synchronized <T extends Message> void registerType(int id, Class<T> messageClass) { public synchronized <T extends Message> void registerType(int id, Class<T> messageClass) {
registerTypeSerializer(id, MessageTypeSerializer.generateForRecord(this, messageClass));
}
public synchronized <T extends Message> void registerTypeSerializer(int id, MessageTypeSerializer<T> typeSerializer) {
if (id < 0 || id > 127) throw new IllegalArgumentException("Invalid id."); if (id < 0 || id > 127) throw new IllegalArgumentException("Invalid id.");
MessageTypeSerializer<T> type = MessageTypeSerializer.get(messageClass); messageTypes.put((byte) id, typeSerializer);
messageTypes.put((byte)id, type); inverseMessageTypes.put(typeSerializer, (byte) id);
inverseMessageTypes.put(type, (byte)id); messageTypeClasses.put(typeSerializer.messageClass(), typeSerializer);
}
/**
* Gets the {@link MessageTypeSerializer} for the given message class.
* @param messageType The class of message to get the serializer for.
* @return The message type serializer.
* @param <T> The type of message.
*/
@SuppressWarnings("unchecked")
public <T extends Message> MessageTypeSerializer<T> getTypeSerializer(Class<T> messageType) {
return (MessageTypeSerializer<T>) messageTypeClasses.get(messageType);
} }
/** /**
@ -63,7 +80,7 @@ public class Serializer {
* constructed for the incoming data. * constructed for the incoming data.
*/ */
public Message readMessage(InputStream i) throws IOException { public Message readMessage(InputStream i) throws IOException {
ExtendedDataInputStream d = new ExtendedDataInputStream(i); ExtendedDataInputStream d = new ExtendedDataInputStream(this, i);
byte typeId = d.readByte(); byte typeId = d.readByte();
var type = messageTypes.get(typeId); var type = messageTypes.get(typeId);
if (type == null) { if (type == null) {
@ -99,12 +116,12 @@ public class Serializer {
*/ */
public <T extends Message> void writeMessage(T msg, OutputStream o) throws IOException { public <T extends Message> void writeMessage(T msg, OutputStream o) throws IOException {
DataOutputStream d = new DataOutputStream(o); DataOutputStream d = new DataOutputStream(o);
Byte typeId = inverseMessageTypes.get(msg.getTypeSerializer()); Byte typeId = inverseMessageTypes.get(msg.getTypeSerializer(this));
if (typeId == null) { if (typeId == null) {
throw new IOException("Unsupported message type: " + msg.getClass().getSimpleName()); throw new IOException("Unsupported message type: " + msg.getClass().getSimpleName());
} }
d.writeByte(typeId); d.writeByte(typeId);
msg.getTypeSerializer().writer().write(msg, new ExtendedDataOutputStream(d)); msg.getTypeSerializer(this).writer().write(msg, new ExtendedDataOutputStream(this, d));
d.flush(); d.flush();
} }
@ -117,7 +134,8 @@ public class Serializer {
* to write is not supported by this serializer. * to write is not supported by this serializer.
*/ */
public <T extends Message> byte[] writeMessage(T msg) throws IOException { public <T extends Message> byte[] writeMessage(T msg) throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream(1 + msg.byteSize()); int bytes = msg.getTypeSerializer(this).byteSizeFunction().apply(msg);
ByteArrayOutputStream out = new ByteArrayOutputStream(1 + bytes);
writeMessage(msg, out); writeMessage(msg, out);
return out.toByteArray(); return out.toByteArray();
} }

View File

@ -2,6 +2,7 @@ package nl.andrewl.record_net.util;
import nl.andrewl.record_net.Message; import nl.andrewl.record_net.Message;
import nl.andrewl.record_net.MessageTypeSerializer; import nl.andrewl.record_net.MessageTypeSerializer;
import nl.andrewl.record_net.Serializer;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.DataInputStream; import java.io.DataInputStream;
@ -16,12 +17,15 @@ import java.util.UUID;
* complex types that are used by the Concord system. * complex types that are used by the Concord system.
*/ */
public class ExtendedDataInputStream extends DataInputStream { public class ExtendedDataInputStream extends DataInputStream {
public ExtendedDataInputStream(InputStream in) { private final Serializer serializer;
public ExtendedDataInputStream(Serializer serializer, InputStream in) {
super(in); super(in);
this.serializer = serializer;
} }
public ExtendedDataInputStream(byte[] data) { public ExtendedDataInputStream(Serializer serializer, byte[] data) {
this(new ByteArrayInputStream(data)); this(serializer, new ByteArrayInputStream(data));
} }
public String readString() throws IOException { public String readString() throws IOException {
@ -81,10 +85,10 @@ public class ExtendedDataInputStream extends DataInputStream {
int length = this.readInt(); int length = this.readInt();
return this.readNBytes(length); return this.readNBytes(length);
} else if (type.isArray() && Message.class.isAssignableFrom(type.getComponentType())) { } else if (type.isArray() && Message.class.isAssignableFrom(type.getComponentType())) {
var messageType = MessageTypeSerializer.get((Class<? extends Message>) type.getComponentType()); var messageType = MessageTypeSerializer.get(serializer, (Class<? extends Message>) type.getComponentType());
return this.readArray(messageType); return this.readArray(messageType);
} else if (Message.class.isAssignableFrom(type)) { } else if (Message.class.isAssignableFrom(type)) {
var messageType = MessageTypeSerializer.get((Class<? extends Message>) type); var messageType = MessageTypeSerializer.get(serializer, (Class<? extends Message>) type);
return messageType.reader().read(this); return messageType.reader().read(this);
} else { } else {
throw new IOException("Unsupported object type: " + type.getSimpleName()); throw new IOException("Unsupported object type: " + type.getSimpleName());

View File

@ -1,6 +1,7 @@
package nl.andrewl.record_net.util; package nl.andrewl.record_net.util;
import nl.andrewl.record_net.Message; import nl.andrewl.record_net.Message;
import nl.andrewl.record_net.Serializer;
import java.io.DataOutputStream; import java.io.DataOutputStream;
import java.io.IOException; import java.io.IOException;
@ -13,8 +14,11 @@ import java.util.UUID;
* that help us to write more data. * that help us to write more data.
*/ */
public class ExtendedDataOutputStream extends DataOutputStream { public class ExtendedDataOutputStream extends DataOutputStream {
public ExtendedDataOutputStream(OutputStream out) { private final Serializer serializer;
public ExtendedDataOutputStream(Serializer serializer, OutputStream out) {
super(out); super(out);
this.serializer = serializer;
} }
/** /**
@ -125,7 +129,7 @@ public class ExtendedDataOutputStream extends DataOutputStream {
public <T extends Message> void writeMessage(Message msg) throws IOException { public <T extends Message> void writeMessage(Message msg) throws IOException {
writeBoolean(msg != null); writeBoolean(msg != null);
if (msg != null) { if (msg != null) {
msg.getTypeSerializer().writer().write(msg, this); msg.getTypeSerializer(serializer).writer().write(msg, this);
} }
} }

View File

@ -14,21 +14,22 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
public class MessageTypeSerializerTest { public class MessageTypeSerializerTest {
@Test @Test
public void testGenerateForRecord() throws IOException { public void testGenerateForRecord() throws IOException {
var s1 = MessageTypeSerializer.get(ChatMessage.class); Serializer serializer = new Serializer();
var s1 = MessageTypeSerializer.get(serializer, ChatMessage.class);
ChatMessage msg = new ChatMessage("andrew", 123, "Hello world!"); ChatMessage msg = new ChatMessage("andrew", 123, "Hello world!");
int expectedByteSize = 4 + msg.username().length() + 8 + 4 + msg.message().length(); int expectedByteSize = 4 + msg.username().length() + 8 + 4 + msg.message().length();
assertEquals(expectedByteSize, s1.byteSizeFunction().apply(msg)); assertEquals(expectedByteSize, s1.byteSizeFunction().apply(msg));
assertEquals(expectedByteSize, msg.byteSize());
ByteArrayOutputStream bOut = new ByteArrayOutputStream(); ByteArrayOutputStream bOut = new ByteArrayOutputStream();
ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(bOut); ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(serializer, bOut);
s1.writer().write(msg, eOut); s1.writer().write(msg, eOut);
byte[] data = bOut.toByteArray(); byte[] data = bOut.toByteArray();
assertEquals(expectedByteSize, data.length); assertEquals(expectedByteSize, data.length);
ChatMessage readMsg = s1.reader().read(new ExtendedDataInputStream(data)); ChatMessage readMsg = s1.reader().read(new ExtendedDataInputStream(serializer, data));
assertEquals(msg, readMsg); assertEquals(msg, readMsg);
// Only record classes can be generated. // Only record classes can be generated.
class NonRecordMessage implements Message {} class NonRecordMessage implements Message {}
assertThrows(IllegalArgumentException.class, () -> MessageTypeSerializer.get(NonRecordMessage.class)); assertThrows(IllegalArgumentException.class, () -> MessageTypeSerializer.get(serializer, NonRecordMessage.class));
} }
} }

View File

@ -18,10 +18,12 @@ public class MessageUtilsTest {
assertEquals(10, MessageUtils.getByteSize("a", "b")); assertEquals(10, MessageUtils.getByteSize("a", "b"));
Message msg = new ChatMessage("andrew", 123, "Hello world!"); Message msg = new ChatMessage("andrew", 123, "Hello world!");
int expectedMsgSize = 1 + 4 + 6 + 8 + 4 + 12; int expectedMsgSize = 1 + 4 + 6 + 8 + 4 + 12;
assertEquals(1, MessageUtils.getByteSize((Message) null)); Serializer serializer = new Serializer();
assertEquals(expectedMsgSize, MessageUtils.getByteSize(msg)); serializer.registerType(1, ChatMessage.class);
assertEquals(4 * expectedMsgSize, MessageUtils.getByteSize(msg, msg, msg, msg)); assertEquals(1, MessageUtils.getByteSize(serializer, (Message) null));
assertEquals(16, MessageUtils.getByteSize(UUID.randomUUID())); assertEquals(expectedMsgSize, MessageUtils.getByteSize(serializer, msg));
assertEquals(4, MessageUtils.getByteSize(StandardCopyOption.ATOMIC_MOVE)); assertEquals(4 * expectedMsgSize, MessageUtils.getByteSize(serializer, msg, msg, msg, msg));
assertEquals(16, MessageUtils.getByteSize(serializer, UUID.randomUUID()));
assertEquals(4, MessageUtils.getByteSize(serializer, StandardCopyOption.ATOMIC_MOVE));
} }
} }

View File

@ -20,7 +20,7 @@ public class SerializerTest {
ByteArrayOutputStream bOut = new ByteArrayOutputStream(); ByteArrayOutputStream bOut = new ByteArrayOutputStream();
s.writeMessage(msg, bOut); s.writeMessage(msg, bOut);
byte[] data = bOut.toByteArray(); byte[] data = bOut.toByteArray();
assertEquals(1 + msg.byteSize(), data.length); assertEquals(MessageUtils.getByteSize(s, msg), data.length);
assertEquals(data[0], 1); assertEquals(data[0], 1);
ChatMessage readMsg = (ChatMessage) s.readMessage(new ByteArrayInputStream(data)); ChatMessage readMsg = (ChatMessage) s.readMessage(new ByteArrayInputStream(data));

View File

@ -1,5 +1,6 @@
package nl.andrewl.record_net.util; package nl.andrewl.record_net.util;
import nl.andrewl.record_net.Serializer;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
@ -13,8 +14,9 @@ public class ExtendedDataOutputStreamTest {
@Test @Test
public void testWriteString() throws IOException { public void testWriteString() throws IOException {
Serializer serializer = new Serializer();
ByteArrayOutputStream bOut = new ByteArrayOutputStream(); ByteArrayOutputStream bOut = new ByteArrayOutputStream();
ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(bOut); ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(serializer, bOut);
eOut.writeString("Hello world!"); eOut.writeString("Hello world!");
byte[] data = bOut.toByteArray(); byte[] data = bOut.toByteArray();
assertEquals(4 + "Hello world!".length(), data.length); assertEquals(4 + "Hello world!".length(), data.length);