From c7e0cc115898db259a732d2427c4848f67f843fd Mon Sep 17 00:00:00 2001 From: Moxie Marlinspike Date: Mon, 16 Mar 2015 15:05:33 -0700 Subject: [PATCH] Use a custom redis pubsub implementation rather than Jedis. // FREEBIE --- pom.xml | 1 - .../dispatch/DispatchChannel.java | 7 + .../dispatch/DispatchManager.java | 191 +++++++++++++ .../dispatch/io/RedisInputStream.java | 64 +++++ .../io/RedisPubSubConnectionFactory.java | 9 + .../dispatch/redis/PubSubConnection.java | 119 ++++++++ .../dispatch/redis/PubSubReply.java | 35 +++ .../redis/protocol/ArrayReplyHeader.java | 24 ++ .../dispatch/redis/protocol/IntReply.java | 24 ++ .../redis/protocol/StringReplyHeader.java | 24 ++ .../whispersystems/dispatch/util/Util.java | 36 +++ .../textsecuregcm/WhisperServerService.java | 16 +- .../providers/RedisClientFactory.java | 39 ++- .../textsecuregcm/storage/PubSubListener.java | 9 - .../textsecuregcm/storage/PubSubManager.java | 185 ++++-------- .../AuthenticatedConnectListener.java | 40 +-- .../websocket/DeadLetterHandler.java | 23 +- .../ProvisioningConnectListener.java | 8 +- .../websocket/ProvisioningConnection.java | 68 ++--- .../websocket/WebSocketConnection.java | 43 ++- .../dispatch/DispatchManagerTest.java | 127 +++++++++ .../dispatch/redis/PubSubConnectionTest.java | 263 ++++++++++++++++++ .../redis/protocol/ArrayReplyHeaderTest.java | 51 ++++ .../redis/protocol/IntReplyHeaderTest.java | 36 +++ .../redis/protocol/StringReplyHeaderTest.java | 47 ++++ .../websocket/WebSocketConnectionTest.java | 57 +--- 26 files changed, 1254 insertions(+), 292 deletions(-) create mode 100644 src/main/java/org/whispersystems/dispatch/DispatchChannel.java create mode 100644 src/main/java/org/whispersystems/dispatch/DispatchManager.java create mode 100644 src/main/java/org/whispersystems/dispatch/io/RedisInputStream.java create mode 100644 src/main/java/org/whispersystems/dispatch/io/RedisPubSubConnectionFactory.java create mode 100644 src/main/java/org/whispersystems/dispatch/redis/PubSubConnection.java create mode 100644 src/main/java/org/whispersystems/dispatch/redis/PubSubReply.java create mode 100644 src/main/java/org/whispersystems/dispatch/redis/protocol/ArrayReplyHeader.java create mode 100644 src/main/java/org/whispersystems/dispatch/redis/protocol/IntReply.java create mode 100644 src/main/java/org/whispersystems/dispatch/redis/protocol/StringReplyHeader.java create mode 100644 src/main/java/org/whispersystems/dispatch/util/Util.java delete mode 100644 src/main/java/org/whispersystems/textsecuregcm/storage/PubSubListener.java create mode 100644 src/test/java/org/whispersystems/dispatch/DispatchManagerTest.java create mode 100644 src/test/java/org/whispersystems/dispatch/redis/PubSubConnectionTest.java create mode 100644 src/test/java/org/whispersystems/dispatch/redis/protocol/ArrayReplyHeaderTest.java create mode 100644 src/test/java/org/whispersystems/dispatch/redis/protocol/IntReplyHeaderTest.java create mode 100644 src/test/java/org/whispersystems/dispatch/redis/protocol/StringReplyHeaderTest.java diff --git a/pom.xml b/pom.xml index 9859ea977..b121da047 100644 --- a/pom.xml +++ b/pom.xml @@ -132,7 +132,6 @@ 0.2.3 - diff --git a/src/main/java/org/whispersystems/dispatch/DispatchChannel.java b/src/main/java/org/whispersystems/dispatch/DispatchChannel.java new file mode 100644 index 000000000..9438f9492 --- /dev/null +++ b/src/main/java/org/whispersystems/dispatch/DispatchChannel.java @@ -0,0 +1,7 @@ +package org.whispersystems.dispatch; + +public interface DispatchChannel { + public void onDispatchMessage(String channel, byte[] message); + public void onDispatchSubscribed(String channel); + public void onDispatchUnsubscribed(String channel); +} diff --git a/src/main/java/org/whispersystems/dispatch/DispatchManager.java b/src/main/java/org/whispersystems/dispatch/DispatchManager.java new file mode 100644 index 000000000..fe9e1d267 --- /dev/null +++ b/src/main/java/org/whispersystems/dispatch/DispatchManager.java @@ -0,0 +1,191 @@ +package org.whispersystems.dispatch; + +import com.google.common.base.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.dispatch.io.RedisPubSubConnectionFactory; +import org.whispersystems.dispatch.redis.PubSubConnection; +import org.whispersystems.dispatch.redis.PubSubReply; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; + +public class DispatchManager extends Thread { + + private final Logger logger = LoggerFactory.getLogger(DispatchManager.class); + private final Executor executor = Executors.newCachedThreadPool(); + private final Map subscriptions = new HashMap<>(); + + private final Optional deadLetterChannel; + private final RedisPubSubConnectionFactory redisPubSubConnectionFactory; + + private PubSubConnection pubSubConnection; + private volatile boolean running; + + public DispatchManager(RedisPubSubConnectionFactory redisPubSubConnectionFactory, + Optional deadLetterChannel) + { + this.redisPubSubConnectionFactory = redisPubSubConnectionFactory; + this.deadLetterChannel = deadLetterChannel; + } + + @Override + public void start() { + this.pubSubConnection = redisPubSubConnectionFactory.connect(); + this.running = true; + super.start(); + } + + public void shutdown() { + this.running = false; + this.pubSubConnection.close(); + } + + public void subscribe(String name, DispatchChannel dispatchChannel) { + Optional previous; + + synchronized (subscriptions) { + previous = Optional.fromNullable(subscriptions.get(name)); + subscriptions.put(name, dispatchChannel); + } + + try { + pubSubConnection.subscribe(name); + } catch (IOException e) { + logger.warn("Subscription error", e); + } + + if (previous.isPresent()) { + dispatchUnsubscription(name, previous.get()); + } + } + + public void unsubscribe(String name, DispatchChannel channel) { + final Optional subscription; + + synchronized (subscriptions) { + subscription = Optional.fromNullable(subscriptions.get(name)); + + if (subscription.isPresent() && subscription.get() == channel) { + subscriptions.remove(name); + } + } + + if (subscription.isPresent()) { + try { + pubSubConnection.unsubscribe(name); + } catch (IOException e) { + logger.warn("Unsubscribe error", e); + } + + dispatchUnsubscription(name, subscription.get()); + } + } + + @Override + public void run() { + while (running) { + try { + PubSubReply reply = pubSubConnection.read(); + + switch (reply.getType()) { + case UNSUBSCRIBE: break; + case SUBSCRIBE: dispatchSubscribe(reply); break; + case MESSAGE: dispatchMessage(reply); break; + default: throw new AssertionError("Unknown pubsub reply type! " + reply.getType()); + } + } catch (IOException e) { + logger.warn("***** PubSub Connection Error *****", e); + if (running) { + this.pubSubConnection.close(); + this.pubSubConnection = redisPubSubConnectionFactory.connect(); + resubscribe(); + } + } + } + + logger.warn("DispatchManager Shutting Down..."); + } + + private void dispatchSubscribe(final PubSubReply reply) { + final Optional subscription; + + synchronized (subscriptions) { + subscription = Optional.fromNullable(subscriptions.get(reply.getChannel())); + } + + if (subscription.isPresent()) { + dispatchSubscription(reply.getChannel(), subscription.get()); + } else { + logger.warn("Received subscribe event for non-existing channel: " + reply.getChannel()); + } + } + + private void dispatchMessage(PubSubReply reply) { + Optional subscription; + + synchronized (subscriptions) { + subscription = Optional.fromNullable(subscriptions.get(reply.getChannel())); + } + + if (subscription.isPresent()) { + dispatchMessage(reply.getChannel(), subscription.get(), reply.getContent().get()); + } else if (deadLetterChannel.isPresent()) { + dispatchMessage(reply.getChannel(), deadLetterChannel.get(), reply.getContent().get()); + } else { + logger.warn("Received message for non-existing channel, with no dead letter handler: " + reply.getChannel()); + } + } + + private void resubscribe() { + final Collection names; + + synchronized (subscriptions) { + names = subscriptions.keySet(); + } + + new Thread() { + @Override + public void run() { + try { + for (String name : names) { + pubSubConnection.subscribe(name); + } + } catch (IOException e) { + logger.warn("***** RESUBSCRIPTION ERROR *****", e); + } + } + }.start(); + } + + private void dispatchMessage(final String name, final DispatchChannel channel, final byte[] message) { + executor.execute(new Runnable() { + @Override + public void run() { + channel.onDispatchMessage(name, message); + } + }); + } + + private void dispatchSubscription(final String name, final DispatchChannel channel) { + executor.execute(new Runnable() { + @Override + public void run() { + channel.onDispatchSubscribed(name); + } + }); + } + + private void dispatchUnsubscription(final String name, final DispatchChannel channel) { + executor.execute(new Runnable() { + @Override + public void run() { + channel.onDispatchUnsubscribed(name); + } + }); + } +} diff --git a/src/main/java/org/whispersystems/dispatch/io/RedisInputStream.java b/src/main/java/org/whispersystems/dispatch/io/RedisInputStream.java new file mode 100644 index 000000000..2245db9f3 --- /dev/null +++ b/src/main/java/org/whispersystems/dispatch/io/RedisInputStream.java @@ -0,0 +1,64 @@ +package org.whispersystems.dispatch.io; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; + +public class RedisInputStream { + + private static final byte CR = 0x0D; + private static final byte LF = 0x0A; + + private final InputStream inputStream; + + public RedisInputStream(InputStream inputStream) { + this.inputStream = inputStream; + } + + public String readLine() throws IOException { + ByteArrayOutputStream boas = new ByteArrayOutputStream(); + + boolean foundCr = false; + + while (true) { + int character = inputStream.read(); + + if (character == -1) { + throw new IOException("Stream closed!"); + } + + boas.write(character); + + if (foundCr && character == LF) break; + else if (character == CR) foundCr = true; + else if (foundCr) foundCr = false; + } + + byte[] data = boas.toByteArray(); + return new String(data, 0, data.length-2); + } + + public byte[] readFully(int size) throws IOException { + byte[] result = new byte[size]; + int offset = 0; + int remaining = result.length; + + while (remaining > 0) { + int read = inputStream.read(result, offset, remaining); + + if (read < 0) { + throw new IOException("Stream closed!"); + } + + offset += read; + remaining -= read; + } + + return result; + } + + public void close() throws IOException { + inputStream.close(); + } + +} diff --git a/src/main/java/org/whispersystems/dispatch/io/RedisPubSubConnectionFactory.java b/src/main/java/org/whispersystems/dispatch/io/RedisPubSubConnectionFactory.java new file mode 100644 index 000000000..d93b58072 --- /dev/null +++ b/src/main/java/org/whispersystems/dispatch/io/RedisPubSubConnectionFactory.java @@ -0,0 +1,9 @@ +package org.whispersystems.dispatch.io; + +import org.whispersystems.dispatch.redis.PubSubConnection; + +public interface RedisPubSubConnectionFactory { + + public PubSubConnection connect(); + +} diff --git a/src/main/java/org/whispersystems/dispatch/redis/PubSubConnection.java b/src/main/java/org/whispersystems/dispatch/redis/PubSubConnection.java new file mode 100644 index 000000000..599c3930a --- /dev/null +++ b/src/main/java/org/whispersystems/dispatch/redis/PubSubConnection.java @@ -0,0 +1,119 @@ +package org.whispersystems.dispatch.redis; + +import com.google.common.base.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.dispatch.io.RedisInputStream; +import org.whispersystems.dispatch.redis.protocol.ArrayReplyHeader; +import org.whispersystems.dispatch.redis.protocol.IntReply; +import org.whispersystems.dispatch.redis.protocol.StringReplyHeader; +import org.whispersystems.dispatch.util.Util; + +import java.io.BufferedInputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.net.Socket; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicBoolean; + +public class PubSubConnection { + + private final Logger logger = LoggerFactory.getLogger(PubSubConnection.class); + + private static final byte[] UNSUBSCRIBE_TYPE = {'u', 'n', 's', 'u', 'b', 's', 'c', 'r', 'i', 'b', 'e' }; + private static final byte[] SUBSCRIBE_TYPE = {'s', 'u', 'b', 's', 'c', 'r', 'i', 'b', 'e' }; + private static final byte[] MESSAGE_TYPE = {'m', 'e', 's', 's', 'a', 'g', 'e' }; + + private static final byte[] SUBSCRIBE_COMMAND = {'S', 'U', 'B', 'S', 'C', 'R', 'I', 'B', 'E', ' ' }; + private static final byte[] UNSUBSCRIBE_COMMAND = {'U', 'N', 'S', 'U', 'B', 'S', 'C', 'R', 'I', 'B', 'E', ' '}; + private static final byte[] CRLF = {'\r', '\n' }; + + private final OutputStream outputStream; + private final RedisInputStream inputStream; + private final Socket socket; + private final AtomicBoolean closed; + + public PubSubConnection(Socket socket) throws IOException { + this.socket = socket; + this.outputStream = socket.getOutputStream(); + this.inputStream = new RedisInputStream(new BufferedInputStream(socket.getInputStream())); + this.closed = new AtomicBoolean(false); + } + + public void subscribe(String channelName) throws IOException { + if (closed.get()) throw new IOException("Connection closed!"); + + byte[] command = Util.combine(SUBSCRIBE_COMMAND, channelName.getBytes(), CRLF); + outputStream.write(command); + } + + public void unsubscribe(String channelName) throws IOException { + if (closed.get()) throw new IOException("Connection closed!"); + + byte[] command = Util.combine(UNSUBSCRIBE_COMMAND, channelName.getBytes(), CRLF); + outputStream.write(command); + } + + public PubSubReply read() throws IOException { + if (closed.get()) throw new IOException("Connection closed!"); + + ArrayReplyHeader replyHeader = new ArrayReplyHeader(inputStream.readLine()); + + if (replyHeader.getElementCount() != 3) { + throw new IOException("Received array reply header with strange count: " + replyHeader.getElementCount()); + } + + StringReplyHeader replyTypeHeader = new StringReplyHeader(inputStream.readLine()); + byte[] replyType = inputStream.readFully(replyTypeHeader.getStringLength()); + inputStream.readLine(); + + if (Arrays.equals(SUBSCRIBE_TYPE, replyType)) return readSubscribeReply(); + else if (Arrays.equals(UNSUBSCRIBE_TYPE, replyType)) return readUnsubscribeReply(); + else if (Arrays.equals(MESSAGE_TYPE, replyType)) return readMessageReply(); + else throw new IOException("Unknown reply type: " + new String(replyType)); + } + + public void close() { + try { + this.closed.set(true); + this.inputStream.close(); + this.outputStream.close(); + this.socket.close(); + } catch (IOException e) { + logger.warn("Exception while closing", e); + } + } + + private PubSubReply readMessageReply() throws IOException { + StringReplyHeader channelNameHeader = new StringReplyHeader(inputStream.readLine()); + byte[] channelName = inputStream.readFully(channelNameHeader.getStringLength()); + inputStream.readLine(); + + StringReplyHeader messageHeader = new StringReplyHeader(inputStream.readLine()); + byte[] message = inputStream.readFully(messageHeader.getStringLength()); + inputStream.readLine(); + + return new PubSubReply(PubSubReply.Type.MESSAGE, new String(channelName), Optional.of(message)); + } + + private PubSubReply readUnsubscribeReply() throws IOException { + String channelName = readSubscriptionReply(); + return new PubSubReply(PubSubReply.Type.UNSUBSCRIBE, channelName, Optional.absent()); + } + + private PubSubReply readSubscribeReply() throws IOException { + String channelName = readSubscriptionReply(); + return new PubSubReply(PubSubReply.Type.SUBSCRIBE, channelName, Optional.absent()); + } + + private String readSubscriptionReply() throws IOException { + StringReplyHeader channelNameHeader = new StringReplyHeader(inputStream.readLine()); + byte[] channelName = inputStream.readFully(channelNameHeader.getStringLength()); + inputStream.readLine(); + + IntReply subscriptionCount = new IntReply(inputStream.readLine()); + + return new String(channelName); + } + +} diff --git a/src/main/java/org/whispersystems/dispatch/redis/PubSubReply.java b/src/main/java/org/whispersystems/dispatch/redis/PubSubReply.java new file mode 100644 index 000000000..d57797fae --- /dev/null +++ b/src/main/java/org/whispersystems/dispatch/redis/PubSubReply.java @@ -0,0 +1,35 @@ +package org.whispersystems.dispatch.redis; + +import com.google.common.base.Optional; + +public class PubSubReply { + + public enum Type { + MESSAGE, + SUBSCRIBE, + UNSUBSCRIBE + } + + private final Type type; + private final String channel; + private final Optional content; + + public PubSubReply(Type type, String channel, Optional content) { + this.type = type; + this.channel = channel; + this.content = content; + } + + public Type getType() { + return type; + } + + public String getChannel() { + return channel; + } + + public Optional getContent() { + return content; + } + +} diff --git a/src/main/java/org/whispersystems/dispatch/redis/protocol/ArrayReplyHeader.java b/src/main/java/org/whispersystems/dispatch/redis/protocol/ArrayReplyHeader.java new file mode 100644 index 000000000..1d3528c1b --- /dev/null +++ b/src/main/java/org/whispersystems/dispatch/redis/protocol/ArrayReplyHeader.java @@ -0,0 +1,24 @@ +package org.whispersystems.dispatch.redis.protocol; + +import java.io.IOException; + +public class ArrayReplyHeader { + + private final int elementCount; + + public ArrayReplyHeader(String header) throws IOException { + if (header == null || header.length() < 2 || header.charAt(0) != '*') { + throw new IOException("Invalid array reply header: " + header); + } + + try { + this.elementCount = Integer.parseInt(header.substring(1)); + } catch (NumberFormatException e) { + throw new IOException(e); + } + } + + public int getElementCount() { + return elementCount; + } +} diff --git a/src/main/java/org/whispersystems/dispatch/redis/protocol/IntReply.java b/src/main/java/org/whispersystems/dispatch/redis/protocol/IntReply.java new file mode 100644 index 000000000..7c7f775b3 --- /dev/null +++ b/src/main/java/org/whispersystems/dispatch/redis/protocol/IntReply.java @@ -0,0 +1,24 @@ +package org.whispersystems.dispatch.redis.protocol; + +import java.io.IOException; + +public class IntReply { + + private final int value; + + public IntReply(String reply) throws IOException { + if (reply == null || reply.length() < 2 || reply.charAt(0) != ':') { + throw new IOException("Invalid int reply: " + reply); + } + + try { + this.value = Integer.parseInt(reply.substring(1)); + } catch (NumberFormatException e) { + throw new IOException(e); + } + } + + public int getValue() { + return value; + } +} diff --git a/src/main/java/org/whispersystems/dispatch/redis/protocol/StringReplyHeader.java b/src/main/java/org/whispersystems/dispatch/redis/protocol/StringReplyHeader.java new file mode 100644 index 000000000..4a6030e21 --- /dev/null +++ b/src/main/java/org/whispersystems/dispatch/redis/protocol/StringReplyHeader.java @@ -0,0 +1,24 @@ +package org.whispersystems.dispatch.redis.protocol; + +import java.io.IOException; + +public class StringReplyHeader { + + private final int stringLength; + + public StringReplyHeader(String header) throws IOException { + if (header == null || header.length() < 2 || header.charAt(0) != '$') { + throw new IOException("Invalid string reply header: " + header); + } + + try { + this.stringLength = Integer.parseInt(header.substring(1)); + } catch (NumberFormatException e) { + throw new IOException(e); + } + } + + public int getStringLength() { + return stringLength; + } +} diff --git a/src/main/java/org/whispersystems/dispatch/util/Util.java b/src/main/java/org/whispersystems/dispatch/util/Util.java new file mode 100644 index 000000000..245466cea --- /dev/null +++ b/src/main/java/org/whispersystems/dispatch/util/Util.java @@ -0,0 +1,36 @@ +package org.whispersystems.dispatch.util; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +public class Util { + + public static byte[] combine(byte[]... elements) { + try { + int sum = 0; + + for (byte[] element : elements) { + sum += element.length; + } + + ByteArrayOutputStream baos = new ByteArrayOutputStream(sum); + + for (byte[] element : elements) { + baos.write(element); + } + + return baos.toByteArray(); + } catch (IOException e) { + throw new AssertionError(e); + } + } + + + public static void sleep(long millis) { + try { + Thread.sleep(millis); + } catch (InterruptedException e) { + throw new AssertionError(e); + } + } +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 8b79571d1..cb712e1bc 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -24,6 +24,8 @@ import com.sun.jersey.api.client.Client; import org.bouncycastle.jce.provider.BouncyCastleProvider; import org.eclipse.jetty.servlets.CrossOriginFilter; import org.skife.jdbi.v2.DBI; +import org.whispersystems.dispatch.DispatchChannel; +import org.whispersystems.dispatch.DispatchManager; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.FederatedPeerAuthenticator; import org.whispersystems.textsecuregcm.auth.MultiBasicAuthProvider; @@ -147,10 +149,11 @@ public class WhisperServerService extends Applicationof(deadLetterHandler)); + PubSubManager pubSubManager = new PubSubManager(cacheClient, dispatchManager); PushServiceClient pushServiceClient = new PushServiceClient(httpClient, config.getPushConfiguration()); WebsocketSender websocketSender = new WebsocketSender(messagesManager, pubSubManager); AccountAuthenticator deviceAuthenticator = new AccountAuthenticator(accountsManager); @@ -173,6 +177,7 @@ public class WhisperServerService extends Application authorizationKey = config.getRedphoneConfiguration().getAuthorizationKey(); + environment.lifecycle().manage(pubSubManager); environment.lifecycle().manage(feedbackHandler); AttachmentController attachmentController = new AttachmentController(rateLimiters, federatedClientManager, urlSigner); @@ -263,5 +268,4 @@ public class WhisperServerService extends Application listeners = new HashMap<>(); - private final Executor threaded = Executors.newCachedThreadPool(); + private final Logger logger = LoggerFactory.getLogger(PubSubManager.class); + private final DispatchManager dispatchManager; private final JedisPool jedisPool; - private final DeadLetterHandler deadLetterHandler; private boolean subscribed = false; - public PubSubManager(JedisPool jedisPool, DeadLetterHandler deadLetterHandler) { + public PubSubManager(JedisPool jedisPool, DispatchManager dispatchManager) { + this.dispatchManager = dispatchManager; this.jedisPool = jedisPool; - this.deadLetterHandler = deadLetterHandler; - initializePubSubWorker(); - waitForSubscription(); } - public synchronized void subscribe(WebsocketAddress address, PubSubListener listener) { - String serializedAddress = address.serialize(); + @Override + public void start() throws Exception { + this.dispatchManager.start(); - listeners.put(serializedAddress, listener); - baseListener.subscribe(serializedAddress.getBytes()); - } + KeepaliveDispatchChannel keepaliveDispatchChannel = new KeepaliveDispatchChannel(); + this.dispatchManager.subscribe(KEEPALIVE_CHANNEL, keepaliveDispatchChannel); - public synchronized void unsubscribe(WebsocketAddress address, PubSubListener listener) { - String serializedAddress = address.serialize(); - - if (listeners.get(serializedAddress) == listener) { - listeners.remove(serializedAddress); - baseListener.unsubscribe(serializedAddress.getBytes()); + synchronized (this) { + while (!subscribed) wait(0); } + + new KeepaliveSender().start(); } - public synchronized boolean publish(WebsocketAddress address, PubSubMessage message) { + @Override + public void stop() throws Exception { + dispatchManager.shutdown(); + } + + public void subscribe(WebsocketAddress address, DispatchChannel channel) { + String serializedAddress = address.serialize(); + dispatchManager.subscribe(serializedAddress, channel); + } + + public void unsubscribe(WebsocketAddress address, DispatchChannel dispatchChannel) { + String serializedAddress = address.serialize(); + dispatchManager.unsubscribe(serializedAddress, dispatchChannel); + } + + public boolean publish(WebsocketAddress address, PubSubMessage message) { return publish(address.serialize().getBytes(), message); } - private synchronized boolean publish(byte[] channel, PubSubMessage message) { + private boolean publish(byte[] channel, PubSubMessage message) { try (Jedis jedis = jedisPool.getResource()) { return jedis.publish(channel, message.toByteArray()) != 0; } } - private synchronized void resubscribeAll() { - for (String serializedAddress : listeners.keySet()) { - baseListener.subscribe(serializedAddress.getBytes()); - } - } - - private synchronized void waitForSubscription() { - try { - while (!subscribed) { - wait(); - } - } catch (InterruptedException e) { - throw new AssertionError(e); - } - } - - private void initializePubSubWorker() { - new Thread("PubSubListener") { - @Override - public void run() { - for (;;) { - logger.info("Starting Redis PubSub Subscriber..."); - - try (Jedis jedis = jedisPool.getResource()) { - jedis.subscribe(baseListener, KEEPALIVE_CHANNEL); - logger.warn("**** Unsubscribed from holding channel!!! ******"); - } catch (Throwable t) { - logger.warn("*** SUBSCRIBER CONNECTION CLOSED", t); - } - } - } - }.start(); - - new Thread("PubSubKeepAlive") { - @Override - public void run() { - for (;;) { - try { - Thread.sleep(20000); - publish(KEEPALIVE_CHANNEL, PubSubMessage.newBuilder() - .setType(PubSubMessage.Type.KEEPALIVE) - .build()); - } catch (Throwable e) { - logger.warn("KEEPALIVE PUBLISH EXCEPTION: ", e); - } - } - } - }.start(); - } - - private class SubscriptionListener extends BinaryJedisPubSub { + private class KeepaliveDispatchChannel implements DispatchChannel { @Override - public void onMessage(final byte[] channel, final byte[] message) { - if (Arrays.equals(KEEPALIVE_CHANNEL, channel)) { - return; - } - - final PubSubListener listener; - - synchronized (PubSubManager.this) { - listener = listeners.get(new String(channel)); - } - - threaded.execute(new Runnable() { - @Override - public void run() { - try { - PubSubMessage receivedMessage = PubSubMessage.parseFrom(message); - - if (listener != null) listener.onPubSubMessage(receivedMessage); - else deadLetterHandler.handle(channel, receivedMessage); - } catch (InvalidProtocolBufferException e) { - logger.warn("Error parsing PubSub protobuf", e); - } - } - }); + public void onDispatchMessage(String channel, byte[] message) { + // Good } @Override - public void onPMessage(byte[] s, byte[] s2, byte[] s3) { - logger.warn("Received PMessage!"); - } - - @Override - public void onSubscribe(byte[] channel, int count) { - if (Arrays.equals(KEEPALIVE_CHANNEL, channel)) { + public void onDispatchSubscribed(String channel) { + if (KEEPALIVE_CHANNEL.equals(channel)) { synchronized (PubSubManager.this) { subscribed = true; PubSubManager.this.notifyAll(); } - - threaded.execute(new Runnable() { - @Override - public void run() { - resubscribeAll(); - } - }); } } @Override - public void onUnsubscribe(byte[] s, int i) {} + public void onDispatchUnsubscribed(String channel) { + logger.warn("***** KEEPALIVE CHANNEL UNSUBSCRIBED *****"); + } + } + private class KeepaliveSender extends Thread { @Override - public void onPUnsubscribe(byte[] s, int i) {} - - @Override - public void onPSubscribe(byte[] s, int i) {} + public void run() { + while (true) { + try { + Thread.sleep(20000); + publish(KEEPALIVE_CHANNEL.getBytes(), PubSubMessage.newBuilder() + .setType(PubSubMessage.Type.KEEPALIVE) + .build()); + } catch (Throwable e) { + logger.warn("***** KEEPALIVE EXCEPTION ******", e); + } + } + } } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index df11cf9be..8ae561cf5 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -1,5 +1,8 @@ package org.whispersystems.textsecuregcm.websocket; +import com.codahale.metrics.Histogram; +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.SharedMetricRegistries; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.push.PushSender; @@ -8,14 +11,19 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.PubSubManager; -import org.whispersystems.textsecuregcm.storage.PubSubProtos; +import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.websocket.session.WebSocketSessionContext; import org.whispersystems.websocket.setup.WebSocketConnectListener; +import static com.codahale.metrics.MetricRegistry.name; + public class AuthenticatedConnectListener implements WebSocketConnectListener { - private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class); + private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class); + private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); + private static final Histogram durationHistogram = metricRegistry.histogram(name(WebSocketConnection.class, "connected_duration")); + private final AccountsManager accountsManager; private final PushSender pushSender; @@ -33,23 +41,22 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { @Override public void onWebSocketConnect(WebSocketSessionContext context) { - Account account = context.getAuthenticated(Account.class).get(); - Device device = account.getAuthenticatedDevice().get(); + final Account account = context.getAuthenticated(Account.class).get(); + final Device device = account.getAuthenticatedDevice().get(); + final long connectTime = System.currentTimeMillis(); + final WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId()); + final WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender, + messagesManager, account, device, + context.getClient()); updateLastSeen(account, device); - closeExistingDeviceConnection(account, device); - - final WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender, - messagesManager, pubSubManager, - account, device, - context.getClient()); - - connection.onConnected(); + pubSubManager.subscribe(address, connection); context.addListener(new WebSocketSessionContext.WebSocketEventListener() { @Override public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) { - connection.onConnectionLost(); + pubSubManager.unsubscribe(address, connection); + durationHistogram.update(System.currentTimeMillis() - connectTime); } }); } @@ -60,12 +67,5 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { accountsManager.update(account); } } - - private void closeExistingDeviceConnection(Account account, Device device) { - pubSubManager.publish(new WebsocketAddress(account.getNumber(), device.getId()), - PubSubProtos.PubSubMessage.newBuilder() - .setType(PubSubProtos.PubSubMessage.Type.CLOSE) - .build()); - } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/DeadLetterHandler.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/DeadLetterHandler.java index 3e8504150..516609a67 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/DeadLetterHandler.java +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/DeadLetterHandler.java @@ -3,11 +3,13 @@ package org.whispersystems.textsecuregcm.websocket; import com.google.protobuf.InvalidProtocolBufferException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.dispatch.DispatchChannel; import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.PubSubProtos; +import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage; -public class DeadLetterHandler { +public class DeadLetterHandler implements DispatchChannel { private final Logger logger = LoggerFactory.getLogger(DeadLetterHandler.class); @@ -17,14 +19,16 @@ public class DeadLetterHandler { this.messagesManager = messagesManager; } - public void handle(byte[] channel, PubSubProtos.PubSubMessage pubSubMessage) { + @Override + public void onDispatchMessage(String channel, byte[] data) { try { - WebsocketAddress address = new WebsocketAddress(new String(channel)); + logger.warn("Handling dead letter to: " + channel); - logger.warn("Handling dead letter to: " + address); + WebsocketAddress address = new WebsocketAddress(channel); + PubSubMessage pubSubMessage = PubSubMessage.parseFrom(data); switch (pubSubMessage.getType().getNumber()) { - case PubSubProtos.PubSubMessage.Type.DELIVER_VALUE: + case PubSubMessage.Type.DELIVER_VALUE: OutgoingMessageSignal message = OutgoingMessageSignal.parseFrom(pubSubMessage.getContent()); messagesManager.insert(address.getNumber(), address.getDeviceId(), message); break; @@ -36,4 +40,13 @@ public class DeadLetterHandler { } } + @Override + public void onDispatchSubscribed(String channel) { + logger.warn("DeadLetterHandler subscription notice! " + channel); + } + + @Override + public void onDispatchUnsubscribed(String channel) { + logger.warn("DeadLetterHandler unsubscribe notice! " + channel); + } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java index 405a86250..0c903ebc8 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java @@ -14,13 +14,15 @@ public class ProvisioningConnectListener implements WebSocketConnectListener { @Override public void onWebSocketConnect(WebSocketSessionContext context) { - final ProvisioningConnection connection = new ProvisioningConnection(pubSubManager, context.getClient()); - connection.onConnected(); + final ProvisioningConnection connection = new ProvisioningConnection(context.getClient()); + final ProvisioningAddress provisioningAddress = ProvisioningAddress.generate(); + + pubSubManager.subscribe(provisioningAddress, connection); context.addListener(new WebSocketSessionContext.WebSocketEventListener() { @Override public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) { - connection.onConnectionLost(); + pubSubManager.unsubscribe(provisioningAddress, connection); } }); } diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnection.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnection.java index 817758304..1fb0af3c2 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnection.java +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnection.java @@ -4,58 +4,62 @@ import com.google.common.base.Optional; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.protobuf.InvalidProtocolBufferException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.dispatch.DispatchChannel; import org.whispersystems.textsecuregcm.entities.MessageProtos.ProvisioningUuid; -import org.whispersystems.textsecuregcm.storage.PubSubListener; -import org.whispersystems.textsecuregcm.storage.PubSubManager; import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage; import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.messages.WebSocketResponseMessage; -public class ProvisioningConnection implements PubSubListener { +public class ProvisioningConnection implements DispatchChannel { - private final PubSubManager pubSubManager; - private final ProvisioningAddress provisioningAddress; - private final WebSocketClient client; + private final Logger logger = LoggerFactory.getLogger(ProvisioningConnection.class); - public ProvisioningConnection(PubSubManager pubSubManager, WebSocketClient client) { - this.pubSubManager = pubSubManager; - this.client = client; - this.provisioningAddress = ProvisioningAddress.generate(); + private final WebSocketClient client; + + public ProvisioningConnection(WebSocketClient client) { + this.client = client; } @Override - public void onPubSubMessage(PubSubMessage outgoingMessage) { - if (outgoingMessage.getType() == PubSubMessage.Type.DELIVER) { - Optional body = Optional.of(outgoingMessage.getContent().toByteArray()); + public void onDispatchMessage(String channel, byte[] message) { + try { + PubSubMessage outgoingMessage = PubSubMessage.parseFrom(message); - ListenableFuture response = client.sendRequest("PUT", "/v1/message", body); + if (outgoingMessage.getType() == PubSubMessage.Type.DELIVER) { + Optional body = Optional.of(outgoingMessage.getContent().toByteArray()); - Futures.addCallback(response, new FutureCallback() { - @Override - public void onSuccess(WebSocketResponseMessage webSocketResponseMessage) { - pubSubManager.unsubscribe(provisioningAddress, ProvisioningConnection.this); - client.close(1001, "All you get."); - } + ListenableFuture response = client.sendRequest("PUT", "/v1/message", body); - @Override - public void onFailure(Throwable throwable) { - pubSubManager.unsubscribe(provisioningAddress, ProvisioningConnection.this); - client.close(1001, "That's all!"); - } - }); + Futures.addCallback(response, new FutureCallback() { + @Override + public void onSuccess(WebSocketResponseMessage webSocketResponseMessage) { + client.close(1001, "All you get."); + } + + @Override + public void onFailure(Throwable throwable) { + client.close(1001, "That's all!"); + } + }); + } + } catch (InvalidProtocolBufferException e) { + logger.warn("Protobuf Error: ", e); } } - public void onConnected() { - this.pubSubManager.subscribe(provisioningAddress, this); + @Override + public void onDispatchSubscribed(String channel) { this.client.sendRequest("PUT", "/v1/address", Optional.of(ProvisioningUuid.newBuilder() - .setUuid(provisioningAddress.getAddress()) + .setUuid(channel) .build() .toByteArray())); } - public void onConnectionLost() { - this.pubSubManager.unsubscribe(provisioningAddress, this); - this.client.close(1001, "Done"); + @Override + public void onDispatchUnsubscribed(String channel) { + this.client.close(1001, "Closed"); } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index f204a43ca..612a77d38 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -10,6 +10,7 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.protobuf.InvalidProtocolBufferException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.dispatch.DispatchChannel; import org.whispersystems.textsecuregcm.entities.CryptoEncodingException; import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; @@ -19,8 +20,6 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; -import org.whispersystems.textsecuregcm.storage.PubSubListener; -import org.whispersystems.textsecuregcm.storage.PubSubManager; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.websocket.WebSocketClient; @@ -34,21 +33,16 @@ import static com.codahale.metrics.MetricRegistry.name; import static org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal; import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage; -public class WebSocketConnection implements PubSubListener { +public class WebSocketConnection implements DispatchChannel { private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class); - private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); - private static final Histogram durationHistogram = metricRegistry.histogram(name(WebSocketConnection.class, "connected_duration")); - private final AccountsManager accountsManager; private final PushSender pushSender; private final MessagesManager messagesManager; - private final PubSubManager pubSubManager; private final Account account; private final Device device; - private final WebsocketAddress address; private final WebSocketClient client; private long connectionStartTime; @@ -56,7 +50,6 @@ public class WebSocketConnection implements PubSubListener { public WebSocketConnection(AccountsManager accountsManager, PushSender pushSender, MessagesManager messagesManager, - PubSubManager pubSubManager, Account account, Device device, WebSocketClient client) @@ -64,27 +57,16 @@ public class WebSocketConnection implements PubSubListener { this.accountsManager = accountsManager; this.pushSender = pushSender; this.messagesManager = messagesManager; - this.pubSubManager = pubSubManager; this.account = account; this.device = device; this.client = client; - this.address = new WebsocketAddress(account.getNumber(), device.getId()); - } - - public void onConnected() { - connectionStartTime = System.currentTimeMillis(); - pubSubManager.subscribe(address, this); - processStoredMessages(); - } - - public void onConnectionLost() { - durationHistogram.update(System.currentTimeMillis() - connectionStartTime); - pubSubManager.unsubscribe(address, this); } @Override - public void onPubSubMessage(PubSubMessage pubSubMessage) { + public void onDispatchMessage(String channel, byte[] message) { try { + PubSubMessage pubSubMessage = PubSubMessage.parseFrom(message); + switch (pubSubMessage.getType().getNumber()) { case PubSubMessage.Type.QUERY_DB_VALUE: processStoredMessages(); @@ -92,10 +74,6 @@ public class WebSocketConnection implements PubSubListener { case PubSubMessage.Type.DELIVER_VALUE: sendMessage(OutgoingMessageSignal.parseFrom(pubSubMessage.getContent()), Optional.absent()); break; - case PubSubMessage.Type.CLOSE_VALUE: - client.close(1000, "OK"); - pubSubManager.unsubscribe(address, this); - break; default: logger.warn("Unknown pubsub message: " + pubSubMessage.getType().getNumber()); } @@ -104,6 +82,15 @@ public class WebSocketConnection implements PubSubListener { } } + @Override + public void onDispatchUnsubscribed(String channel) { + client.close(1000, "OK"); + } + + public void onDispatchSubscribed(String channel) { + processStoredMessages(); + } + private void sendMessage(final OutgoingMessageSignal message, final Optional storedMessageId) { @@ -180,4 +167,6 @@ public class WebSocketConnection implements PubSubListener { sendMessage(message.second(), Optional.of(message.first())); } } + + } diff --git a/src/test/java/org/whispersystems/dispatch/DispatchManagerTest.java b/src/test/java/org/whispersystems/dispatch/DispatchManagerTest.java new file mode 100644 index 000000000..d6e81c518 --- /dev/null +++ b/src/test/java/org/whispersystems/dispatch/DispatchManagerTest.java @@ -0,0 +1,127 @@ +package org.whispersystems.dispatch; + +import com.google.common.base.Optional; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExternalResource; +import org.mockito.ArgumentCaptor; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.whispersystems.dispatch.io.RedisPubSubConnectionFactory; +import org.whispersystems.dispatch.redis.PubSubConnection; +import org.whispersystems.dispatch.redis.PubSubReply; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.*; + +public class DispatchManagerTest { + + private PubSubConnection pubSubConnection; + private RedisPubSubConnectionFactory socketFactory; + private DispatchManager dispatchManager; + private PubSubReplyInputStream pubSubReplyInputStream; + + @Rule + public ExternalResource resource = new ExternalResource() { + @Override + protected void before() throws Throwable { + pubSubConnection = mock(PubSubConnection.class ); + socketFactory = mock(RedisPubSubConnectionFactory.class); + pubSubReplyInputStream = new PubSubReplyInputStream(); + + when(socketFactory.connect()).thenReturn(pubSubConnection); + when(pubSubConnection.read()).thenAnswer(new Answer() { + @Override + public PubSubReply answer(InvocationOnMock invocationOnMock) throws Throwable { + return pubSubReplyInputStream.read(); + } + }); + + dispatchManager = new DispatchManager(socketFactory, Optional.absent()); + dispatchManager.start(); + } + + @Override + protected void after() { + + } + }; + + @Test + public void testConnect() { + verify(socketFactory).connect(); + } + + @Test + public void testSubscribe() throws IOException { + DispatchChannel dispatchChannel = mock(DispatchChannel.class); + dispatchManager.subscribe("foo", dispatchChannel); + pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "foo", Optional.absent())); + + verify(dispatchChannel, timeout(1000)).onDispatchSubscribed(eq("foo")); + } + + @Test + public void testSubscribeUnsubscribe() throws IOException { + DispatchChannel dispatchChannel = mock(DispatchChannel.class); + dispatchManager.subscribe("foo", dispatchChannel); + dispatchManager.unsubscribe("foo", dispatchChannel); + + pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "foo", Optional.absent())); + pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.UNSUBSCRIBE, "foo", Optional.absent())); + + verify(dispatchChannel, timeout(1000)).onDispatchUnsubscribed(eq("foo")); + } + + @Test + public void testMessages() throws IOException { + DispatchChannel fooChannel = mock(DispatchChannel.class); + DispatchChannel barChannel = mock(DispatchChannel.class); + + dispatchManager.subscribe("foo", fooChannel); + dispatchManager.subscribe("bar", barChannel); + + pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "foo", Optional.absent())); + pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "bar", Optional.absent())); + + verify(fooChannel, timeout(1000)).onDispatchSubscribed(eq("foo")); + verify(barChannel, timeout(1000)).onDispatchSubscribed(eq("bar")); + + pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.MESSAGE, "foo", Optional.of("hello".getBytes()))); + pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.MESSAGE, "bar", Optional.of("there".getBytes()))); + + ArgumentCaptor captor = ArgumentCaptor.forClass(byte[].class); + verify(fooChannel, timeout(1000)).onDispatchMessage(eq("foo"), captor.capture()); + + assertArrayEquals("hello".getBytes(), captor.getValue()); + + verify(barChannel, timeout(1000)).onDispatchMessage(eq("bar"), captor.capture()); + + assertArrayEquals("there".getBytes(), captor.getValue()); + } + + private static class PubSubReplyInputStream { + + private final List pubSubReplyList = new LinkedList<>(); + + public synchronized PubSubReply read() { + try { + while (pubSubReplyList.isEmpty()) wait(); + return pubSubReplyList.remove(0); + } catch (InterruptedException e) { + throw new AssertionError(e); + } + } + + public synchronized void write(PubSubReply pubSubReply) { + pubSubReplyList.add(pubSubReply); + notifyAll(); + } + } + +} diff --git a/src/test/java/org/whispersystems/dispatch/redis/PubSubConnectionTest.java b/src/test/java/org/whispersystems/dispatch/redis/PubSubConnectionTest.java new file mode 100644 index 000000000..a1524aabb --- /dev/null +++ b/src/test/java/org/whispersystems/dispatch/redis/PubSubConnectionTest.java @@ -0,0 +1,263 @@ +package org.whispersystems.dispatch.redis; + +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.Socket; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; + +import static org.junit.Assert.*; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Mockito.*; + +public class PubSubConnectionTest { + + private static final String REPLY = "*3\r\n" + + "$9\r\n" + + "subscribe\r\n" + + "$5\r\n" + + "abcde\r\n" + + ":1\r\n" + + "*3\r\n" + + "$9\r\n" + + "subscribe\r\n" + + "$5\r\n" + + "fghij\r\n" + + ":2\r\n" + + "*3\r\n" + + "$9\r\n" + + "subscribe\r\n" + + "$5\r\n" + + "klmno\r\n" + + ":2\r\n" + + "*3\r\n" + + "$7\r\n" + + "message\r\n" + + "$5\r\n" + + "abcde\r\n" + + "$10\r\n" + + "1234567890\r\n" + + "*3\r\n" + + "$7\r\n" + + "message\r\n" + + "$5\r\n" + + "klmno\r\n" + + "$10\r\n" + + "0987654321\r\n"; + + + @Test + public void testSubscribe() throws IOException { +// ByteChannel byteChannel = mock(ByteChannel.class); + OutputStream outputStream = mock(OutputStream.class); + Socket socket = mock(Socket.class ); + when(socket.getOutputStream()).thenReturn(outputStream); + PubSubConnection connection = new PubSubConnection(socket); + + connection.subscribe("foobar"); + + ArgumentCaptor captor = ArgumentCaptor.forClass(byte[].class); + verify(outputStream).write(captor.capture()); + + assertArrayEquals(captor.getValue(), "SUBSCRIBE foobar\r\n".getBytes()); + } + + @Test + public void testUnsubscribe() throws IOException { + OutputStream outputStream = mock(OutputStream.class); + Socket socket = mock(Socket.class ); + when(socket.getOutputStream()).thenReturn(outputStream); + PubSubConnection connection = new PubSubConnection(socket); + + connection.unsubscribe("bazbar"); + + ArgumentCaptor captor = ArgumentCaptor.forClass(byte[].class); + verify(outputStream).write(captor.capture()); + + assertArrayEquals(captor.getValue(), "UNSUBSCRIBE bazbar\r\n".getBytes()); + } + + @Test + public void testTricklyResponse() throws Exception { + InputStream inputStream = mockInputStreamFor(new TrickleInputStream(REPLY.getBytes())); + OutputStream outputStream = mock(OutputStream.class); + Socket socket = mock(Socket.class ); + when(socket.getOutputStream()).thenReturn(outputStream); + when(socket.getInputStream()).thenReturn(inputStream); + + PubSubConnection pubSubConnection = new PubSubConnection(socket); + readResponses(pubSubConnection); + } + + @Test + public void testFullResponse() throws Exception { + InputStream inputStream = mockInputStreamFor(new FullInputStream(REPLY.getBytes())); + OutputStream outputStream = mock(OutputStream.class); + Socket socket = mock(Socket.class ); + when(socket.getOutputStream()).thenReturn(outputStream); + when(socket.getInputStream()).thenReturn(inputStream); + + PubSubConnection pubSubConnection = new PubSubConnection(socket); + readResponses(pubSubConnection); + } + + @Test + public void testRandomLengthResponse() throws Exception { + InputStream inputStream = mockInputStreamFor(new RandomInputStream(REPLY.getBytes())); + OutputStream outputStream = mock(OutputStream.class); + Socket socket = mock(Socket.class ); + when(socket.getOutputStream()).thenReturn(outputStream); + when(socket.getInputStream()).thenReturn(inputStream); + + PubSubConnection pubSubConnection = new PubSubConnection(socket); + readResponses(pubSubConnection); + } + + private InputStream mockInputStreamFor(final MockInputStream stub) throws IOException { + InputStream result = mock(InputStream.class); + + when(result.read()).thenAnswer(new Answer() { + @Override + public Integer answer(InvocationOnMock invocationOnMock) throws Throwable { + return stub.read(); + } + }); + + when(result.read(any(byte[].class))).thenAnswer(new Answer() { + @Override + public Integer answer(InvocationOnMock invocationOnMock) throws Throwable { + byte[] buffer = (byte[])invocationOnMock.getArguments()[0]; + return stub.read(buffer, 0, buffer.length); + } + }); + + when(result.read(any(byte[].class), anyInt(), anyInt())).thenAnswer(new Answer() { + @Override + public Integer answer(InvocationOnMock invocationOnMock) throws Throwable { + byte[] buffer = (byte[]) invocationOnMock.getArguments()[0]; + int offset = (int) invocationOnMock.getArguments()[1]; + int length = (int) invocationOnMock.getArguments()[2]; + + return stub.read(buffer, offset, length); + } + }); + + return result; + } + + private void readResponses(PubSubConnection pubSubConnection) throws Exception { + PubSubReply reply = pubSubConnection.read(); + + assertEquals(reply.getType(), PubSubReply.Type.SUBSCRIBE); + assertEquals(reply.getChannel(), "abcde"); + assertFalse(reply.getContent().isPresent()); + + reply = pubSubConnection.read(); + + assertEquals(reply.getType(), PubSubReply.Type.SUBSCRIBE); + assertEquals(reply.getChannel(), "fghij"); + assertFalse(reply.getContent().isPresent()); + + reply = pubSubConnection.read(); + + assertEquals(reply.getType(), PubSubReply.Type.SUBSCRIBE); + assertEquals(reply.getChannel(), "klmno"); + assertFalse(reply.getContent().isPresent()); + + reply = pubSubConnection.read(); + + assertEquals(reply.getType(), PubSubReply.Type.MESSAGE); + assertEquals(reply.getChannel(), "abcde"); + assertArrayEquals(reply.getContent().get(), "1234567890".getBytes()); + + reply = pubSubConnection.read(); + + assertEquals(reply.getType(), PubSubReply.Type.MESSAGE); + assertEquals(reply.getChannel(), "klmno"); + assertArrayEquals(reply.getContent().get(), "0987654321".getBytes()); + } + + private interface MockInputStream { + public int read(); + public int read(byte[] input, int offset, int length); + } + + private static class TrickleInputStream implements MockInputStream { + + private final byte[] data; + private int index = 0; + + private TrickleInputStream(byte[] data) { + this.data = data; + } + + public int read() { + return data[index++]; + } + + public int read(byte[] input, int offset, int length) { + input[offset] = data[index++]; + return 1; + } + + } + + private static class FullInputStream implements MockInputStream { + + private final byte[] data; + private int index = 0; + + private FullInputStream(byte[] data) { + this.data = data; + } + + public int read() { + return data[index++]; + } + + public int read(byte[] input, int offset, int length) { + int amount = Math.min(data.length - index, length); + System.arraycopy(data, index, input, offset, amount); + index += length; + + return amount; + } + } + + private static class RandomInputStream implements MockInputStream { + private final byte[] data; + private int index = 0; + + private RandomInputStream(byte[] data) { + this.data = data; + } + + public int read() { + return data[index++]; + } + + public int read(byte[] input, int offset, int length) { + try { + int maxCopy = Math.min(data.length - index, length); + int randomCopy = SecureRandom.getInstance("SHA1PRNG").nextInt(maxCopy) + 1; + int copyAmount = Math.min(maxCopy, randomCopy); + + System.arraycopy(data, index, input, offset, copyAmount); + index += copyAmount; + + return copyAmount; + } catch (NoSuchAlgorithmException e) { + throw new AssertionError(e); + } + } + + } + +} diff --git a/src/test/java/org/whispersystems/dispatch/redis/protocol/ArrayReplyHeaderTest.java b/src/test/java/org/whispersystems/dispatch/redis/protocol/ArrayReplyHeaderTest.java new file mode 100644 index 000000000..1c930ceb3 --- /dev/null +++ b/src/test/java/org/whispersystems/dispatch/redis/protocol/ArrayReplyHeaderTest.java @@ -0,0 +1,51 @@ +package org.whispersystems.dispatch.redis.protocol; + + +import org.junit.Test; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; + +public class ArrayReplyHeaderTest { + + + @Test(expected = IOException.class) + public void testNull() throws IOException { + new ArrayReplyHeader(null); + } + + @Test(expected = IOException.class) + public void testBadPrefix() throws IOException { + new ArrayReplyHeader(":3"); + } + + @Test(expected = IOException.class) + public void testEmpty() throws IOException { + new ArrayReplyHeader(""); + } + + @Test(expected = IOException.class) + public void testTruncated() throws IOException { + new ArrayReplyHeader("*"); + } + + @Test(expected = IOException.class) + public void testBadNumber() throws IOException { + new ArrayReplyHeader("*ABC"); + } + + @Test + public void testValid() throws IOException { + assertEquals(4, new ArrayReplyHeader("*4").getElementCount()); + } + + + + + + + + + +} diff --git a/src/test/java/org/whispersystems/dispatch/redis/protocol/IntReplyHeaderTest.java b/src/test/java/org/whispersystems/dispatch/redis/protocol/IntReplyHeaderTest.java new file mode 100644 index 000000000..24881c93c --- /dev/null +++ b/src/test/java/org/whispersystems/dispatch/redis/protocol/IntReplyHeaderTest.java @@ -0,0 +1,36 @@ +package org.whispersystems.dispatch.redis.protocol; + + +import org.junit.Test; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; + +public class IntReplyHeaderTest { + + @Test(expected = IOException.class) + public void testNull() throws IOException { + new IntReply(null); + } + + @Test(expected = IOException.class) + public void testEmpty() throws IOException { + new IntReply(""); + } + + @Test(expected = IOException.class) + public void testBadNumber() throws IOException { + new IntReply(":A"); + } + + @Test(expected = IOException.class) + public void testBadFormat() throws IOException { + new IntReply("*"); + } + + @Test + public void testValid() throws IOException { + assertEquals(23, new IntReply(":23").getValue()); + } +} diff --git a/src/test/java/org/whispersystems/dispatch/redis/protocol/StringReplyHeaderTest.java b/src/test/java/org/whispersystems/dispatch/redis/protocol/StringReplyHeaderTest.java new file mode 100644 index 000000000..a49517df9 --- /dev/null +++ b/src/test/java/org/whispersystems/dispatch/redis/protocol/StringReplyHeaderTest.java @@ -0,0 +1,47 @@ +package org.whispersystems.dispatch.redis.protocol; + +import org.junit.Test; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; + +public class StringReplyHeaderTest { + + @Test + public void testNull() { + try { + new StringReplyHeader(null); + throw new AssertionError(); + } catch (IOException e) { + // good + } + } + + @Test + public void testBadNumber() { + try { + new StringReplyHeader("$100A"); + throw new AssertionError(); + } catch (IOException e) { + // good + } + } + + @Test + public void testBadPrefix() { + try { + new StringReplyHeader("*"); + throw new AssertionError(); + } catch (IOException e) { + // good + } + } + + @Test + public void testValid() throws IOException { + assertEquals(1000, new StringReplyHeader("$1000").getStringLength()); + } + + +} diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java index 32833cd2f..40cf98200 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java @@ -44,45 +44,20 @@ import static org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMe public class WebSocketConnectionTest { -// private static final ObjectMapper mapper = new ObjectMapper(); - private static final String VALID_USER = "+14152222222"; private static final String INVALID_USER = "+14151111111"; private static final String VALID_PASSWORD = "secure"; private static final String INVALID_PASSWORD = "insecure"; -// private static final StoredMessages storedMessages = mock(StoredMessages.class); private static final AccountAuthenticator accountAuthenticator = mock(AccountAuthenticator.class); private static final AccountsManager accountsManager = mock(AccountsManager.class); private static final PubSubManager pubSubManager = mock(PubSubManager.class ); private static final Account account = mock(Account.class ); private static final Device device = mock(Device.class ); private static final UpgradeRequest upgradeRequest = mock(UpgradeRequest.class ); -// private static final Session session = mock(Session.class ); private static final PushSender pushSender = mock(PushSender.class); - @Test - public void testCloseExisting() throws Exception { - MessagesManager storedMessages = mock(MessagesManager.class ); - WebSocketConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, pushSender, storedMessages, pubSubManager); - WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); - Account account = mock(Account.class ); - Device device = mock(Device.class ); - - when(sessionContext.getAuthenticated(Account.class)).thenReturn(Optional.of(account)); - when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); - when(account.getNumber()).thenReturn("+14157777777"); - when(device.getId()).thenReturn(1L); - - connectListener.onWebSocketConnect(sessionContext); - - ArgumentCaptor message = ArgumentCaptor.forClass(PubSubProtos.PubSubMessage.class); - - verify(pubSubManager).publish(eq(new WebsocketAddress("+14157777777", 1L)), message.capture()); - assertEquals(message.getValue().getType().getNumber(), PubSubProtos.PubSubMessage.Type.CLOSE_VALUE); - } - @Test public void testCredentials() throws Exception { MessagesManager storedMessages = mock(MessagesManager.class); @@ -98,10 +73,6 @@ public class WebSocketConnectionTest { when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); -// when(session.getUpgradeRequest()).thenReturn(upgradeRequest); -// -// WebsocketController controller = new WebsocketController(accountAuthenticator, accountsManager, pushSender, pubSubManager, storedMessages); - when(upgradeRequest.getParameterMap()).thenReturn(new HashMap() {{ put("login", new String[] {VALID_USER}); put("password", new String[] {VALID_PASSWORD}); @@ -114,13 +85,6 @@ public class WebSocketConnectionTest { verify(sessionContext).addListener(any(WebSocketSessionContext.WebSocketEventListener.class)); -// -// controller.onWebSocketConnect(session); - -// verify(session, never()).close(); -// verify(session, never()).close(any(CloseStatus.class)); -// verify(session, never()).close(anyInt(), anyString()); - when(upgradeRequest.getParameterMap()).thenReturn(new HashMap() {{ put("login", new String[] {INVALID_USER}); put("password", new String[] {INVALID_PASSWORD}); @@ -128,15 +92,6 @@ public class WebSocketConnectionTest { account = webSocketAuthenticator.authenticate(upgradeRequest); assertFalse(account.isPresent()); -// when(sessionContext.getAuthenticated(Account.class)).thenReturn(account); -// -// WebSocketClient client = mock(WebSocketClient.class); -// when(sessionContext.getClient()).thenReturn(client); -// -// connectListener.onWebSocketConnect(sessionContext); -// -// verify(sessionContext, times(1)).addListener(any(WebSocketSessionContext.WebSocketEventListener.class)); -// verify(client).close(eq(4001), anyString()); } @Test @@ -183,12 +138,11 @@ public class WebSocketConnectionTest { } }); + WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId()); WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender, storedMessages, - pubSubManager, account, device, client); + account, device, client); - connection.onConnected(); - - verify(pubSubManager).subscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq((connection))); + connection.onDispatchSubscribed(websocketAddress.serialize()); verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(Optional.class)); assertTrue(futures.size() == 3); @@ -205,11 +159,10 @@ public class WebSocketConnectionTest { add(createMessage("sender2", 3333, false, "third")); }}; -// verify(pushSender, times(2)).sendMessage(eq(account), eq(device), any(OutgoingMessageSignal.class)); verify(pushSender, times(1)).sendMessage(eq(sender1), eq(sender1device), any(OutgoingMessageSignal.class)); - connection.onConnectionLost(); - verify(pubSubManager).unsubscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq(connection)); + connection.onDispatchUnsubscribed(websocketAddress.serialize()); + verify(client).close(anyInt(), anyString()); } private OutgoingMessageSignal createMessage(String sender, long timestamp, boolean receipt, String content) {