Refactored to remove simple static class cache.
This commit is contained in:
parent
7cc9327fef
commit
1d45822c67
2
pom.xml
2
pom.xml
|
@ -6,7 +6,7 @@
|
|||
|
||||
<groupId>nl.andrewl</groupId>
|
||||
<artifactId>record-net</artifactId>
|
||||
<version>1.0.0</version>
|
||||
<version>1.1.0</version>
|
||||
|
||||
<properties>
|
||||
<maven.compiler.source>17</maven.compiler.source>
|
||||
|
|
|
@ -15,15 +15,7 @@ public interface Message {
|
|||
* @return The serializer to use to read and write messages of this type.
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
default <T extends Message> MessageTypeSerializer<T> getTypeSerializer() {
|
||||
return MessageTypeSerializer.get((Class<T>) this.getClass());
|
||||
}
|
||||
|
||||
/**
|
||||
* Convenience method to determine the size of this message in bytes.
|
||||
* @return The size of this message, in bytes.
|
||||
*/
|
||||
default int byteSize() {
|
||||
return getTypeSerializer().byteSizeFunction().apply(this);
|
||||
default <T extends Message> MessageTypeSerializer<T> getTypeSerializer(Serializer serializer) {
|
||||
return (MessageTypeSerializer<T>) serializer.getTypeSerializer(this.getClass());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
package nl.andrewl.record_net;
|
||||
|
||||
import nl.andrewl.record_net.util.Pair;
|
||||
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.lang.reflect.RecordComponent;
|
||||
import java.util.Arrays;
|
||||
|
@ -25,18 +27,25 @@ public record MessageTypeSerializer<T extends Message>(
|
|||
MessageReader<T> reader,
|
||||
MessageWriter<T> writer
|
||||
) {
|
||||
private static final Map<Class<?>, MessageTypeSerializer<?>> generatedMessageTypes = new HashMap<>();
|
||||
/**
|
||||
* An internal cache for storing generated type serializers.
|
||||
*/
|
||||
private static final Map<Pair<Class<?>, Serializer>, MessageTypeSerializer<?>> generatedMessageTypes = new HashMap<>();
|
||||
|
||||
/**
|
||||
* Gets the {@link MessageTypeSerializer} instance for a given message class, and
|
||||
* generates a new implementation if none exists yet.
|
||||
* @param serializer The serializer context to get a type serializer for.
|
||||
* @param messageClass The class of the message to get a type for.
|
||||
* @param <T> The type of the message.
|
||||
* @return The message type.
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
public static <T extends Message> MessageTypeSerializer<T> get(Class<T> messageClass) {
|
||||
return (MessageTypeSerializer<T>) generatedMessageTypes.computeIfAbsent(messageClass, c -> generateForRecord((Class<T>) c));
|
||||
public static <T extends Message> MessageTypeSerializer<T> get(Serializer serializer, Class<T> messageClass) {
|
||||
return (MessageTypeSerializer<T>) generatedMessageTypes.computeIfAbsent(
|
||||
new Pair<>(messageClass, serializer),
|
||||
p -> generateForRecord(serializer, (Class<T>) p.first())
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -45,11 +54,12 @@ public record MessageTypeSerializer<T extends Message>(
|
|||
* <p>
|
||||
* Note that this only works for record-based messages.
|
||||
* </p>
|
||||
* @param serializer The serializer context to get a type serializer for.
|
||||
* @param messageTypeClass The class of the message type.
|
||||
* @param <T> The type of the message.
|
||||
* @return A message type instance.
|
||||
*/
|
||||
public static <T extends Message> MessageTypeSerializer<T> generateForRecord(Class<T> messageTypeClass) {
|
||||
public static <T extends Message> MessageTypeSerializer<T> generateForRecord(Serializer serializer, Class<T> messageTypeClass) {
|
||||
RecordComponent[] components = messageTypeClass.getRecordComponents();
|
||||
if (components == null) throw new IllegalArgumentException("Cannot generate a MessageTypeSerializer for non-record class " + messageTypeClass.getSimpleName());
|
||||
Constructor<T> constructor;
|
||||
|
@ -61,7 +71,7 @@ public record MessageTypeSerializer<T extends Message>(
|
|||
}
|
||||
return new MessageTypeSerializer<>(
|
||||
messageTypeClass,
|
||||
generateByteSizeFunction(components),
|
||||
generateByteSizeFunction(serializer, components),
|
||||
generateReader(constructor),
|
||||
generateWriter(components)
|
||||
);
|
||||
|
@ -70,17 +80,18 @@ public record MessageTypeSerializer<T extends Message>(
|
|||
/**
|
||||
* Generates a function implementation that counts the byte size of a
|
||||
* message based on the message's record component types.
|
||||
* @param serializer The serializer context to generate a function for.
|
||||
* @param components The list of components that make up the message.
|
||||
* @param <T> The message type.
|
||||
* @return A function that computes the byte size of a message of the given
|
||||
* type.
|
||||
*/
|
||||
private static <T extends Message> Function<T, Integer> generateByteSizeFunction(RecordComponent[] components) {
|
||||
private static <T extends Message> Function<T, Integer> generateByteSizeFunction(Serializer serializer, RecordComponent[] components) {
|
||||
return msg -> {
|
||||
int size = 0;
|
||||
for (var component : components) {
|
||||
try {
|
||||
size += MessageUtils.getByteSize(component.getAccessor().invoke(msg));
|
||||
size += MessageUtils.getByteSize(serializer, component.getAccessor().invoke(msg));
|
||||
} catch (ReflectiveOperationException e) {
|
||||
throw new IllegalStateException(e);
|
||||
}
|
||||
|
|
|
@ -35,19 +35,25 @@ public class MessageUtils {
|
|||
return size;
|
||||
}
|
||||
|
||||
public static int getByteSize(Message msg) {
|
||||
return 1 + (msg == null ? 0 : msg.byteSize());
|
||||
@SuppressWarnings("unchecked")
|
||||
public static <T extends Message> int getByteSize(Serializer serializer, T msg) {
|
||||
if (msg == null) {
|
||||
return 1;
|
||||
} else {
|
||||
MessageTypeSerializer<T> typeSerializer = (MessageTypeSerializer<T>) serializer.getTypeSerializer(msg.getClass());
|
||||
return 1 + typeSerializer.byteSizeFunction().apply(msg);
|
||||
}
|
||||
}
|
||||
|
||||
public static <T extends Message> int getByteSize(T[] items) {
|
||||
public static <T extends Message> int getByteSize(Serializer serializer, T[] items) {
|
||||
int count = Integer.BYTES;
|
||||
for (var item : items) {
|
||||
count += getByteSize(item);
|
||||
count += getByteSize(serializer, item);
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
public static int getByteSize(Object o) {
|
||||
public static int getByteSize(Serializer serializer, Object o) {
|
||||
if (o instanceof Integer) {
|
||||
return Integer.BYTES;
|
||||
} else if (o instanceof Long) {
|
||||
|
@ -61,18 +67,18 @@ public class MessageUtils {
|
|||
} else if (o instanceof byte[]) {
|
||||
return Integer.BYTES + ((byte[]) o).length;
|
||||
} else if (o.getClass().isArray() && Message.class.isAssignableFrom(o.getClass().getComponentType())) {
|
||||
return getByteSize((Message[]) o);
|
||||
return getByteSize(serializer, (Message[]) o);
|
||||
} else if (o instanceof Message) {
|
||||
return getByteSize((Message) o);
|
||||
return getByteSize(serializer, (Message) o);
|
||||
} else {
|
||||
throw new IllegalArgumentException("Unsupported object type: " + o.getClass().getSimpleName());
|
||||
}
|
||||
}
|
||||
|
||||
public static int getByteSize(Object... objects) {
|
||||
public static int getByteSize(Serializer serializer, Object... objects) {
|
||||
int size = 0;
|
||||
for (var o : objects) {
|
||||
size += getByteSize(o);
|
||||
size += getByteSize(serializer, o);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
|
|
@ -19,6 +19,8 @@ public class Serializer {
|
|||
*/
|
||||
private final Map<Byte, MessageTypeSerializer<?>> messageTypes = new HashMap<>();
|
||||
|
||||
private final Map<Class<?>, MessageTypeSerializer<?>> messageTypeClasses = new HashMap<>();
|
||||
|
||||
/**
|
||||
* An inverse of {@link Serializer#messageTypes} which is used to look up a
|
||||
* message's byte value when you know the class of the message.
|
||||
|
@ -47,10 +49,25 @@ public class Serializer {
|
|||
* @param messageClass The type of message associated with the given id.
|
||||
*/
|
||||
public synchronized <T extends Message> void registerType(int id, Class<T> messageClass) {
|
||||
registerTypeSerializer(id, MessageTypeSerializer.generateForRecord(this, messageClass));
|
||||
}
|
||||
|
||||
public synchronized <T extends Message> void registerTypeSerializer(int id, MessageTypeSerializer<T> typeSerializer) {
|
||||
if (id < 0 || id > 127) throw new IllegalArgumentException("Invalid id.");
|
||||
MessageTypeSerializer<T> type = MessageTypeSerializer.get(messageClass);
|
||||
messageTypes.put((byte)id, type);
|
||||
inverseMessageTypes.put(type, (byte)id);
|
||||
messageTypes.put((byte) id, typeSerializer);
|
||||
inverseMessageTypes.put(typeSerializer, (byte) id);
|
||||
messageTypeClasses.put(typeSerializer.messageClass(), typeSerializer);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the {@link MessageTypeSerializer} for the given message class.
|
||||
* @param messageType The class of message to get the serializer for.
|
||||
* @return The message type serializer.
|
||||
* @param <T> The type of message.
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
public <T extends Message> MessageTypeSerializer<T> getTypeSerializer(Class<T> messageType) {
|
||||
return (MessageTypeSerializer<T>) messageTypeClasses.get(messageType);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -63,7 +80,7 @@ public class Serializer {
|
|||
* constructed for the incoming data.
|
||||
*/
|
||||
public Message readMessage(InputStream i) throws IOException {
|
||||
ExtendedDataInputStream d = new ExtendedDataInputStream(i);
|
||||
ExtendedDataInputStream d = new ExtendedDataInputStream(this, i);
|
||||
byte typeId = d.readByte();
|
||||
var type = messageTypes.get(typeId);
|
||||
if (type == null) {
|
||||
|
@ -99,12 +116,12 @@ public class Serializer {
|
|||
*/
|
||||
public <T extends Message> void writeMessage(T msg, OutputStream o) throws IOException {
|
||||
DataOutputStream d = new DataOutputStream(o);
|
||||
Byte typeId = inverseMessageTypes.get(msg.getTypeSerializer());
|
||||
Byte typeId = inverseMessageTypes.get(msg.getTypeSerializer(this));
|
||||
if (typeId == null) {
|
||||
throw new IOException("Unsupported message type: " + msg.getClass().getSimpleName());
|
||||
}
|
||||
d.writeByte(typeId);
|
||||
msg.getTypeSerializer().writer().write(msg, new ExtendedDataOutputStream(d));
|
||||
msg.getTypeSerializer(this).writer().write(msg, new ExtendedDataOutputStream(this, d));
|
||||
d.flush();
|
||||
}
|
||||
|
||||
|
@ -117,7 +134,8 @@ public class Serializer {
|
|||
* to write is not supported by this serializer.
|
||||
*/
|
||||
public <T extends Message> byte[] writeMessage(T msg) throws IOException {
|
||||
ByteArrayOutputStream out = new ByteArrayOutputStream(1 + msg.byteSize());
|
||||
int bytes = msg.getTypeSerializer(this).byteSizeFunction().apply(msg);
|
||||
ByteArrayOutputStream out = new ByteArrayOutputStream(1 + bytes);
|
||||
writeMessage(msg, out);
|
||||
return out.toByteArray();
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package nl.andrewl.record_net.util;
|
|||
|
||||
import nl.andrewl.record_net.Message;
|
||||
import nl.andrewl.record_net.MessageTypeSerializer;
|
||||
import nl.andrewl.record_net.Serializer;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.DataInputStream;
|
||||
|
@ -16,12 +17,15 @@ import java.util.UUID;
|
|||
* complex types that are used by the Concord system.
|
||||
*/
|
||||
public class ExtendedDataInputStream extends DataInputStream {
|
||||
public ExtendedDataInputStream(InputStream in) {
|
||||
private final Serializer serializer;
|
||||
|
||||
public ExtendedDataInputStream(Serializer serializer, InputStream in) {
|
||||
super(in);
|
||||
this.serializer = serializer;
|
||||
}
|
||||
|
||||
public ExtendedDataInputStream(byte[] data) {
|
||||
this(new ByteArrayInputStream(data));
|
||||
public ExtendedDataInputStream(Serializer serializer, byte[] data) {
|
||||
this(serializer, new ByteArrayInputStream(data));
|
||||
}
|
||||
|
||||
public String readString() throws IOException {
|
||||
|
@ -81,10 +85,10 @@ public class ExtendedDataInputStream extends DataInputStream {
|
|||
int length = this.readInt();
|
||||
return this.readNBytes(length);
|
||||
} else if (type.isArray() && Message.class.isAssignableFrom(type.getComponentType())) {
|
||||
var messageType = MessageTypeSerializer.get((Class<? extends Message>) type.getComponentType());
|
||||
var messageType = MessageTypeSerializer.get(serializer, (Class<? extends Message>) type.getComponentType());
|
||||
return this.readArray(messageType);
|
||||
} else if (Message.class.isAssignableFrom(type)) {
|
||||
var messageType = MessageTypeSerializer.get((Class<? extends Message>) type);
|
||||
var messageType = MessageTypeSerializer.get(serializer, (Class<? extends Message>) type);
|
||||
return messageType.reader().read(this);
|
||||
} else {
|
||||
throw new IOException("Unsupported object type: " + type.getSimpleName());
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package nl.andrewl.record_net.util;
|
||||
|
||||
import nl.andrewl.record_net.Message;
|
||||
import nl.andrewl.record_net.Serializer;
|
||||
|
||||
import java.io.DataOutputStream;
|
||||
import java.io.IOException;
|
||||
|
@ -13,8 +14,11 @@ import java.util.UUID;
|
|||
* that help us to write more data.
|
||||
*/
|
||||
public class ExtendedDataOutputStream extends DataOutputStream {
|
||||
public ExtendedDataOutputStream(OutputStream out) {
|
||||
private final Serializer serializer;
|
||||
|
||||
public ExtendedDataOutputStream(Serializer serializer, OutputStream out) {
|
||||
super(out);
|
||||
this.serializer = serializer;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -125,7 +129,7 @@ public class ExtendedDataOutputStream extends DataOutputStream {
|
|||
public <T extends Message> void writeMessage(Message msg) throws IOException {
|
||||
writeBoolean(msg != null);
|
||||
if (msg != null) {
|
||||
msg.getTypeSerializer().writer().write(msg, this);
|
||||
msg.getTypeSerializer(serializer).writer().write(msg, this);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -14,21 +14,22 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
|
|||
public class MessageTypeSerializerTest {
|
||||
@Test
|
||||
public void testGenerateForRecord() throws IOException {
|
||||
var s1 = MessageTypeSerializer.get(ChatMessage.class);
|
||||
Serializer serializer = new Serializer();
|
||||
var s1 = MessageTypeSerializer.get(serializer, ChatMessage.class);
|
||||
ChatMessage msg = new ChatMessage("andrew", 123, "Hello world!");
|
||||
int expectedByteSize = 4 + msg.username().length() + 8 + 4 + msg.message().length();
|
||||
assertEquals(expectedByteSize, s1.byteSizeFunction().apply(msg));
|
||||
assertEquals(expectedByteSize, msg.byteSize());
|
||||
|
||||
ByteArrayOutputStream bOut = new ByteArrayOutputStream();
|
||||
ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(bOut);
|
||||
ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(serializer, bOut);
|
||||
s1.writer().write(msg, eOut);
|
||||
byte[] data = bOut.toByteArray();
|
||||
assertEquals(expectedByteSize, data.length);
|
||||
ChatMessage readMsg = s1.reader().read(new ExtendedDataInputStream(data));
|
||||
ChatMessage readMsg = s1.reader().read(new ExtendedDataInputStream(serializer, data));
|
||||
assertEquals(msg, readMsg);
|
||||
|
||||
// Only record classes can be generated.
|
||||
class NonRecordMessage implements Message {}
|
||||
assertThrows(IllegalArgumentException.class, () -> MessageTypeSerializer.get(NonRecordMessage.class));
|
||||
assertThrows(IllegalArgumentException.class, () -> MessageTypeSerializer.get(serializer, NonRecordMessage.class));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,10 +18,12 @@ public class MessageUtilsTest {
|
|||
assertEquals(10, MessageUtils.getByteSize("a", "b"));
|
||||
Message msg = new ChatMessage("andrew", 123, "Hello world!");
|
||||
int expectedMsgSize = 1 + 4 + 6 + 8 + 4 + 12;
|
||||
assertEquals(1, MessageUtils.getByteSize((Message) null));
|
||||
assertEquals(expectedMsgSize, MessageUtils.getByteSize(msg));
|
||||
assertEquals(4 * expectedMsgSize, MessageUtils.getByteSize(msg, msg, msg, msg));
|
||||
assertEquals(16, MessageUtils.getByteSize(UUID.randomUUID()));
|
||||
assertEquals(4, MessageUtils.getByteSize(StandardCopyOption.ATOMIC_MOVE));
|
||||
Serializer serializer = new Serializer();
|
||||
serializer.registerType(1, ChatMessage.class);
|
||||
assertEquals(1, MessageUtils.getByteSize(serializer, (Message) null));
|
||||
assertEquals(expectedMsgSize, MessageUtils.getByteSize(serializer, msg));
|
||||
assertEquals(4 * expectedMsgSize, MessageUtils.getByteSize(serializer, msg, msg, msg, msg));
|
||||
assertEquals(16, MessageUtils.getByteSize(serializer, UUID.randomUUID()));
|
||||
assertEquals(4, MessageUtils.getByteSize(serializer, StandardCopyOption.ATOMIC_MOVE));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ public class SerializerTest {
|
|||
ByteArrayOutputStream bOut = new ByteArrayOutputStream();
|
||||
s.writeMessage(msg, bOut);
|
||||
byte[] data = bOut.toByteArray();
|
||||
assertEquals(1 + msg.byteSize(), data.length);
|
||||
assertEquals(MessageUtils.getByteSize(s, msg), data.length);
|
||||
assertEquals(data[0], 1);
|
||||
|
||||
ChatMessage readMsg = (ChatMessage) s.readMessage(new ByteArrayInputStream(data));
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package nl.andrewl.record_net.util;
|
||||
|
||||
import nl.andrewl.record_net.Serializer;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
|
@ -13,8 +14,9 @@ public class ExtendedDataOutputStreamTest {
|
|||
|
||||
@Test
|
||||
public void testWriteString() throws IOException {
|
||||
Serializer serializer = new Serializer();
|
||||
ByteArrayOutputStream bOut = new ByteArrayOutputStream();
|
||||
ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(bOut);
|
||||
ExtendedDataOutputStream eOut = new ExtendedDataOutputStream(serializer, bOut);
|
||||
eOut.writeString("Hello world!");
|
||||
byte[] data = bOut.toByteArray();
|
||||
assertEquals(4 + "Hello world!".length(), data.length);
|
||||
|
|
Loading…
Reference in New Issue