More refactoring to make the system more extensible.

This commit is contained in:
Andrew Lalis 2022-04-16 13:58:53 +02:00
parent 1d45822c67
commit c25232c3ec
7 changed files with 191 additions and 142 deletions

View File

@ -1,142 +1,35 @@
package nl.andrewl.record_net; package nl.andrewl.record_net;
import nl.andrewl.record_net.util.Pair;
import java.lang.reflect.Constructor;
import java.lang.reflect.RecordComponent;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function; import java.util.function.Function;
/** /**
* Record containing the components needed to read and write a given message. * A type serializer provides the basic components needed to read and write
* <p> * instances of the given message type.
* Also contains methods for automatically generating message type * @param <T> The message type.
* implementations for standard record-based messages.
* </p>
* @param <T> The type of message.
* @param messageClass The class of the message.
* @param byteSizeFunction A function that computes the byte size of the message.
* @param reader A reader that can read messages from an input stream.
* @param writer A writer that write messages from an input stream.
*/ */
public record MessageTypeSerializer<T extends Message>( public interface MessageTypeSerializer<T extends Message> {
Class<T> messageClass, /**
Function<T, Integer> byteSizeFunction, * Gets the class of the message type that this serializer handles.
MessageReader<T> reader, * @return The message class.
MessageWriter<T> writer */
) { Class<T> messageClass();
/**
* An internal cache for storing generated type serializers.
*/
private static final Map<Pair<Class<?>, Serializer>, MessageTypeSerializer<?>> generatedMessageTypes = new HashMap<>();
/** /**
* Gets the {@link MessageTypeSerializer} instance for a given message class, and * Gets a function that computes the size, in bytes, of messages of this
* generates a new implementation if none exists yet. * serializer's type.
* @param serializer The serializer context to get a type serializer for. * @return A byte size function.
* @param messageClass The class of the message to get a type for. */
* @param <T> The type of the message. Function<T, Integer> byteSizeFunction();
* @return The message type.
*/
@SuppressWarnings("unchecked")
public static <T extends Message> MessageTypeSerializer<T> get(Serializer serializer, Class<T> messageClass) {
return (MessageTypeSerializer<T>) generatedMessageTypes.computeIfAbsent(
new Pair<>(messageClass, serializer),
p -> generateForRecord(serializer, (Class<T>) p.first())
);
}
/** /**
* Generates a message type instance for a given class, using reflection to * Gets a component that can read messages from an input stream.
* introspect the fields of the message. * @return The message reader.
* <p> */
* Note that this only works for record-based messages. MessageReader<T> reader();
* </p>
* @param serializer The serializer context to get a type serializer for.
* @param messageTypeClass The class of the message type.
* @param <T> The type of the message.
* @return A message type instance.
*/
public static <T extends Message> MessageTypeSerializer<T> generateForRecord(Serializer serializer, Class<T> messageTypeClass) {
RecordComponent[] components = messageTypeClass.getRecordComponents();
if (components == null) throw new IllegalArgumentException("Cannot generate a MessageTypeSerializer for non-record class " + messageTypeClass.getSimpleName());
Constructor<T> constructor;
try {
constructor = messageTypeClass.getDeclaredConstructor(Arrays.stream(components)
.map(RecordComponent::getType).toArray(Class<?>[]::new));
} catch (NoSuchMethodException e) {
throw new IllegalArgumentException(e);
}
return new MessageTypeSerializer<>(
messageTypeClass,
generateByteSizeFunction(serializer, components),
generateReader(constructor),
generateWriter(components)
);
}
/** /**
* Generates a function implementation that counts the byte size of a * Gets a component that can write messages to an output stream.
* message based on the message's record component types. * @return The message writer.
* @param serializer The serializer context to generate a function for. */
* @param components The list of components that make up the message. MessageWriter<T> writer();
* @param <T> The message type.
* @return A function that computes the byte size of a message of the given
* type.
*/
private static <T extends Message> Function<T, Integer> generateByteSizeFunction(Serializer serializer, RecordComponent[] components) {
return msg -> {
int size = 0;
for (var component : components) {
try {
size += MessageUtils.getByteSize(serializer, component.getAccessor().invoke(msg));
} catch (ReflectiveOperationException e) {
throw new IllegalStateException(e);
}
}
return size;
};
}
/**
* Generates a message reader for the given message constructor method. It
* will try to read objects from the input stream according to the
* parameters of the canonical constructor of a message record.
* @param constructor The canonical constructor of the message record.
* @param <T> The message type.
* @return A message reader for the given type.
*/
private static <T extends Message> MessageReader<T> generateReader(Constructor<T> constructor) {
return in -> {
Object[] values = new Object[constructor.getParameterCount()];
for (int i = 0; i < values.length; i++) {
values[i] = in.readObject(constructor.getParameterTypes()[i]);
}
try {
return constructor.newInstance(values);
} catch (ReflectiveOperationException e) {
throw new IllegalStateException(e);
}
};
}
/**
* Generates a message writer for the given message record components.
* @param components The record components to write.
* @param <T> The type of message.
* @return The message writer for the given type.
*/
private static <T extends Message> MessageWriter<T> generateWriter(RecordComponent[] components) {
return (msg, out) -> {
for (var component: components) {
try {
out.writeObject(component.getAccessor().invoke(msg), component.getType());
} catch (ReflectiveOperationException e) {
throw new IllegalStateException(e);
}
}
};
}
} }

View File

@ -0,0 +1,18 @@
package nl.andrewl.record_net;
import java.util.function.Function;
/**
* Record containing the components needed to read and write a given message.
* @param <T> The type of message.
* @param messageClass The class of the message.
* @param byteSizeFunction A function that computes the byte size of the message.
* @param reader A reader that can read messages from an input stream.
* @param writer A writer that write messages from an input stream.
*/
public record MessageTypeSerializerImpl<T extends Message>(
Class<T> messageClass,
Function<T, Integer> byteSizeFunction,
MessageReader<T> reader,
MessageWriter<T> writer
) implements MessageTypeSerializer<T> {}

View File

@ -40,7 +40,7 @@ public class MessageUtils {
if (msg == null) { if (msg == null) {
return 1; return 1;
} else { } else {
MessageTypeSerializer<T> typeSerializer = (MessageTypeSerializer<T>) serializer.getTypeSerializer(msg.getClass()); MessageTypeSerializerImpl<T> typeSerializer = (MessageTypeSerializerImpl<T>) serializer.getTypeSerializer(msg.getClass());
return 1 + typeSerializer.byteSizeFunction().apply(msg); return 1 + typeSerializer.byteSizeFunction().apply(msg);
} }
} }

View File

@ -2,6 +2,7 @@ package nl.andrewl.record_net;
import nl.andrewl.record_net.util.ExtendedDataInputStream; import nl.andrewl.record_net.util.ExtendedDataInputStream;
import nl.andrewl.record_net.util.ExtendedDataOutputStream; import nl.andrewl.record_net.util.ExtendedDataOutputStream;
import nl.andrewl.record_net.util.RecordMessageTypeSerializer;
import java.io.*; import java.io.*;
import java.util.HashMap; import java.util.HashMap;
@ -37,21 +38,28 @@ public class Serializer {
* their ids. * their ids.
* @param messageTypes A map containing message types mapped to their ids. * @param messageTypes A map containing message types mapped to their ids.
*/ */
public Serializer(Map<Byte, Class<? extends Message>> messageTypes) { public Serializer(Map<Integer, Class<? extends Message>> messageTypes) {
messageTypes.forEach(this::registerType); messageTypes.forEach(this::registerType);
} }
/** /**
* Helper method which registers a message type to be supported by the * Helper method for registering a message type serializer for a record
* serializer, by adding it to the normal and inverse mappings. * class, using {@link RecordMessageTypeSerializer#generateForRecord(Serializer, Class)}.
* @param id The byte which will be used to identify messages of the given * @param id The byte which will be used to identify messages of the given
* class. The value should from 0 to 127. * class. The value should from 0 to 127.
* @param messageClass The type of message associated with the given id. * @param messageClass The type of message associated with the given id.
*/ */
public synchronized <T extends Message> void registerType(int id, Class<T> messageClass) { public synchronized <T extends Message> void registerType(int id, Class<T> messageClass) {
registerTypeSerializer(id, MessageTypeSerializer.generateForRecord(this, messageClass)); registerTypeSerializer(id, RecordMessageTypeSerializer.generateForRecord(this, messageClass));
} }
/**
* Registers the given type serializer with the given id.
* @param id The id to use.
* @param typeSerializer The type serializer that will be associated with
* the given id.
* @param <T> The message type.
*/
public synchronized <T extends Message> void registerTypeSerializer(int id, MessageTypeSerializer<T> typeSerializer) { public synchronized <T extends Message> void registerTypeSerializer(int id, MessageTypeSerializer<T> typeSerializer) {
if (id < 0 || id > 127) throw new IllegalArgumentException("Invalid id."); if (id < 0 || id > 127) throw new IllegalArgumentException("Invalid id.");
messageTypes.put((byte) id, typeSerializer); messageTypes.put((byte) id, typeSerializer);

View File

@ -85,10 +85,10 @@ public class ExtendedDataInputStream extends DataInputStream {
int length = this.readInt(); int length = this.readInt();
return this.readNBytes(length); return this.readNBytes(length);
} else if (type.isArray() && Message.class.isAssignableFrom(type.getComponentType())) { } else if (type.isArray() && Message.class.isAssignableFrom(type.getComponentType())) {
var messageType = MessageTypeSerializer.get(serializer, (Class<? extends Message>) type.getComponentType()); var messageType = RecordMessageTypeSerializer.get(serializer, (Class<? extends Message>) type.getComponentType());
return this.readArray(messageType); return this.readArray(messageType);
} else if (Message.class.isAssignableFrom(type)) { } else if (Message.class.isAssignableFrom(type)) {
var messageType = MessageTypeSerializer.get(serializer, (Class<? extends Message>) type); var messageType = RecordMessageTypeSerializer.get(serializer, (Class<? extends Message>) type);
return messageType.reader().read(this); return messageType.reader().read(this);
} else { } else {
throw new IOException("Unsupported object type: " + type.getSimpleName()); throw new IOException("Unsupported object type: " + type.getSimpleName());

View File

@ -0,0 +1,129 @@
package nl.andrewl.record_net.util;
import nl.andrewl.record_net.*;
import java.lang.reflect.Constructor;
import java.lang.reflect.RecordComponent;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
/**
* Helper class that contains logic for generating {@link MessageTypeSerializerImpl}
* implementations for record classes.
*/
public class RecordMessageTypeSerializer {
/**
* An internal cache for storing generated type serializers.
*/
private static final Map<Pair<Class<?>, Serializer>, MessageTypeSerializer<?>> generatedMessageTypes = new HashMap<>();
/**
* Gets the {@link MessageTypeSerializer} instance for a given message class, and
* generates a new implementation if none exists yet.
* @param serializer The serializer context to get a type serializer for.
* @param messageClass The class of the message to get a type for.
* @param <T> The type of the message.
* @return The message type.
*/
@SuppressWarnings("unchecked")
public static <T extends Message> MessageTypeSerializer<T> get(Serializer serializer, Class<T> messageClass) {
return (MessageTypeSerializer<T>) generatedMessageTypes.computeIfAbsent(
new Pair<>(messageClass, serializer),
p -> generateForRecord(serializer, (Class<T>) p.first())
);
}
/**
* Generates a message type instance for a given class, using reflection to
* introspect the fields of the message.
* <p>
* Note that this only works for record-based messages.
* </p>
* @param serializer The serializer context to get a type serializer for.
* @param messageTypeClass The class of the message type.
* @param <T> The type of the message.
* @return A message type instance.
*/
public static <T extends Message> MessageTypeSerializerImpl<T> generateForRecord(Serializer serializer, Class<T> messageTypeClass) {
RecordComponent[] components = messageTypeClass.getRecordComponents();
if (components == null) throw new IllegalArgumentException("Cannot generate a MessageTypeSerializer for non-record class " + messageTypeClass.getSimpleName());
Constructor<T> constructor;
try {
constructor = messageTypeClass.getDeclaredConstructor(Arrays.stream(components)
.map(RecordComponent::getType).toArray(Class<?>[]::new));
} catch (NoSuchMethodException e) {
throw new IllegalArgumentException(e);
}
return new MessageTypeSerializerImpl<>(
messageTypeClass,
generateByteSizeFunction(serializer, components),
generateReader(constructor),
generateWriter(components)
);
}
/**
* Generates a function implementation that counts the byte size of a
* message based on the message's record component types.
* @param serializer The serializer context to generate a function for.
* @param components The list of components that make up the message.
* @param <T> The message type.
* @return A function that computes the byte size of a message of the given
* type.
*/
private static <T extends Message> Function<T, Integer> generateByteSizeFunction(Serializer serializer, RecordComponent[] components) {
return msg -> {
int size = 0;
for (var component : components) {
try {
size += MessageUtils.getByteSize(serializer, component.getAccessor().invoke(msg));
} catch (ReflectiveOperationException e) {
throw new IllegalStateException(e);
}
}
return size;
};
}
/**
* Generates a message reader for the given message constructor method. It
* will try to read objects from the input stream according to the
* parameters of the canonical constructor of a message record.
* @param constructor The canonical constructor of the message record.
* @param <T> The message type.
* @return A message reader for the given type.
*/
private static <T extends Message> MessageReader<T> generateReader(Constructor<T> constructor) {
return in -> {
Object[] values = new Object[constructor.getParameterCount()];
for (int i = 0; i < values.length; i++) {
values[i] = in.readObject(constructor.getParameterTypes()[i]);
}
try {
return constructor.newInstance(values);
} catch (ReflectiveOperationException e) {
throw new IllegalStateException(e);
}
};
}
/**
* Generates a message writer for the given message record components.
* @param components The record components to write.
* @param <T> The type of message.
* @return The message writer for the given type.
*/
private static <T extends Message> MessageWriter<T> generateWriter(RecordComponent[] components) {
return (msg, out) -> {
for (var component: components) {
try {
out.writeObject(component.getAccessor().invoke(msg), component.getType());
} catch (ReflectiveOperationException e) {
throw new IllegalStateException(e);
}
}
};
}
}

View File

@ -3,6 +3,7 @@ package nl.andrewl.record_net;
import nl.andrewl.record_net.msg.ChatMessage; import nl.andrewl.record_net.msg.ChatMessage;
import nl.andrewl.record_net.util.ExtendedDataInputStream; import nl.andrewl.record_net.util.ExtendedDataInputStream;
import nl.andrewl.record_net.util.ExtendedDataOutputStream; import nl.andrewl.record_net.util.ExtendedDataOutputStream;
import nl.andrewl.record_net.util.RecordMessageTypeSerializer;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
@ -11,11 +12,11 @@ import java.io.IOException;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
public class MessageTypeSerializerTest { public class RecordMessageTypeSerializerTest {
@Test @Test
public void testGenerateForRecord() throws IOException { public void testGenerateForRecord() throws IOException {
Serializer serializer = new Serializer(); Serializer serializer = new Serializer();
var s1 = MessageTypeSerializer.get(serializer, ChatMessage.class); var s1 = RecordMessageTypeSerializer.get(serializer, ChatMessage.class);
ChatMessage msg = new ChatMessage("andrew", 123, "Hello world!"); ChatMessage msg = new ChatMessage("andrew", 123, "Hello world!");
int expectedByteSize = 4 + msg.username().length() + 8 + 4 + msg.message().length(); int expectedByteSize = 4 + msg.username().length() + 8 + 4 + msg.message().length();
assertEquals(expectedByteSize, s1.byteSizeFunction().apply(msg)); assertEquals(expectedByteSize, s1.byteSizeFunction().apply(msg));
@ -30,6 +31,6 @@ public class MessageTypeSerializerTest {
// Only record classes can be generated. // Only record classes can be generated.
class NonRecordMessage implements Message {} class NonRecordMessage implements Message {}
assertThrows(IllegalArgumentException.class, () -> MessageTypeSerializer.get(serializer, NonRecordMessage.class)); assertThrows(IllegalArgumentException.class, () -> RecordMessageTypeSerializer.get(serializer, NonRecordMessage.class));
} }
} }