Added starter implementation, some rare and weird bugs with messages sometimes failing.

This commit is contained in:
Andrew Lalis 2021-09-11 15:23:36 +02:00
parent 6770418c66
commit c6a2bb15da
14 changed files with 251 additions and 39 deletions

View File

@ -16,24 +16,27 @@ import nl.andrewl.concord_client.event.handlers.ServerMetaDataHandler;
import nl.andrewl.concord_client.event.handlers.ServerUsersHandler; import nl.andrewl.concord_client.event.handlers.ServerUsersHandler;
import nl.andrewl.concord_client.gui.MainWindow; import nl.andrewl.concord_client.gui.MainWindow;
import nl.andrewl.concord_client.model.ClientModel; import nl.andrewl.concord_client.model.ClientModel;
import nl.andrewl.concord_core.msg.Encryption;
import nl.andrewl.concord_core.msg.Message; import nl.andrewl.concord_core.msg.Message;
import nl.andrewl.concord_core.msg.Serializer; import nl.andrewl.concord_core.msg.Serializer;
import nl.andrewl.concord_core.msg.types.Error;
import nl.andrewl.concord_core.msg.types.*; import nl.andrewl.concord_core.msg.types.*;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket; import java.net.Socket;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
import java.security.GeneralSecurityException;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
public class ConcordClient implements Runnable { public class ConcordClient implements Runnable {
private final Socket socket; private final Socket socket;
private final DataInputStream in; private InputStream in;
private final DataOutputStream out; private OutputStream out;
private final Serializer serializer; private final Serializer serializer;
@Getter @Getter
@ -46,8 +49,8 @@ public class ConcordClient implements Runnable {
public ConcordClient(String host, int port, String nickname, Path tokensFile) throws IOException { public ConcordClient(String host, int port, String nickname, Path tokensFile) throws IOException {
this.eventManager = new EventManager(this); this.eventManager = new EventManager(this);
this.socket = new Socket(host, port); this.socket = new Socket(host, port);
this.in = new DataInputStream(this.socket.getInputStream()); this.in = this.socket.getInputStream();
this.out = new DataOutputStream(this.socket.getOutputStream()); this.out = this.socket.getOutputStream();
this.serializer = new Serializer(); this.serializer = new Serializer();
this.model = this.initializeConnectionToServer(nickname, tokensFile); this.model = this.initializeConnectionToServer(nickname, tokensFile);
@ -73,6 +76,15 @@ public class ConcordClient implements Runnable {
* messages, or if the server sends an unexpected response. * messages, or if the server sends an unexpected response.
*/ */
private ClientModel initializeConnectionToServer(String nickname, Path tokensFile) throws IOException { private ClientModel initializeConnectionToServer(String nickname, Path tokensFile) throws IOException {
try {
System.out.println("Initializing end-to-end encryption with the server...");
var streams = Encryption.upgrade(this.in, this.out, this.serializer);
this.in = streams.first();
this.out = streams.second();
System.out.println("Successfully established cipher streams.");
} catch (GeneralSecurityException e) {
throw new IOException(e);
}
String token = this.getSessionToken(tokensFile); String token = this.getSessionToken(tokensFile);
this.serializer.writeMessage(new Identification(nickname, token), this.out); this.serializer.writeMessage(new Identification(nickname, token), this.out);
Message reply = this.serializer.readMessage(this.in); Message reply = this.serializer.readMessage(this.in);
@ -83,7 +95,7 @@ public class ConcordClient implements Runnable {
this.sendMessage(new ChatHistoryRequest(model.getCurrentChannelId(), "")); this.sendMessage(new ChatHistoryRequest(model.getCurrentChannelId(), ""));
return model; return model;
} else { } else {
throw new IOException("Unexpected response from the server after sending identification message."); throw new IOException("Unexpected response from the server after sending identification message: " + reply);
} }
} }

View File

@ -5,6 +5,8 @@ import nl.andrewl.concord_client.event.MessageHandler;
import nl.andrewl.concord_core.msg.types.ChatHistoryRequest; import nl.andrewl.concord_core.msg.types.ChatHistoryRequest;
import nl.andrewl.concord_core.msg.types.MoveToChannel; import nl.andrewl.concord_core.msg.types.MoveToChannel;
import java.util.Map;
/** /**
* When the client receives a {@link MoveToChannel} message, it means that the * When the client receives a {@link MoveToChannel} message, it means that the
* server has told the client that it has been moved to the indicated channel. * server has told the client that it has been moved to the indicated channel.
@ -15,6 +17,6 @@ public class ChannelMovedHandler implements MessageHandler<MoveToChannel> {
@Override @Override
public void handle(MoveToChannel msg, ConcordClient client) throws Exception { public void handle(MoveToChannel msg, ConcordClient client) throws Exception {
client.getModel().setCurrentChannel(msg.getId(), msg.getChannelName()); client.getModel().setCurrentChannel(msg.getId(), msg.getChannelName());
client.sendMessage(new ChatHistoryRequest(msg.getId(), "")); client.sendMessage(new ChatHistoryRequest(msg.getId()));
} }
} }

View File

@ -56,7 +56,9 @@ public class ChatList extends AbstractListBox<Chat, ChatList> implements ChatHis
public void chatUpdated(ChatHistory history) { public void chatUpdated(ChatHistory history) {
this.getTextGUI().getGUIThread().invokeLater(() -> { this.getTextGUI().getGUIThread().invokeLater(() -> {
this.clearItems(); this.clearItems();
System.out.println("Cleared chats");
for (var chat : history.getChats()) { for (var chat : history.getChats()) {
System.out.println("Adding chat: " + chat);
this.addItem(chat); this.addItem(chat);
} }
}); });

View File

@ -1,6 +1,7 @@
module concord_core { module concord_core {
requires static lombok; requires static lombok;
exports nl.andrewl.concord_core.util to concord_server, concord_client;
exports nl.andrewl.concord_core.msg to concord_server, concord_client; exports nl.andrewl.concord_core.msg to concord_server, concord_client;
exports nl.andrewl.concord_core.msg.types to concord_server, concord_client; exports nl.andrewl.concord_core.msg.types to concord_server, concord_client;
} }

View File

@ -0,0 +1,73 @@
package nl.andrewl.concord_core.msg;
import nl.andrewl.concord_core.msg.types.KeyData;
import nl.andrewl.concord_core.util.Pair;
import javax.crypto.Cipher;
import javax.crypto.CipherInputStream;
import javax.crypto.CipherOutputStream;
import javax.crypto.KeyAgreement;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.security.*;
import java.security.spec.X509EncodedKeySpec;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class Encryption {
public static Pair<CipherInputStream, CipherOutputStream> upgrade(
InputStream in,
OutputStream out,
Serializer serializer
) throws GeneralSecurityException, IOException {
// Generate our own key pair.
KeyPairGenerator kpg = KeyPairGenerator.getInstance("EC");
kpg.initialize(256);
KeyPair keyPair = kpg.generateKeyPair();
byte[] publicKey = keyPair.getPublic().getEncoded();
var random = new SecureRandom();
byte[] iv = new byte[16];
random.nextBytes(iv);
byte[] salt = new byte[8];
random.nextBytes(salt);
// Send our public key and related data to the client, unencrypted.
serializer.writeMessage(new KeyData(iv, salt, publicKey), out);
// Receive and decode client's unencrypted key data.
KeyData clientKeyData = (KeyData) serializer.readMessage(in);
PublicKey clientPublicKey = KeyFactory.getInstance("EC")
.generatePublic(new X509EncodedKeySpec(clientKeyData.getPublicKey()));
// Compute secret key from client's public key and our private key.
KeyAgreement ka = KeyAgreement.getInstance("ECDH");
ka.init(keyPair.getPrivate());
ka.doPhase(clientPublicKey, true);
byte[] secretKey = computeSecretKey(ka.generateSecret(), publicKey, clientKeyData.getPublicKey());
// Initialize cipher streams.
Cipher writeCipher = Cipher.getInstance("AES/CFB8/NoPadding");
Cipher readCipher = Cipher.getInstance("AES/CFB8/NoPadding");
Key cipherKey = new SecretKeySpec(secretKey, "AES");
writeCipher.init(Cipher.ENCRYPT_MODE, cipherKey, new IvParameterSpec(iv));
readCipher.init(Cipher.DECRYPT_MODE, cipherKey, new IvParameterSpec(clientKeyData.getIv()));
return new Pair<>(
new CipherInputStream(in, readCipher),
new CipherOutputStream(out, writeCipher)
);
}
private static byte[] computeSecretKey(byte[] sharedSecret, byte[] pk1, byte[] pk2) throws NoSuchAlgorithmException {
MessageDigest hash = MessageDigest.getInstance("SHA-256");
hash.update(sharedSecret);
List<ByteBuffer> keys = Arrays.asList(ByteBuffer.wrap(pk1), ByteBuffer.wrap(pk2));
Collections.sort(keys);
hash.update(keys.get(0));
hash.update(keys.get(1));
return hash.digest();
}
}

View File

@ -15,6 +15,14 @@ import java.util.UUID;
public class MessageUtils { public class MessageUtils {
public static final int UUID_BYTES = 2 * Long.BYTES; public static final int UUID_BYTES = 2 * Long.BYTES;
public static final char MIN_HIGH_SURROGATE = '\uD800';
public static final char MAX_HIGH_SURROGATE = '\uDBFF';
public static final char MIN_LOW_SURROGATE = '\uDC00';
public static final char MAX_LOW_SURROGATE = '\uDFFF';
public static final int MIN_SUPPLEMENTARY_CODE_POINT = 0x010000;
private static final int SUR_CALC = (MIN_SUPPLEMENTARY_CODE_POINT - (MIN_HIGH_SURROGATE << 10)) - MIN_LOW_SURROGATE;
/** /**
* Gets the number of bytes that the given string will occupy when it is * Gets the number of bytes that the given string will occupy when it is
* serialized. * serialized.
@ -23,6 +31,53 @@ public class MessageUtils {
*/ */
public static int getByteSize(String s) { public static int getByteSize(String s) {
return Integer.BYTES + (s == null ? 0 : s.getBytes(StandardCharsets.UTF_8).length); return Integer.BYTES + (s == null ? 0 : s.getBytes(StandardCharsets.UTF_8).length);
// int length = s.length();
// int i = 0;
// int counter = 0;
//
// char c;
// while (i < length && (c = s.charAt(i)) < '\u0080') {
// // ascii fast loop;
// counter++;
// i++;
// }
//
// while (i < length) {
// c = s.charAt(i++);
// if (c < 0x80) {
// counter++;
// } else if (c < 0x800) {
// counter += 2;
// } else if (Character.isSurrogate(c)) {
// int uc = -1;
// char c2;
// if (isHighSurrogate(c) && i < length && isLowSurrogate(c2 = s.charAt(i))) {
// uc = toCodePoint(c, c2);
// }
// if (uc < 0) {
//
// } else {
// counter += 4;
// i++; // 2 chars
// }
// } else {
// // 3 bytes, 16 bits
// counter += 3;
// }
// }
//
// return Integer.BYTES + counter;
}
public static boolean isHighSurrogate(char ch) {
return ch >= MIN_HIGH_SURROGATE && ch < (MAX_HIGH_SURROGATE + 1);
}
public static boolean isLowSurrogate(char ch) {
return ch >= MIN_LOW_SURROGATE && ch < (MAX_LOW_SURROGATE + 1);
}
public static int toCodePoint(int high, int low) {
return ((high << 10) + low) + SUR_CALC;
} }
/** /**
@ -54,6 +109,7 @@ public class MessageUtils {
public static String readString(DataInputStream i) throws IOException { public static String readString(DataInputStream i) throws IOException {
int length = i.readInt(); int length = i.readInt();
if (length == -1) return null; if (length == -1) return null;
if (length == 0) return "";
byte[] data = new byte[length]; byte[] data = new byte[length];
int read = i.read(data); int read = i.read(data);
if (read != length) throw new IOException("Not all bytes of a string of length " + length + " could be read."); if (read != length) throw new IOException("Not all bytes of a string of length " + length + " could be read.");
@ -105,11 +161,13 @@ public class MessageUtils {
o.writeInt(items.size()); o.writeInt(items.size());
for (var i : items) { for (var i : items) {
i.write(o); i.write(o);
System.out.println("Wrote " + i);
} }
} }
public static <T extends Message> List<T> readList(Class<T> type, DataInputStream i) throws IOException { public static <T extends Message> List<T> readList(Class<T> type, DataInputStream i) throws IOException {
int size = i.readInt(); int size = i.readInt();
System.out.println("Read a size of " + size + " items of type " + type.getSimpleName());
try { try {
var constructor = type.getConstructor(); var constructor = type.getConstructor();
List<T> items = new ArrayList<>(size); List<T> items = new ArrayList<>(size);
@ -117,6 +175,7 @@ public class MessageUtils {
var item = constructor.newInstance(); var item = constructor.newInstance();
item.read(i); item.read(i);
items.add(item); items.add(item);
System.out.println("Read item " + (k+1) + " of " + size + ": " + item);
} }
return items; return items;
} catch (ReflectiveOperationException e) { } catch (ReflectiveOperationException e) {

View File

@ -39,6 +39,7 @@ public class Serializer {
registerType(8, ServerMetaData.class); registerType(8, ServerMetaData.class);
registerType(9, Error.class); registerType(9, Error.class);
registerType(10, CreateThread.class); registerType(10, CreateThread.class);
registerType(11, KeyData.class);
} }
/** /**

View File

@ -42,7 +42,7 @@ public class Chat implements Message {
@Override @Override
public int getByteCount() { public int getByteCount() {
return UUID_BYTES + Long.BYTES + getByteSize(this.senderNickname) + getByteSize(this.message); return 2 * UUID_BYTES + Long.BYTES + getByteSize(this.senderNickname) + getByteSize(this.message);
} }
@Override @Override
@ -61,6 +61,7 @@ public class Chat implements Message {
this.senderNickname = readString(i); this.senderNickname = readString(i);
this.timestamp = i.readLong(); this.timestamp = i.readLong();
this.message = readString(i); this.message = readString(i);
System.out.println("Read chat: " + this);
} }
@Override @Override

View File

@ -58,6 +58,10 @@ public class ChatHistoryRequest implements Message {
private UUID channelId; private UUID channelId;
private String query; private String query;
public ChatHistoryRequest(UUID channelId) {
this(channelId, "");
}
public ChatHistoryRequest(UUID channelId, Map<String, String> params) { public ChatHistoryRequest(UUID channelId, Map<String, String> params) {
this.channelId = channelId; this.channelId = channelId;
this.query = params.entrySet().stream() this.query = params.entrySet().stream()
@ -87,7 +91,7 @@ public class ChatHistoryRequest implements Message {
@Override @Override
public int getByteCount() { public int getByteCount() {
return UUID_BYTES + Integer.BYTES + getByteSize(this.query); return UUID_BYTES + getByteSize(this.query);
} }
@Override @Override

View File

@ -4,7 +4,6 @@ import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import nl.andrewl.concord_core.msg.Message; import nl.andrewl.concord_core.msg.Message;
import nl.andrewl.concord_core.msg.MessageUtils;
import java.io.DataInputStream; import java.io.DataInputStream;
import java.io.DataOutputStream; import java.io.DataOutputStream;
@ -12,6 +11,8 @@ import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import static nl.andrewl.concord_core.msg.MessageUtils.*;
/** /**
* The response that a server sends to a {@link ChatHistoryRequest}. The list of * The response that a server sends to a {@link ChatHistoryRequest}. The list of
* messages is ordered by timestamp, with the newest messages appearing first. * messages is ordered by timestamp, with the newest messages appearing first.
@ -25,32 +26,19 @@ public class ChatHistoryResponse implements Message {
@Override @Override
public int getByteCount() { public int getByteCount() {
int count = Long.BYTES + Integer.BYTES; return UUID_BYTES + getByteSize(messages);
for (var message : this.messages) {
count += message.getByteCount();
}
return count;
} }
@Override @Override
public void write(DataOutputStream o) throws IOException { public void write(DataOutputStream o) throws IOException {
MessageUtils.writeUUID(this.channelId, o); writeUUID(this.channelId, o);
o.writeInt(messages.size()); writeList(this.messages, o);
for (var message : this.messages) {
message.write(o);
}
} }
@Override @Override
public void read(DataInputStream i) throws IOException { public void read(DataInputStream i) throws IOException {
this.channelId = MessageUtils.readUUID(i); this.channelId = readUUID(i);
int messageCount = i.readInt(); System.out.println("Reading list of chats...");
Chat[] messages = new Chat[messageCount]; this.messages = readList(Chat.class, i);
for (int k = 0; k < messageCount; k++) {
Chat c = new Chat();
c.read(i);
messages[k] = c;
}
this.messages = List.of(messages);
} }
} }

View File

@ -0,0 +1,52 @@
package nl.andrewl.concord_core.msg.types;
import lombok.Getter;
import lombok.NoArgsConstructor;
import nl.andrewl.concord_core.msg.Message;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
/**
* This message is sent as the first message from both the server and the client
* to establish an end-to-end encryption via a key exchange.
*/
@Getter
@NoArgsConstructor
public class KeyData implements Message {
private byte[] iv;
private byte[] salt;
private byte[] publicKey;
public KeyData(byte[] iv, byte[] salt, byte[] publicKey) {
this.iv = iv;
this.salt = salt;
this.publicKey = publicKey;
}
@Override
public int getByteCount() {
return Integer.BYTES * 3 + iv.length + salt.length + publicKey.length;
}
@Override
public void write(DataOutputStream o) throws IOException {
o.writeInt(iv.length);
o.write(iv);
o.writeInt(salt.length);
o.write(salt);
o.writeInt(publicKey.length);
o.write(publicKey);
}
@Override
public void read(DataInputStream i) throws IOException {
int ivLength = i.readInt();
this.iv = i.readNBytes(ivLength);
int saltLength = i.readInt();
this.salt = i.readNBytes(saltLength);
int publicKeyLength = i.readInt();
this.publicKey = i.readNBytes(publicKeyLength);
}
}

View File

@ -0,0 +1,3 @@
package nl.andrewl.concord_core.util;
public record Pair<A, B>(A first, B second) {}

View File

@ -2,15 +2,16 @@ package nl.andrewl.concord_server.client;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import nl.andrewl.concord_core.msg.Encryption;
import nl.andrewl.concord_core.msg.Message; import nl.andrewl.concord_core.msg.Message;
import nl.andrewl.concord_core.msg.types.Identification; import nl.andrewl.concord_core.msg.types.Identification;
import nl.andrewl.concord_core.msg.types.UserData; import nl.andrewl.concord_core.msg.types.UserData;
import nl.andrewl.concord_server.channel.Channel;
import nl.andrewl.concord_server.ConcordServer; import nl.andrewl.concord_server.ConcordServer;
import nl.andrewl.concord_server.channel.Channel;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket; import java.net.Socket;
import java.util.UUID; import java.util.UUID;
@ -20,8 +21,8 @@ import java.util.UUID;
*/ */
public class ClientThread extends Thread { public class ClientThread extends Thread {
private final Socket socket; private final Socket socket;
private final DataInputStream in; private InputStream in;
private final DataOutputStream out; private OutputStream out;
private final ConcordServer server; private final ConcordServer server;
@ -48,8 +49,8 @@ public class ClientThread extends Thread {
public ClientThread(Socket socket, ConcordServer server) throws IOException { public ClientThread(Socket socket, ConcordServer server) throws IOException {
this.socket = socket; this.socket = socket;
this.server = server; this.server = server;
this.in = new DataInputStream(socket.getInputStream()); this.in = socket.getInputStream();
this.out = new DataOutputStream(socket.getOutputStream()); this.out = socket.getOutputStream();
} }
/** /**
@ -76,6 +77,7 @@ public class ClientThread extends Thread {
this.out.flush(); this.out.flush();
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); e.printStackTrace();
System.err.printf("Could not send to client %s(%s): %s", this.clientId, this.clientNickname, e.getMessage());
} }
} }
@ -104,6 +106,7 @@ public class ClientThread extends Thread {
while (this.running) { while (this.running) {
try { try {
var msg = this.server.getSerializer().readMessage(this.in); var msg = this.server.getSerializer().readMessage(this.in);
System.out.println("Received " + msg.getClass().getSimpleName() + " from " + this.clientNickname);
this.server.getEventManager().handle(msg, this); this.server.getEventManager().handle(msg, this);
} catch (IOException e) { } catch (IOException e) {
this.running = false; this.running = false;
@ -131,6 +134,16 @@ public class ClientThread extends Thread {
*/ */
private boolean identifyClient() { private boolean identifyClient() {
int attempts = 0; int attempts = 0;
try {
System.out.println("Initializing end-to-end encryption with the client...");
var streams = Encryption.upgrade(this.in, this.out, server.getSerializer());
this.in = streams.first();
this.out = streams.second();
System.out.println("Successfully established cipher streams.");
} catch (Exception e) {
e.printStackTrace();
return false;
}
while (attempts < 5) { while (attempts < 5) {
try { try {
var msg = this.server.getSerializer().readMessage(this.in); var msg = this.server.getSerializer().readMessage(this.in);

View File

@ -59,7 +59,7 @@ public class ChatHistoryRequestHandler implements MessageHandler<ChatHistoryRequ
} }
private ChatHistoryResponse getResponse(Channel channel, long count, Long from, Long to) { private ChatHistoryResponse getResponse(Channel channel, long count, Long from, Long to) {
var col = channel.getServer().getDb().getCollection("channel-" + channel.getId()); var col = channel.getMessageCollection();
Cursor cursor; Cursor cursor;
FindOptions options = FindOptions.sort("timestamp", SortOrder.Descending).thenLimit(0, (int) count); FindOptions options = FindOptions.sort("timestamp", SortOrder.Descending).thenLimit(0, (int) count);
List<Filter> filters = new ArrayList<>(2); List<Filter> filters = new ArrayList<>(2);
@ -74,12 +74,13 @@ public class ChatHistoryRequestHandler implements MessageHandler<ChatHistoryRequ
} else { } else {
cursor = col.find(Filters.and(filters.toArray(new Filter[0])), options); cursor = col.find(Filters.and(filters.toArray(new Filter[0])), options);
} }
System.out.println("Found " + cursor.size() + " chats");
List<Chat> chats = new ArrayList<>((int) count); List<Chat> chats = new ArrayList<>((int) count);
for (Document doc : cursor) { for (Document doc : cursor) {
chats.add(this.read(doc)); chats.add(this.read(doc));
} }
col.close(); System.out.println(chats);
chats.sort(Comparator.comparingLong(Chat::getTimestamp)); chats.sort(Comparator.comparingLong(Chat::getTimestamp));
return new ChatHistoryResponse(channel.getId(), chats); return new ChatHistoryResponse(channel.getId(), chats);
} }