Refactored to remove simple static class cache.

This commit is contained in:
Andrew Lalis 2022-04-16 13:41:01 +02:00
parent 7cc9327fef
commit 1d45822c67
11 changed files with 93 additions and 53 deletions

View File

@ -6,7 +6,7 @@
<groupId>nl.andrewl</groupId>
<artifactId>record-net</artifactId>
<version>1.0.0</version>
<version>1.1.0</version>
<properties>
<maven.compiler.source>17</maven.compiler.source>

View File

@ -15,15 +15,7 @@ public interface Message {
* @return The serializer to use to read and write messages of this type.
*/
@SuppressWarnings("unchecked")
default <T extends Message> MessageTypeSerializer<T> getTypeSerializer() {
return MessageTypeSerializer.get((Class<T>) this.getClass());
}
/**
* Convenience method to determine the size of this message in bytes.
* @return The size of this message, in bytes.
*/
default int byteSize() {
return getTypeSerializer().byteSizeFunction().apply(this);
default <T extends Message> MessageTypeSerializer<T> getTypeSerializer(Serializer serializer) {
return (MessageTypeSerializer<T>) serializer.getTypeSerializer(this.getClass());
}
}

View File

@ -1,5 +1,7 @@
package nl.andrewl.record_net;
import nl.andrewl.record_net.util.Pair;
import java.lang.reflect.Constructor;
import java.lang.reflect.RecordComponent;
import java.util.Arrays;
@ -25,18 +27,25 @@ public record MessageTypeSerializer<T extends Message>(
MessageReader<T> reader,
MessageWriter<T> writer
) {
private static final Map<Class<?>, MessageTypeSerializer<?>> generatedMessageTypes = new HashMap<>();
/**
* An internal cache for storing generated type serializers.
*/
private static final Map<Pair<Class<?>, Serializer>, MessageTypeSerializer<?>> generatedMessageTypes = new HashMap<>();
/**
* Gets the {@link MessageTypeSerializer} instance for a given message class, and
* generates a new implementation if none exists yet.
* @param serializer The serializer context to get a type serializer for.
* @param messageClass The class of the message to get a type for.
* @param <T> The type of the message.
* @return The message type.
*/
@SuppressWarnings("unchecked")
public static <T extends Message> MessageTypeSerializer<T> get(Class<T> messageClass) {
return (MessageTypeSerializer<T>) generatedMessageTypes.computeIfAbsent(messageClass, c -> generateForRecord((Class<T>) c));
public static <T extends Message> MessageTypeSerializer<T> get(Serializer serializer, Class<T> messageClass) {
return (MessageTypeSerializer<T>) generatedMessageTypes.computeIfAbsent(
new Pair<>(messageClass, serializer),
p -> generateForRecord(serializer, (Class<T>) p.first())
);
}
/**
@ -45,11 +54,12 @@ public record MessageTypeSerializer<T extends Message>(
* <p>
* Note that this only works for record-based messages.
* </p>
* @param serializer The serializer context to get a type serializer for.
* @param messageTypeClass The class of the message type.
* @param <T> The type of the message.
* @return A message type instance.
*/
public static <T extends Message> MessageTypeSerializer<T> generateForRecord(Class<T> messageTypeClass) {
public static <T extends Message> MessageTypeSerializer<T> generateForRecord(Serializer serializer, Class<T> messageTypeClass) {
RecordComponent[] components = messageTypeClass.getRecordComponents();
if (components == null) throw new IllegalArgumentException("Cannot generate a MessageTypeSerializer for non-record class " + messageTypeClass.getSimpleName());
Constructor<T> constructor;
@ -61,7 +71,7 @@ public record MessageTypeSerializer<T extends Message>(
}
return new MessageTypeSerializer<>(
messageTypeClass,
generateByteSizeFunction(components),
generateByteSizeFunction(serializer, components),
generateReader(constructor),
generateWriter(components)
);
@ -70,17 +80,18 @@ public record MessageTypeSerializer<T extends Message>(
/**
* Generates a function implementation that counts the byte size of a
* message based on the message's record component types.
* @param serializer The serializer context to generate a function for.
* @param components The list of components that make up the message.
* @param <T> The message type.
* @return A function that computes the byte size of a message of the given
* type.
*/
private static <T extends Message> Function<T, Integer> generateByteSizeFunction(RecordComponent[] components) {
private static <T extends Message> Function<T, Integer> generateByteSizeFunction(Serializer serializer, RecordComponent[] components) {
return msg -> {
int size = 0;
for (var component : components) {
try {
size += MessageUtils.getByteSize(component.getAccessor().invoke(msg));
size += MessageUtils.getByteSize(serializer, component.getAccessor().invoke(msg));
} catch (ReflectiveOperationException e) {
throw new IllegalStateException(e);
}

View File

@ -35,19 +35,25 @@ public class MessageUtils {
return size;
}
public static int getByteSize(Message msg) {
return 1 + (msg == null ? 0 : msg.byteSize());
@SuppressWarnings("unchecked")
public static <T extends Message> int getByteSize(Serializer serializer, T msg) {
if (msg == null) {
return 1;
} else {
MessageTypeSerializer<T> typeSerializer = (MessageTypeSerializer<T>) serializer.getTypeSerializer(msg.getClass());
return 1 + typeSerializer.byteSizeFunction().apply(msg);
}
}
public static <T extends Message> int getByteSize(T[] items) {
public static <T extends Message> int getByteSize(Serializer serializer, T[] items) {
int count = Integer.BYTES;
for (var item : items) {
count += getByteSize(item);
count += getByteSize(serializer, item);
}
return count;
}
public static int getByteSize(Object o) {
public static int getByteSize(Serializer serializer, Object o) {
if (o instanceof Integer) {
return Integer.BYTES;
} else if (o instanceof Long) {
@ -61,18 +67,18 @@ public class MessageUtils {
} else if (o instanceof byte[]) {
return Integer.BYTES + ((byte[]) o).length;
} else if (o.getClass().isArray() && Message.class.isAssignableFrom(o.getClass().getComponentType())) {
return getByteSize((Message[]) o);
return getByteSize(serializer, (Message[]) o);
} else if (o instanceof Message) {
return getByteSize((Message) o);
return getByteSize(serializer, (Message) o);
} else {
throw new IllegalArgumentException("Unsupported object type: " + o.getClass().getSimpleName());
}
}
public static int getByteSize(Object... objects) {
public static int getByteSize(Serializer serializer, Object... objects) {
int size = 0;
for (var o : objects) {
size += getByteSize(o);
size += getByteSize(serializer, o);
}
return size;
}

View File

@ -19,6 +19,8 @@ public class Serializer {
*/
private final Map<Byte, MessageTypeSerializer<?>> messageTypes = new HashMap<>();
private final Map<Class<?>, MessageTypeSerializer<?>> messageTypeClasses = new HashMap<>();
/**
* An inverse of {@link Serializer#messageTypes} which is used to look up a
* message's byte value when you know the class of the message.
@ -47,10 +49,25 @@ public class Serializer {
* @param messageClass The type of message associated with the given id.
*/
public synchronized <T extends Message> void registerType(int id, Class<T> messageClass) {
registerTypeSerializer(id, MessageTypeSerializer.generateForRecord(this, messageClass));
}
public synchronized <T extends Message> void registerTypeSerializer(int id, MessageTypeSerializer<T> typeSerializer) {
if (id < 0 || id > 127) throw new IllegalArgumentException("Invalid id.");
MessageTypeSerializer<T> type = MessageTypeSerializer.get(messageClass);
messageTypes.put((byte)id, type);
inverseMessageTypes.put(type, (byte)id);
messageTypes.put((byte) id, typeSerializer);
inverseMessageTypes.put(typeSerializer, (byte) id);
messageTypeClasses.put(typeSerializer.messageClass(), typeSerializer);
}
/**
* Gets the {@link MessageTypeSerializer} for the given message class.
* @param messageType The class of message to get the serializer for.
* @return The message type serializer.
* @param <T> The type of message.
*/
@SuppressWarnings("unchecked")
public <T extends Message> MessageTypeSerializer<T> getTypeSerializer(Class<T> messageType) {
return (MessageTypeSerializer<T>) messageTypeClasses.get(messageType);
}
/**
@ -63,7 +80,7 @@ public class Serializer {
* constructed for the incoming data.
*/
public Message readMessage(InputStream i) throws IOException {
ExtendedDataInputStream d = new ExtendedDataInputStream(i);
ExtendedDataInputStream d = new ExtendedDataInputStream(this, i);
byte typeId = d.readByte();
var type = messageTypes.get(typeId);
if (type == null) {
@ -99,12 +116,12 @@ public class Serializer {
*/
public <T extends Message> void writeMessage(T msg, OutputStream o) throws IOException {
DataOutputStream d = new DataOutputStream(o);
Byte typeId = inverseMessageTypes.get(msg.getTypeSerializer());
Byte typeId = inverseMessageTypes.get(msg.getTypeSerializer(this));
if (typeId == null) {
throw new IOException("Unsupported message type: " + msg.getClass().getSimpleName());
}
d.writeByte(typeId);
msg.getTypeSerializer().writer().write(msg, new ExtendedDataOutputStream(d));
msg.getTypeSerializer(this).writer().write(msg, new ExtendedDataOutputStream(this, d));
d.flush();
}
@ -117,7 +134,8 @@ public class Serializer {
* to write is not supported by this serializer.
*/
public <T extends Message> byte[] writeMessage(T msg) throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream(1 + msg.byteSize());
int bytes = msg.getTypeSerializer(this).byteSizeFunction().apply(msg);
ByteArrayOutputStream out = new ByteArrayOutputStream(1 + bytes);
writeMessage(msg, out);
return out.toByteArray();
}

View File

@ -2,6 +2,7 @@ package nl.andrewl.record_net.util;
import nl.andrewl.record_net.Message;
import nl.andrewl.record_net.MessageTypeSerializer;
import nl.andrewl.record_net.Serializer;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
@ -16,12 +17,15 @@ import java.util.UUID;
* complex types that are used by the Concord system.
*/
public class ExtendedDataInputStream extends DataInputStream {
public ExtendedDataInputStream(InputStream in) {
private final Serializer serializer;
public ExtendedDataInputStream(Serializer serializer, InputStream in) {
super(in);
this.serializer = serializer;
}
public ExtendedDataInputStream(byte[] data) {
this(new ByteArrayInputStream(data));
public ExtendedDataInputStream(Serializer serializer, byte[] data) {
this(serializer, new ByteArrayInputStream(data));
}
public String readString() throws IOException {
@ -81,10 +85,10 @@ public class ExtendedDataInputStream extends DataInputStream {
int length = this.readInt();
return this.readNBytes(length);
} else if (type.isArray() && Message.class.isAssignableFrom(type.getComponentType())) {
var messageType = MessageTypeSerializer.get((Class<? extends Message>) type.getComponentType());
var messageType = MessageTypeSerializer.get(serializer, (Class<? extends Message>) type.getComponentType());
return this.readArray(messageType);
} else if (Message.class.isAssignableFrom(type)) {
var messageType = MessageTypeSerializer.get((Class<? extends Message>) type);
var messageType = MessageTypeSerializer.get(serializer, (Class<? extends Message>) type);
return messageType.reader().read(this);
} else {
throw new IOException("Unsupported object type: " + type.getSimpleName());

View File

@ -1,6 +1,7 @@
package nl.andrewl.record_net.util;
import nl.andrewl.record_net.Message;
import nl.andrewl.record_net.Serializer;
import java.io.DataOutputStream;
import java.io.IOException;
@ -13,8 +14,11 @@ import java.util.UUID;
* that help us to write more data.
*/
public class ExtendedDataOutputStream extends DataOutputStream {
public ExtendedDataOutputStream(OutputStream out) {
private final Serializer serializer;
public ExtendedDataOutputStream(Serializer serializer, OutputStream out) {
super(out);
this.serializer = serializer;
}
/**
@ -125,7 +129,7 @@ public class ExtendedDataOutputStream extends DataOutputStream {
public <T extends Message> void writeMessage(Message msg) throws IOException {
writeBoolean(msg != null);
if (msg != null) {
msg.getTypeSerializer().writer().write(msg, this);
msg.getTypeSerializer(serializer).writer().write(msg, this);
}
}

View File

@ -14,21 +14,22 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
public class MessageTypeSerializerTest {
@Test
public void testGenerateForRecord() throws IOException {
var s1 = MessageTypeSerializer.get(ChatMessage.class);
Serializer serializer = new Serializer();
var s1 = MessageTypeSerializer.get(serializer, ChatMessage.class);
ChatMessage msg = new ChatMessage("andrew", 123, "Hello world!");
int expectedByteSize = 4 + msg.username().length() + 8 + 4 + msg.message().length();
assertEquals(expectedByteSize, s1.byteSizeFunction().apply(msg));
assertEquals(expectedByteSize, msg.byteSize());
ByteArrayOutputStream bOut = new ByteArrayOutputStream();
ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(bOut);
ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(serializer, bOut);
s1.writer().write(msg, eOut);
byte[] data = bOut.toByteArray();
assertEquals(expectedByteSize, data.length);
ChatMessage readMsg = s1.reader().read(new ExtendedDataInputStream(data));
ChatMessage readMsg = s1.reader().read(new ExtendedDataInputStream(serializer, data));
assertEquals(msg, readMsg);
// Only record classes can be generated.
class NonRecordMessage implements Message {}
assertThrows(IllegalArgumentException.class, () -> MessageTypeSerializer.get(NonRecordMessage.class));
assertThrows(IllegalArgumentException.class, () -> MessageTypeSerializer.get(serializer, NonRecordMessage.class));
}
}

View File

@ -18,10 +18,12 @@ public class MessageUtilsTest {
assertEquals(10, MessageUtils.getByteSize("a", "b"));
Message msg = new ChatMessage("andrew", 123, "Hello world!");
int expectedMsgSize = 1 + 4 + 6 + 8 + 4 + 12;
assertEquals(1, MessageUtils.getByteSize((Message) null));
assertEquals(expectedMsgSize, MessageUtils.getByteSize(msg));
assertEquals(4 * expectedMsgSize, MessageUtils.getByteSize(msg, msg, msg, msg));
assertEquals(16, MessageUtils.getByteSize(UUID.randomUUID()));
assertEquals(4, MessageUtils.getByteSize(StandardCopyOption.ATOMIC_MOVE));
Serializer serializer = new Serializer();
serializer.registerType(1, ChatMessage.class);
assertEquals(1, MessageUtils.getByteSize(serializer, (Message) null));
assertEquals(expectedMsgSize, MessageUtils.getByteSize(serializer, msg));
assertEquals(4 * expectedMsgSize, MessageUtils.getByteSize(serializer, msg, msg, msg, msg));
assertEquals(16, MessageUtils.getByteSize(serializer, UUID.randomUUID()));
assertEquals(4, MessageUtils.getByteSize(serializer, StandardCopyOption.ATOMIC_MOVE));
}
}

View File

@ -20,7 +20,7 @@ public class SerializerTest {
ByteArrayOutputStream bOut = new ByteArrayOutputStream();
s.writeMessage(msg, bOut);
byte[] data = bOut.toByteArray();
assertEquals(1 + msg.byteSize(), data.length);
assertEquals(MessageUtils.getByteSize(s, msg), data.length);
assertEquals(data[0], 1);
ChatMessage readMsg = (ChatMessage) s.readMessage(new ByteArrayInputStream(data));

View File

@ -1,5 +1,6 @@
package nl.andrewl.record_net.util;
import nl.andrewl.record_net.Serializer;
import org.junit.jupiter.api.Test;
import java.io.ByteArrayInputStream;
@ -13,8 +14,9 @@ public class ExtendedDataOutputStreamTest {
@Test
public void testWriteString() throws IOException {
Serializer serializer = new Serializer();
ByteArrayOutputStream bOut = new ByteArrayOutputStream();
ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(bOut);
ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(serializer, bOut);
eOut.writeString("Hello world!");
byte[] data = bOut.toByteArray();
assertEquals(4 + "Hello world!".length(), data.length);