diff --git a/pom.xml b/pom.xml
index 9859ea977..b121da047 100644
--- a/pom.xml
+++ b/pom.xml
@@ -132,7 +132,6 @@
0.2.3
-
diff --git a/src/main/java/org/whispersystems/dispatch/DispatchChannel.java b/src/main/java/org/whispersystems/dispatch/DispatchChannel.java
new file mode 100644
index 000000000..9438f9492
--- /dev/null
+++ b/src/main/java/org/whispersystems/dispatch/DispatchChannel.java
@@ -0,0 +1,7 @@
+package org.whispersystems.dispatch;
+
+public interface DispatchChannel {
+ public void onDispatchMessage(String channel, byte[] message);
+ public void onDispatchSubscribed(String channel);
+ public void onDispatchUnsubscribed(String channel);
+}
diff --git a/src/main/java/org/whispersystems/dispatch/DispatchManager.java b/src/main/java/org/whispersystems/dispatch/DispatchManager.java
new file mode 100644
index 000000000..fe9e1d267
--- /dev/null
+++ b/src/main/java/org/whispersystems/dispatch/DispatchManager.java
@@ -0,0 +1,191 @@
+package org.whispersystems.dispatch;
+
+import com.google.common.base.Optional;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.dispatch.io.RedisPubSubConnectionFactory;
+import org.whispersystems.dispatch.redis.PubSubConnection;
+import org.whispersystems.dispatch.redis.PubSubReply;
+
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.Executor;
+import java.util.concurrent.Executors;
+
+public class DispatchManager extends Thread {
+
+ private final Logger logger = LoggerFactory.getLogger(DispatchManager.class);
+ private final Executor executor = Executors.newCachedThreadPool();
+ private final Map subscriptions = new HashMap<>();
+
+ private final Optional deadLetterChannel;
+ private final RedisPubSubConnectionFactory redisPubSubConnectionFactory;
+
+ private PubSubConnection pubSubConnection;
+ private volatile boolean running;
+
+ public DispatchManager(RedisPubSubConnectionFactory redisPubSubConnectionFactory,
+ Optional deadLetterChannel)
+ {
+ this.redisPubSubConnectionFactory = redisPubSubConnectionFactory;
+ this.deadLetterChannel = deadLetterChannel;
+ }
+
+ @Override
+ public void start() {
+ this.pubSubConnection = redisPubSubConnectionFactory.connect();
+ this.running = true;
+ super.start();
+ }
+
+ public void shutdown() {
+ this.running = false;
+ this.pubSubConnection.close();
+ }
+
+ public void subscribe(String name, DispatchChannel dispatchChannel) {
+ Optional previous;
+
+ synchronized (subscriptions) {
+ previous = Optional.fromNullable(subscriptions.get(name));
+ subscriptions.put(name, dispatchChannel);
+ }
+
+ try {
+ pubSubConnection.subscribe(name);
+ } catch (IOException e) {
+ logger.warn("Subscription error", e);
+ }
+
+ if (previous.isPresent()) {
+ dispatchUnsubscription(name, previous.get());
+ }
+ }
+
+ public void unsubscribe(String name, DispatchChannel channel) {
+ final Optional subscription;
+
+ synchronized (subscriptions) {
+ subscription = Optional.fromNullable(subscriptions.get(name));
+
+ if (subscription.isPresent() && subscription.get() == channel) {
+ subscriptions.remove(name);
+ }
+ }
+
+ if (subscription.isPresent()) {
+ try {
+ pubSubConnection.unsubscribe(name);
+ } catch (IOException e) {
+ logger.warn("Unsubscribe error", e);
+ }
+
+ dispatchUnsubscription(name, subscription.get());
+ }
+ }
+
+ @Override
+ public void run() {
+ while (running) {
+ try {
+ PubSubReply reply = pubSubConnection.read();
+
+ switch (reply.getType()) {
+ case UNSUBSCRIBE: break;
+ case SUBSCRIBE: dispatchSubscribe(reply); break;
+ case MESSAGE: dispatchMessage(reply); break;
+ default: throw new AssertionError("Unknown pubsub reply type! " + reply.getType());
+ }
+ } catch (IOException e) {
+ logger.warn("***** PubSub Connection Error *****", e);
+ if (running) {
+ this.pubSubConnection.close();
+ this.pubSubConnection = redisPubSubConnectionFactory.connect();
+ resubscribe();
+ }
+ }
+ }
+
+ logger.warn("DispatchManager Shutting Down...");
+ }
+
+ private void dispatchSubscribe(final PubSubReply reply) {
+ final Optional subscription;
+
+ synchronized (subscriptions) {
+ subscription = Optional.fromNullable(subscriptions.get(reply.getChannel()));
+ }
+
+ if (subscription.isPresent()) {
+ dispatchSubscription(reply.getChannel(), subscription.get());
+ } else {
+ logger.warn("Received subscribe event for non-existing channel: " + reply.getChannel());
+ }
+ }
+
+ private void dispatchMessage(PubSubReply reply) {
+ Optional subscription;
+
+ synchronized (subscriptions) {
+ subscription = Optional.fromNullable(subscriptions.get(reply.getChannel()));
+ }
+
+ if (subscription.isPresent()) {
+ dispatchMessage(reply.getChannel(), subscription.get(), reply.getContent().get());
+ } else if (deadLetterChannel.isPresent()) {
+ dispatchMessage(reply.getChannel(), deadLetterChannel.get(), reply.getContent().get());
+ } else {
+ logger.warn("Received message for non-existing channel, with no dead letter handler: " + reply.getChannel());
+ }
+ }
+
+ private void resubscribe() {
+ final Collection names;
+
+ synchronized (subscriptions) {
+ names = subscriptions.keySet();
+ }
+
+ new Thread() {
+ @Override
+ public void run() {
+ try {
+ for (String name : names) {
+ pubSubConnection.subscribe(name);
+ }
+ } catch (IOException e) {
+ logger.warn("***** RESUBSCRIPTION ERROR *****", e);
+ }
+ }
+ }.start();
+ }
+
+ private void dispatchMessage(final String name, final DispatchChannel channel, final byte[] message) {
+ executor.execute(new Runnable() {
+ @Override
+ public void run() {
+ channel.onDispatchMessage(name, message);
+ }
+ });
+ }
+
+ private void dispatchSubscription(final String name, final DispatchChannel channel) {
+ executor.execute(new Runnable() {
+ @Override
+ public void run() {
+ channel.onDispatchSubscribed(name);
+ }
+ });
+ }
+
+ private void dispatchUnsubscription(final String name, final DispatchChannel channel) {
+ executor.execute(new Runnable() {
+ @Override
+ public void run() {
+ channel.onDispatchUnsubscribed(name);
+ }
+ });
+ }
+}
diff --git a/src/main/java/org/whispersystems/dispatch/io/RedisInputStream.java b/src/main/java/org/whispersystems/dispatch/io/RedisInputStream.java
new file mode 100644
index 000000000..2245db9f3
--- /dev/null
+++ b/src/main/java/org/whispersystems/dispatch/io/RedisInputStream.java
@@ -0,0 +1,64 @@
+package org.whispersystems.dispatch.io;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+
+public class RedisInputStream {
+
+ private static final byte CR = 0x0D;
+ private static final byte LF = 0x0A;
+
+ private final InputStream inputStream;
+
+ public RedisInputStream(InputStream inputStream) {
+ this.inputStream = inputStream;
+ }
+
+ public String readLine() throws IOException {
+ ByteArrayOutputStream boas = new ByteArrayOutputStream();
+
+ boolean foundCr = false;
+
+ while (true) {
+ int character = inputStream.read();
+
+ if (character == -1) {
+ throw new IOException("Stream closed!");
+ }
+
+ boas.write(character);
+
+ if (foundCr && character == LF) break;
+ else if (character == CR) foundCr = true;
+ else if (foundCr) foundCr = false;
+ }
+
+ byte[] data = boas.toByteArray();
+ return new String(data, 0, data.length-2);
+ }
+
+ public byte[] readFully(int size) throws IOException {
+ byte[] result = new byte[size];
+ int offset = 0;
+ int remaining = result.length;
+
+ while (remaining > 0) {
+ int read = inputStream.read(result, offset, remaining);
+
+ if (read < 0) {
+ throw new IOException("Stream closed!");
+ }
+
+ offset += read;
+ remaining -= read;
+ }
+
+ return result;
+ }
+
+ public void close() throws IOException {
+ inputStream.close();
+ }
+
+}
diff --git a/src/main/java/org/whispersystems/dispatch/io/RedisPubSubConnectionFactory.java b/src/main/java/org/whispersystems/dispatch/io/RedisPubSubConnectionFactory.java
new file mode 100644
index 000000000..d93b58072
--- /dev/null
+++ b/src/main/java/org/whispersystems/dispatch/io/RedisPubSubConnectionFactory.java
@@ -0,0 +1,9 @@
+package org.whispersystems.dispatch.io;
+
+import org.whispersystems.dispatch.redis.PubSubConnection;
+
+public interface RedisPubSubConnectionFactory {
+
+ public PubSubConnection connect();
+
+}
diff --git a/src/main/java/org/whispersystems/dispatch/redis/PubSubConnection.java b/src/main/java/org/whispersystems/dispatch/redis/PubSubConnection.java
new file mode 100644
index 000000000..599c3930a
--- /dev/null
+++ b/src/main/java/org/whispersystems/dispatch/redis/PubSubConnection.java
@@ -0,0 +1,119 @@
+package org.whispersystems.dispatch.redis;
+
+import com.google.common.base.Optional;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.dispatch.io.RedisInputStream;
+import org.whispersystems.dispatch.redis.protocol.ArrayReplyHeader;
+import org.whispersystems.dispatch.redis.protocol.IntReply;
+import org.whispersystems.dispatch.redis.protocol.StringReplyHeader;
+import org.whispersystems.dispatch.util.Util;
+
+import java.io.BufferedInputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.net.Socket;
+import java.util.Arrays;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+public class PubSubConnection {
+
+ private final Logger logger = LoggerFactory.getLogger(PubSubConnection.class);
+
+ private static final byte[] UNSUBSCRIBE_TYPE = {'u', 'n', 's', 'u', 'b', 's', 'c', 'r', 'i', 'b', 'e' };
+ private static final byte[] SUBSCRIBE_TYPE = {'s', 'u', 'b', 's', 'c', 'r', 'i', 'b', 'e' };
+ private static final byte[] MESSAGE_TYPE = {'m', 'e', 's', 's', 'a', 'g', 'e' };
+
+ private static final byte[] SUBSCRIBE_COMMAND = {'S', 'U', 'B', 'S', 'C', 'R', 'I', 'B', 'E', ' ' };
+ private static final byte[] UNSUBSCRIBE_COMMAND = {'U', 'N', 'S', 'U', 'B', 'S', 'C', 'R', 'I', 'B', 'E', ' '};
+ private static final byte[] CRLF = {'\r', '\n' };
+
+ private final OutputStream outputStream;
+ private final RedisInputStream inputStream;
+ private final Socket socket;
+ private final AtomicBoolean closed;
+
+ public PubSubConnection(Socket socket) throws IOException {
+ this.socket = socket;
+ this.outputStream = socket.getOutputStream();
+ this.inputStream = new RedisInputStream(new BufferedInputStream(socket.getInputStream()));
+ this.closed = new AtomicBoolean(false);
+ }
+
+ public void subscribe(String channelName) throws IOException {
+ if (closed.get()) throw new IOException("Connection closed!");
+
+ byte[] command = Util.combine(SUBSCRIBE_COMMAND, channelName.getBytes(), CRLF);
+ outputStream.write(command);
+ }
+
+ public void unsubscribe(String channelName) throws IOException {
+ if (closed.get()) throw new IOException("Connection closed!");
+
+ byte[] command = Util.combine(UNSUBSCRIBE_COMMAND, channelName.getBytes(), CRLF);
+ outputStream.write(command);
+ }
+
+ public PubSubReply read() throws IOException {
+ if (closed.get()) throw new IOException("Connection closed!");
+
+ ArrayReplyHeader replyHeader = new ArrayReplyHeader(inputStream.readLine());
+
+ if (replyHeader.getElementCount() != 3) {
+ throw new IOException("Received array reply header with strange count: " + replyHeader.getElementCount());
+ }
+
+ StringReplyHeader replyTypeHeader = new StringReplyHeader(inputStream.readLine());
+ byte[] replyType = inputStream.readFully(replyTypeHeader.getStringLength());
+ inputStream.readLine();
+
+ if (Arrays.equals(SUBSCRIBE_TYPE, replyType)) return readSubscribeReply();
+ else if (Arrays.equals(UNSUBSCRIBE_TYPE, replyType)) return readUnsubscribeReply();
+ else if (Arrays.equals(MESSAGE_TYPE, replyType)) return readMessageReply();
+ else throw new IOException("Unknown reply type: " + new String(replyType));
+ }
+
+ public void close() {
+ try {
+ this.closed.set(true);
+ this.inputStream.close();
+ this.outputStream.close();
+ this.socket.close();
+ } catch (IOException e) {
+ logger.warn("Exception while closing", e);
+ }
+ }
+
+ private PubSubReply readMessageReply() throws IOException {
+ StringReplyHeader channelNameHeader = new StringReplyHeader(inputStream.readLine());
+ byte[] channelName = inputStream.readFully(channelNameHeader.getStringLength());
+ inputStream.readLine();
+
+ StringReplyHeader messageHeader = new StringReplyHeader(inputStream.readLine());
+ byte[] message = inputStream.readFully(messageHeader.getStringLength());
+ inputStream.readLine();
+
+ return new PubSubReply(PubSubReply.Type.MESSAGE, new String(channelName), Optional.of(message));
+ }
+
+ private PubSubReply readUnsubscribeReply() throws IOException {
+ String channelName = readSubscriptionReply();
+ return new PubSubReply(PubSubReply.Type.UNSUBSCRIBE, channelName, Optional.absent());
+ }
+
+ private PubSubReply readSubscribeReply() throws IOException {
+ String channelName = readSubscriptionReply();
+ return new PubSubReply(PubSubReply.Type.SUBSCRIBE, channelName, Optional.absent());
+ }
+
+ private String readSubscriptionReply() throws IOException {
+ StringReplyHeader channelNameHeader = new StringReplyHeader(inputStream.readLine());
+ byte[] channelName = inputStream.readFully(channelNameHeader.getStringLength());
+ inputStream.readLine();
+
+ IntReply subscriptionCount = new IntReply(inputStream.readLine());
+
+ return new String(channelName);
+ }
+
+}
diff --git a/src/main/java/org/whispersystems/dispatch/redis/PubSubReply.java b/src/main/java/org/whispersystems/dispatch/redis/PubSubReply.java
new file mode 100644
index 000000000..d57797fae
--- /dev/null
+++ b/src/main/java/org/whispersystems/dispatch/redis/PubSubReply.java
@@ -0,0 +1,35 @@
+package org.whispersystems.dispatch.redis;
+
+import com.google.common.base.Optional;
+
+public class PubSubReply {
+
+ public enum Type {
+ MESSAGE,
+ SUBSCRIBE,
+ UNSUBSCRIBE
+ }
+
+ private final Type type;
+ private final String channel;
+ private final Optional content;
+
+ public PubSubReply(Type type, String channel, Optional content) {
+ this.type = type;
+ this.channel = channel;
+ this.content = content;
+ }
+
+ public Type getType() {
+ return type;
+ }
+
+ public String getChannel() {
+ return channel;
+ }
+
+ public Optional getContent() {
+ return content;
+ }
+
+}
diff --git a/src/main/java/org/whispersystems/dispatch/redis/protocol/ArrayReplyHeader.java b/src/main/java/org/whispersystems/dispatch/redis/protocol/ArrayReplyHeader.java
new file mode 100644
index 000000000..1d3528c1b
--- /dev/null
+++ b/src/main/java/org/whispersystems/dispatch/redis/protocol/ArrayReplyHeader.java
@@ -0,0 +1,24 @@
+package org.whispersystems.dispatch.redis.protocol;
+
+import java.io.IOException;
+
+public class ArrayReplyHeader {
+
+ private final int elementCount;
+
+ public ArrayReplyHeader(String header) throws IOException {
+ if (header == null || header.length() < 2 || header.charAt(0) != '*') {
+ throw new IOException("Invalid array reply header: " + header);
+ }
+
+ try {
+ this.elementCount = Integer.parseInt(header.substring(1));
+ } catch (NumberFormatException e) {
+ throw new IOException(e);
+ }
+ }
+
+ public int getElementCount() {
+ return elementCount;
+ }
+}
diff --git a/src/main/java/org/whispersystems/dispatch/redis/protocol/IntReply.java b/src/main/java/org/whispersystems/dispatch/redis/protocol/IntReply.java
new file mode 100644
index 000000000..7c7f775b3
--- /dev/null
+++ b/src/main/java/org/whispersystems/dispatch/redis/protocol/IntReply.java
@@ -0,0 +1,24 @@
+package org.whispersystems.dispatch.redis.protocol;
+
+import java.io.IOException;
+
+public class IntReply {
+
+ private final int value;
+
+ public IntReply(String reply) throws IOException {
+ if (reply == null || reply.length() < 2 || reply.charAt(0) != ':') {
+ throw new IOException("Invalid int reply: " + reply);
+ }
+
+ try {
+ this.value = Integer.parseInt(reply.substring(1));
+ } catch (NumberFormatException e) {
+ throw new IOException(e);
+ }
+ }
+
+ public int getValue() {
+ return value;
+ }
+}
diff --git a/src/main/java/org/whispersystems/dispatch/redis/protocol/StringReplyHeader.java b/src/main/java/org/whispersystems/dispatch/redis/protocol/StringReplyHeader.java
new file mode 100644
index 000000000..4a6030e21
--- /dev/null
+++ b/src/main/java/org/whispersystems/dispatch/redis/protocol/StringReplyHeader.java
@@ -0,0 +1,24 @@
+package org.whispersystems.dispatch.redis.protocol;
+
+import java.io.IOException;
+
+public class StringReplyHeader {
+
+ private final int stringLength;
+
+ public StringReplyHeader(String header) throws IOException {
+ if (header == null || header.length() < 2 || header.charAt(0) != '$') {
+ throw new IOException("Invalid string reply header: " + header);
+ }
+
+ try {
+ this.stringLength = Integer.parseInt(header.substring(1));
+ } catch (NumberFormatException e) {
+ throw new IOException(e);
+ }
+ }
+
+ public int getStringLength() {
+ return stringLength;
+ }
+}
diff --git a/src/main/java/org/whispersystems/dispatch/util/Util.java b/src/main/java/org/whispersystems/dispatch/util/Util.java
new file mode 100644
index 000000000..245466cea
--- /dev/null
+++ b/src/main/java/org/whispersystems/dispatch/util/Util.java
@@ -0,0 +1,36 @@
+package org.whispersystems.dispatch.util;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+
+public class Util {
+
+ public static byte[] combine(byte[]... elements) {
+ try {
+ int sum = 0;
+
+ for (byte[] element : elements) {
+ sum += element.length;
+ }
+
+ ByteArrayOutputStream baos = new ByteArrayOutputStream(sum);
+
+ for (byte[] element : elements) {
+ baos.write(element);
+ }
+
+ return baos.toByteArray();
+ } catch (IOException e) {
+ throw new AssertionError(e);
+ }
+ }
+
+
+ public static void sleep(long millis) {
+ try {
+ Thread.sleep(millis);
+ } catch (InterruptedException e) {
+ throw new AssertionError(e);
+ }
+ }
+}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
index 8b79571d1..cb712e1bc 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
@@ -24,6 +24,8 @@ import com.sun.jersey.api.client.Client;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.eclipse.jetty.servlets.CrossOriginFilter;
import org.skife.jdbi.v2.DBI;
+import org.whispersystems.dispatch.DispatchChannel;
+import org.whispersystems.dispatch.DispatchManager;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.FederatedPeerAuthenticator;
import org.whispersystems.textsecuregcm.auth.MultiBasicAuthProvider;
@@ -147,10 +149,11 @@ public class WhisperServerService extends Applicationof(deadLetterHandler));
+ PubSubManager pubSubManager = new PubSubManager(cacheClient, dispatchManager);
PushServiceClient pushServiceClient = new PushServiceClient(httpClient, config.getPushConfiguration());
WebsocketSender websocketSender = new WebsocketSender(messagesManager, pubSubManager);
AccountAuthenticator deviceAuthenticator = new AccountAuthenticator(accountsManager);
@@ -173,6 +177,7 @@ public class WhisperServerService extends Application authorizationKey = config.getRedphoneConfiguration().getAuthorizationKey();
+ environment.lifecycle().manage(pubSubManager);
environment.lifecycle().manage(feedbackHandler);
AttachmentController attachmentController = new AttachmentController(rateLimiters, federatedClientManager, urlSigner);
@@ -263,5 +268,4 @@ public class WhisperServerService extends Application listeners = new HashMap<>();
- private final Executor threaded = Executors.newCachedThreadPool();
+ private final Logger logger = LoggerFactory.getLogger(PubSubManager.class);
+ private final DispatchManager dispatchManager;
private final JedisPool jedisPool;
- private final DeadLetterHandler deadLetterHandler;
private boolean subscribed = false;
- public PubSubManager(JedisPool jedisPool, DeadLetterHandler deadLetterHandler) {
+ public PubSubManager(JedisPool jedisPool, DispatchManager dispatchManager) {
+ this.dispatchManager = dispatchManager;
this.jedisPool = jedisPool;
- this.deadLetterHandler = deadLetterHandler;
- initializePubSubWorker();
- waitForSubscription();
}
- public synchronized void subscribe(WebsocketAddress address, PubSubListener listener) {
- String serializedAddress = address.serialize();
+ @Override
+ public void start() throws Exception {
+ this.dispatchManager.start();
- listeners.put(serializedAddress, listener);
- baseListener.subscribe(serializedAddress.getBytes());
- }
+ KeepaliveDispatchChannel keepaliveDispatchChannel = new KeepaliveDispatchChannel();
+ this.dispatchManager.subscribe(KEEPALIVE_CHANNEL, keepaliveDispatchChannel);
- public synchronized void unsubscribe(WebsocketAddress address, PubSubListener listener) {
- String serializedAddress = address.serialize();
-
- if (listeners.get(serializedAddress) == listener) {
- listeners.remove(serializedAddress);
- baseListener.unsubscribe(serializedAddress.getBytes());
+ synchronized (this) {
+ while (!subscribed) wait(0);
}
+
+ new KeepaliveSender().start();
}
- public synchronized boolean publish(WebsocketAddress address, PubSubMessage message) {
+ @Override
+ public void stop() throws Exception {
+ dispatchManager.shutdown();
+ }
+
+ public void subscribe(WebsocketAddress address, DispatchChannel channel) {
+ String serializedAddress = address.serialize();
+ dispatchManager.subscribe(serializedAddress, channel);
+ }
+
+ public void unsubscribe(WebsocketAddress address, DispatchChannel dispatchChannel) {
+ String serializedAddress = address.serialize();
+ dispatchManager.unsubscribe(serializedAddress, dispatchChannel);
+ }
+
+ public boolean publish(WebsocketAddress address, PubSubMessage message) {
return publish(address.serialize().getBytes(), message);
}
- private synchronized boolean publish(byte[] channel, PubSubMessage message) {
+ private boolean publish(byte[] channel, PubSubMessage message) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.publish(channel, message.toByteArray()) != 0;
}
}
- private synchronized void resubscribeAll() {
- for (String serializedAddress : listeners.keySet()) {
- baseListener.subscribe(serializedAddress.getBytes());
- }
- }
-
- private synchronized void waitForSubscription() {
- try {
- while (!subscribed) {
- wait();
- }
- } catch (InterruptedException e) {
- throw new AssertionError(e);
- }
- }
-
- private void initializePubSubWorker() {
- new Thread("PubSubListener") {
- @Override
- public void run() {
- for (;;) {
- logger.info("Starting Redis PubSub Subscriber...");
-
- try (Jedis jedis = jedisPool.getResource()) {
- jedis.subscribe(baseListener, KEEPALIVE_CHANNEL);
- logger.warn("**** Unsubscribed from holding channel!!! ******");
- } catch (Throwable t) {
- logger.warn("*** SUBSCRIBER CONNECTION CLOSED", t);
- }
- }
- }
- }.start();
-
- new Thread("PubSubKeepAlive") {
- @Override
- public void run() {
- for (;;) {
- try {
- Thread.sleep(20000);
- publish(KEEPALIVE_CHANNEL, PubSubMessage.newBuilder()
- .setType(PubSubMessage.Type.KEEPALIVE)
- .build());
- } catch (Throwable e) {
- logger.warn("KEEPALIVE PUBLISH EXCEPTION: ", e);
- }
- }
- }
- }.start();
- }
-
- private class SubscriptionListener extends BinaryJedisPubSub {
+ private class KeepaliveDispatchChannel implements DispatchChannel {
@Override
- public void onMessage(final byte[] channel, final byte[] message) {
- if (Arrays.equals(KEEPALIVE_CHANNEL, channel)) {
- return;
- }
-
- final PubSubListener listener;
-
- synchronized (PubSubManager.this) {
- listener = listeners.get(new String(channel));
- }
-
- threaded.execute(new Runnable() {
- @Override
- public void run() {
- try {
- PubSubMessage receivedMessage = PubSubMessage.parseFrom(message);
-
- if (listener != null) listener.onPubSubMessage(receivedMessage);
- else deadLetterHandler.handle(channel, receivedMessage);
- } catch (InvalidProtocolBufferException e) {
- logger.warn("Error parsing PubSub protobuf", e);
- }
- }
- });
+ public void onDispatchMessage(String channel, byte[] message) {
+ // Good
}
@Override
- public void onPMessage(byte[] s, byte[] s2, byte[] s3) {
- logger.warn("Received PMessage!");
- }
-
- @Override
- public void onSubscribe(byte[] channel, int count) {
- if (Arrays.equals(KEEPALIVE_CHANNEL, channel)) {
+ public void onDispatchSubscribed(String channel) {
+ if (KEEPALIVE_CHANNEL.equals(channel)) {
synchronized (PubSubManager.this) {
subscribed = true;
PubSubManager.this.notifyAll();
}
-
- threaded.execute(new Runnable() {
- @Override
- public void run() {
- resubscribeAll();
- }
- });
}
}
@Override
- public void onUnsubscribe(byte[] s, int i) {}
+ public void onDispatchUnsubscribed(String channel) {
+ logger.warn("***** KEEPALIVE CHANNEL UNSUBSCRIBED *****");
+ }
+ }
+ private class KeepaliveSender extends Thread {
@Override
- public void onPUnsubscribe(byte[] s, int i) {}
-
- @Override
- public void onPSubscribe(byte[] s, int i) {}
+ public void run() {
+ while (true) {
+ try {
+ Thread.sleep(20000);
+ publish(KEEPALIVE_CHANNEL.getBytes(), PubSubMessage.newBuilder()
+ .setType(PubSubMessage.Type.KEEPALIVE)
+ .build());
+ } catch (Throwable e) {
+ logger.warn("***** KEEPALIVE EXCEPTION ******", e);
+ }
+ }
+ }
}
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java
index df11cf9be..8ae561cf5 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java
@@ -1,5 +1,8 @@
package org.whispersystems.textsecuregcm.websocket;
+import com.codahale.metrics.Histogram;
+import com.codahale.metrics.MetricRegistry;
+import com.codahale.metrics.SharedMetricRegistries;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.push.PushSender;
@@ -8,14 +11,19 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
-import org.whispersystems.textsecuregcm.storage.PubSubProtos;
+import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener;
+import static com.codahale.metrics.MetricRegistry.name;
+
public class AuthenticatedConnectListener implements WebSocketConnectListener {
- private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class);
+ private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class);
+ private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
+ private static final Histogram durationHistogram = metricRegistry.histogram(name(WebSocketConnection.class, "connected_duration"));
+
private final AccountsManager accountsManager;
private final PushSender pushSender;
@@ -33,23 +41,22 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
@Override
public void onWebSocketConnect(WebSocketSessionContext context) {
- Account account = context.getAuthenticated(Account.class).get();
- Device device = account.getAuthenticatedDevice().get();
+ final Account account = context.getAuthenticated(Account.class).get();
+ final Device device = account.getAuthenticatedDevice().get();
+ final long connectTime = System.currentTimeMillis();
+ final WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
+ final WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender,
+ messagesManager, account, device,
+ context.getClient());
updateLastSeen(account, device);
- closeExistingDeviceConnection(account, device);
-
- final WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender,
- messagesManager, pubSubManager,
- account, device,
- context.getClient());
-
- connection.onConnected();
+ pubSubManager.subscribe(address, connection);
context.addListener(new WebSocketSessionContext.WebSocketEventListener() {
@Override
public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) {
- connection.onConnectionLost();
+ pubSubManager.unsubscribe(address, connection);
+ durationHistogram.update(System.currentTimeMillis() - connectTime);
}
});
}
@@ -60,12 +67,5 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
accountsManager.update(account);
}
}
-
- private void closeExistingDeviceConnection(Account account, Device device) {
- pubSubManager.publish(new WebsocketAddress(account.getNumber(), device.getId()),
- PubSubProtos.PubSubMessage.newBuilder()
- .setType(PubSubProtos.PubSubMessage.Type.CLOSE)
- .build());
- }
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/DeadLetterHandler.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/DeadLetterHandler.java
index 3e8504150..516609a67 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/websocket/DeadLetterHandler.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/DeadLetterHandler.java
@@ -3,11 +3,13 @@ package org.whispersystems.textsecuregcm.websocket;
import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
+import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
-public class DeadLetterHandler {
+public class DeadLetterHandler implements DispatchChannel {
private final Logger logger = LoggerFactory.getLogger(DeadLetterHandler.class);
@@ -17,14 +19,16 @@ public class DeadLetterHandler {
this.messagesManager = messagesManager;
}
- public void handle(byte[] channel, PubSubProtos.PubSubMessage pubSubMessage) {
+ @Override
+ public void onDispatchMessage(String channel, byte[] data) {
try {
- WebsocketAddress address = new WebsocketAddress(new String(channel));
+ logger.warn("Handling dead letter to: " + channel);
- logger.warn("Handling dead letter to: " + address);
+ WebsocketAddress address = new WebsocketAddress(channel);
+ PubSubMessage pubSubMessage = PubSubMessage.parseFrom(data);
switch (pubSubMessage.getType().getNumber()) {
- case PubSubProtos.PubSubMessage.Type.DELIVER_VALUE:
+ case PubSubMessage.Type.DELIVER_VALUE:
OutgoingMessageSignal message = OutgoingMessageSignal.parseFrom(pubSubMessage.getContent());
messagesManager.insert(address.getNumber(), address.getDeviceId(), message);
break;
@@ -36,4 +40,13 @@ public class DeadLetterHandler {
}
}
+ @Override
+ public void onDispatchSubscribed(String channel) {
+ logger.warn("DeadLetterHandler subscription notice! " + channel);
+ }
+
+ @Override
+ public void onDispatchUnsubscribed(String channel) {
+ logger.warn("DeadLetterHandler unsubscribe notice! " + channel);
+ }
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java
index 405a86250..0c903ebc8 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java
@@ -14,13 +14,15 @@ public class ProvisioningConnectListener implements WebSocketConnectListener {
@Override
public void onWebSocketConnect(WebSocketSessionContext context) {
- final ProvisioningConnection connection = new ProvisioningConnection(pubSubManager, context.getClient());
- connection.onConnected();
+ final ProvisioningConnection connection = new ProvisioningConnection(context.getClient());
+ final ProvisioningAddress provisioningAddress = ProvisioningAddress.generate();
+
+ pubSubManager.subscribe(provisioningAddress, connection);
context.addListener(new WebSocketSessionContext.WebSocketEventListener() {
@Override
public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) {
- connection.onConnectionLost();
+ pubSubManager.unsubscribe(provisioningAddress, connection);
}
});
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnection.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnection.java
index 817758304..1fb0af3c2 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnection.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnection.java
@@ -4,58 +4,62 @@ import com.google.common.base.Optional;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
+import com.google.protobuf.InvalidProtocolBufferException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.textsecuregcm.entities.MessageProtos.ProvisioningUuid;
-import org.whispersystems.textsecuregcm.storage.PubSubListener;
-import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
-public class ProvisioningConnection implements PubSubListener {
+public class ProvisioningConnection implements DispatchChannel {
- private final PubSubManager pubSubManager;
- private final ProvisioningAddress provisioningAddress;
- private final WebSocketClient client;
+ private final Logger logger = LoggerFactory.getLogger(ProvisioningConnection.class);
- public ProvisioningConnection(PubSubManager pubSubManager, WebSocketClient client) {
- this.pubSubManager = pubSubManager;
- this.client = client;
- this.provisioningAddress = ProvisioningAddress.generate();
+ private final WebSocketClient client;
+
+ public ProvisioningConnection(WebSocketClient client) {
+ this.client = client;
}
@Override
- public void onPubSubMessage(PubSubMessage outgoingMessage) {
- if (outgoingMessage.getType() == PubSubMessage.Type.DELIVER) {
- Optional body = Optional.of(outgoingMessage.getContent().toByteArray());
+ public void onDispatchMessage(String channel, byte[] message) {
+ try {
+ PubSubMessage outgoingMessage = PubSubMessage.parseFrom(message);
- ListenableFuture response = client.sendRequest("PUT", "/v1/message", body);
+ if (outgoingMessage.getType() == PubSubMessage.Type.DELIVER) {
+ Optional body = Optional.of(outgoingMessage.getContent().toByteArray());
- Futures.addCallback(response, new FutureCallback() {
- @Override
- public void onSuccess(WebSocketResponseMessage webSocketResponseMessage) {
- pubSubManager.unsubscribe(provisioningAddress, ProvisioningConnection.this);
- client.close(1001, "All you get.");
- }
+ ListenableFuture response = client.sendRequest("PUT", "/v1/message", body);
- @Override
- public void onFailure(Throwable throwable) {
- pubSubManager.unsubscribe(provisioningAddress, ProvisioningConnection.this);
- client.close(1001, "That's all!");
- }
- });
+ Futures.addCallback(response, new FutureCallback() {
+ @Override
+ public void onSuccess(WebSocketResponseMessage webSocketResponseMessage) {
+ client.close(1001, "All you get.");
+ }
+
+ @Override
+ public void onFailure(Throwable throwable) {
+ client.close(1001, "That's all!");
+ }
+ });
+ }
+ } catch (InvalidProtocolBufferException e) {
+ logger.warn("Protobuf Error: ", e);
}
}
- public void onConnected() {
- this.pubSubManager.subscribe(provisioningAddress, this);
+ @Override
+ public void onDispatchSubscribed(String channel) {
this.client.sendRequest("PUT", "/v1/address", Optional.of(ProvisioningUuid.newBuilder()
- .setUuid(provisioningAddress.getAddress())
+ .setUuid(channel)
.build()
.toByteArray()));
}
- public void onConnectionLost() {
- this.pubSubManager.unsubscribe(provisioningAddress, this);
- this.client.close(1001, "Done");
+ @Override
+ public void onDispatchUnsubscribed(String channel) {
+ this.client.close(1001, "Closed");
}
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java
index f204a43ca..612a77d38 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java
@@ -10,6 +10,7 @@ import com.google.common.util.concurrent.ListenableFuture;
import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.textsecuregcm.entities.CryptoEncodingException;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
@@ -19,8 +20,6 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
-import org.whispersystems.textsecuregcm.storage.PubSubListener;
-import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.WebSocketClient;
@@ -34,21 +33,16 @@ import static com.codahale.metrics.MetricRegistry.name;
import static org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
-public class WebSocketConnection implements PubSubListener {
+public class WebSocketConnection implements DispatchChannel {
private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class);
- private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
- private static final Histogram durationHistogram = metricRegistry.histogram(name(WebSocketConnection.class, "connected_duration"));
-
private final AccountsManager accountsManager;
private final PushSender pushSender;
private final MessagesManager messagesManager;
- private final PubSubManager pubSubManager;
private final Account account;
private final Device device;
- private final WebsocketAddress address;
private final WebSocketClient client;
private long connectionStartTime;
@@ -56,7 +50,6 @@ public class WebSocketConnection implements PubSubListener {
public WebSocketConnection(AccountsManager accountsManager,
PushSender pushSender,
MessagesManager messagesManager,
- PubSubManager pubSubManager,
Account account,
Device device,
WebSocketClient client)
@@ -64,27 +57,16 @@ public class WebSocketConnection implements PubSubListener {
this.accountsManager = accountsManager;
this.pushSender = pushSender;
this.messagesManager = messagesManager;
- this.pubSubManager = pubSubManager;
this.account = account;
this.device = device;
this.client = client;
- this.address = new WebsocketAddress(account.getNumber(), device.getId());
- }
-
- public void onConnected() {
- connectionStartTime = System.currentTimeMillis();
- pubSubManager.subscribe(address, this);
- processStoredMessages();
- }
-
- public void onConnectionLost() {
- durationHistogram.update(System.currentTimeMillis() - connectionStartTime);
- pubSubManager.unsubscribe(address, this);
}
@Override
- public void onPubSubMessage(PubSubMessage pubSubMessage) {
+ public void onDispatchMessage(String channel, byte[] message) {
try {
+ PubSubMessage pubSubMessage = PubSubMessage.parseFrom(message);
+
switch (pubSubMessage.getType().getNumber()) {
case PubSubMessage.Type.QUERY_DB_VALUE:
processStoredMessages();
@@ -92,10 +74,6 @@ public class WebSocketConnection implements PubSubListener {
case PubSubMessage.Type.DELIVER_VALUE:
sendMessage(OutgoingMessageSignal.parseFrom(pubSubMessage.getContent()), Optional.absent());
break;
- case PubSubMessage.Type.CLOSE_VALUE:
- client.close(1000, "OK");
- pubSubManager.unsubscribe(address, this);
- break;
default:
logger.warn("Unknown pubsub message: " + pubSubMessage.getType().getNumber());
}
@@ -104,6 +82,15 @@ public class WebSocketConnection implements PubSubListener {
}
}
+ @Override
+ public void onDispatchUnsubscribed(String channel) {
+ client.close(1000, "OK");
+ }
+
+ public void onDispatchSubscribed(String channel) {
+ processStoredMessages();
+ }
+
private void sendMessage(final OutgoingMessageSignal message,
final Optional storedMessageId)
{
@@ -180,4 +167,6 @@ public class WebSocketConnection implements PubSubListener {
sendMessage(message.second(), Optional.of(message.first()));
}
}
+
+
}
diff --git a/src/test/java/org/whispersystems/dispatch/DispatchManagerTest.java b/src/test/java/org/whispersystems/dispatch/DispatchManagerTest.java
new file mode 100644
index 000000000..d6e81c518
--- /dev/null
+++ b/src/test/java/org/whispersystems/dispatch/DispatchManagerTest.java
@@ -0,0 +1,127 @@
+package org.whispersystems.dispatch;
+
+import com.google.common.base.Optional;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExternalResource;
+import org.mockito.ArgumentCaptor;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import org.whispersystems.dispatch.io.RedisPubSubConnectionFactory;
+import org.whispersystems.dispatch.redis.PubSubConnection;
+import org.whispersystems.dispatch.redis.PubSubReply;
+
+import java.io.IOException;
+import java.util.LinkedList;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.*;
+
+public class DispatchManagerTest {
+
+ private PubSubConnection pubSubConnection;
+ private RedisPubSubConnectionFactory socketFactory;
+ private DispatchManager dispatchManager;
+ private PubSubReplyInputStream pubSubReplyInputStream;
+
+ @Rule
+ public ExternalResource resource = new ExternalResource() {
+ @Override
+ protected void before() throws Throwable {
+ pubSubConnection = mock(PubSubConnection.class );
+ socketFactory = mock(RedisPubSubConnectionFactory.class);
+ pubSubReplyInputStream = new PubSubReplyInputStream();
+
+ when(socketFactory.connect()).thenReturn(pubSubConnection);
+ when(pubSubConnection.read()).thenAnswer(new Answer() {
+ @Override
+ public PubSubReply answer(InvocationOnMock invocationOnMock) throws Throwable {
+ return pubSubReplyInputStream.read();
+ }
+ });
+
+ dispatchManager = new DispatchManager(socketFactory, Optional.absent());
+ dispatchManager.start();
+ }
+
+ @Override
+ protected void after() {
+
+ }
+ };
+
+ @Test
+ public void testConnect() {
+ verify(socketFactory).connect();
+ }
+
+ @Test
+ public void testSubscribe() throws IOException {
+ DispatchChannel dispatchChannel = mock(DispatchChannel.class);
+ dispatchManager.subscribe("foo", dispatchChannel);
+ pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "foo", Optional.absent()));
+
+ verify(dispatchChannel, timeout(1000)).onDispatchSubscribed(eq("foo"));
+ }
+
+ @Test
+ public void testSubscribeUnsubscribe() throws IOException {
+ DispatchChannel dispatchChannel = mock(DispatchChannel.class);
+ dispatchManager.subscribe("foo", dispatchChannel);
+ dispatchManager.unsubscribe("foo", dispatchChannel);
+
+ pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "foo", Optional.absent()));
+ pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.UNSUBSCRIBE, "foo", Optional.absent()));
+
+ verify(dispatchChannel, timeout(1000)).onDispatchUnsubscribed(eq("foo"));
+ }
+
+ @Test
+ public void testMessages() throws IOException {
+ DispatchChannel fooChannel = mock(DispatchChannel.class);
+ DispatchChannel barChannel = mock(DispatchChannel.class);
+
+ dispatchManager.subscribe("foo", fooChannel);
+ dispatchManager.subscribe("bar", barChannel);
+
+ pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "foo", Optional.absent()));
+ pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "bar", Optional.absent()));
+
+ verify(fooChannel, timeout(1000)).onDispatchSubscribed(eq("foo"));
+ verify(barChannel, timeout(1000)).onDispatchSubscribed(eq("bar"));
+
+ pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.MESSAGE, "foo", Optional.of("hello".getBytes())));
+ pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.MESSAGE, "bar", Optional.of("there".getBytes())));
+
+ ArgumentCaptor captor = ArgumentCaptor.forClass(byte[].class);
+ verify(fooChannel, timeout(1000)).onDispatchMessage(eq("foo"), captor.capture());
+
+ assertArrayEquals("hello".getBytes(), captor.getValue());
+
+ verify(barChannel, timeout(1000)).onDispatchMessage(eq("bar"), captor.capture());
+
+ assertArrayEquals("there".getBytes(), captor.getValue());
+ }
+
+ private static class PubSubReplyInputStream {
+
+ private final List pubSubReplyList = new LinkedList<>();
+
+ public synchronized PubSubReply read() {
+ try {
+ while (pubSubReplyList.isEmpty()) wait();
+ return pubSubReplyList.remove(0);
+ } catch (InterruptedException e) {
+ throw new AssertionError(e);
+ }
+ }
+
+ public synchronized void write(PubSubReply pubSubReply) {
+ pubSubReplyList.add(pubSubReply);
+ notifyAll();
+ }
+ }
+
+}
diff --git a/src/test/java/org/whispersystems/dispatch/redis/PubSubConnectionTest.java b/src/test/java/org/whispersystems/dispatch/redis/PubSubConnectionTest.java
new file mode 100644
index 000000000..a1524aabb
--- /dev/null
+++ b/src/test/java/org/whispersystems/dispatch/redis/PubSubConnectionTest.java
@@ -0,0 +1,263 @@
+package org.whispersystems.dispatch.redis;
+
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.net.Socket;
+import java.security.NoSuchAlgorithmException;
+import java.security.SecureRandom;
+
+import static org.junit.Assert.*;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Mockito.*;
+
+public class PubSubConnectionTest {
+
+ private static final String REPLY = "*3\r\n" +
+ "$9\r\n" +
+ "subscribe\r\n" +
+ "$5\r\n" +
+ "abcde\r\n" +
+ ":1\r\n" +
+ "*3\r\n" +
+ "$9\r\n" +
+ "subscribe\r\n" +
+ "$5\r\n" +
+ "fghij\r\n" +
+ ":2\r\n" +
+ "*3\r\n" +
+ "$9\r\n" +
+ "subscribe\r\n" +
+ "$5\r\n" +
+ "klmno\r\n" +
+ ":2\r\n" +
+ "*3\r\n" +
+ "$7\r\n" +
+ "message\r\n" +
+ "$5\r\n" +
+ "abcde\r\n" +
+ "$10\r\n" +
+ "1234567890\r\n" +
+ "*3\r\n" +
+ "$7\r\n" +
+ "message\r\n" +
+ "$5\r\n" +
+ "klmno\r\n" +
+ "$10\r\n" +
+ "0987654321\r\n";
+
+
+ @Test
+ public void testSubscribe() throws IOException {
+// ByteChannel byteChannel = mock(ByteChannel.class);
+ OutputStream outputStream = mock(OutputStream.class);
+ Socket socket = mock(Socket.class );
+ when(socket.getOutputStream()).thenReturn(outputStream);
+ PubSubConnection connection = new PubSubConnection(socket);
+
+ connection.subscribe("foobar");
+
+ ArgumentCaptor captor = ArgumentCaptor.forClass(byte[].class);
+ verify(outputStream).write(captor.capture());
+
+ assertArrayEquals(captor.getValue(), "SUBSCRIBE foobar\r\n".getBytes());
+ }
+
+ @Test
+ public void testUnsubscribe() throws IOException {
+ OutputStream outputStream = mock(OutputStream.class);
+ Socket socket = mock(Socket.class );
+ when(socket.getOutputStream()).thenReturn(outputStream);
+ PubSubConnection connection = new PubSubConnection(socket);
+
+ connection.unsubscribe("bazbar");
+
+ ArgumentCaptor captor = ArgumentCaptor.forClass(byte[].class);
+ verify(outputStream).write(captor.capture());
+
+ assertArrayEquals(captor.getValue(), "UNSUBSCRIBE bazbar\r\n".getBytes());
+ }
+
+ @Test
+ public void testTricklyResponse() throws Exception {
+ InputStream inputStream = mockInputStreamFor(new TrickleInputStream(REPLY.getBytes()));
+ OutputStream outputStream = mock(OutputStream.class);
+ Socket socket = mock(Socket.class );
+ when(socket.getOutputStream()).thenReturn(outputStream);
+ when(socket.getInputStream()).thenReturn(inputStream);
+
+ PubSubConnection pubSubConnection = new PubSubConnection(socket);
+ readResponses(pubSubConnection);
+ }
+
+ @Test
+ public void testFullResponse() throws Exception {
+ InputStream inputStream = mockInputStreamFor(new FullInputStream(REPLY.getBytes()));
+ OutputStream outputStream = mock(OutputStream.class);
+ Socket socket = mock(Socket.class );
+ when(socket.getOutputStream()).thenReturn(outputStream);
+ when(socket.getInputStream()).thenReturn(inputStream);
+
+ PubSubConnection pubSubConnection = new PubSubConnection(socket);
+ readResponses(pubSubConnection);
+ }
+
+ @Test
+ public void testRandomLengthResponse() throws Exception {
+ InputStream inputStream = mockInputStreamFor(new RandomInputStream(REPLY.getBytes()));
+ OutputStream outputStream = mock(OutputStream.class);
+ Socket socket = mock(Socket.class );
+ when(socket.getOutputStream()).thenReturn(outputStream);
+ when(socket.getInputStream()).thenReturn(inputStream);
+
+ PubSubConnection pubSubConnection = new PubSubConnection(socket);
+ readResponses(pubSubConnection);
+ }
+
+ private InputStream mockInputStreamFor(final MockInputStream stub) throws IOException {
+ InputStream result = mock(InputStream.class);
+
+ when(result.read()).thenAnswer(new Answer() {
+ @Override
+ public Integer answer(InvocationOnMock invocationOnMock) throws Throwable {
+ return stub.read();
+ }
+ });
+
+ when(result.read(any(byte[].class))).thenAnswer(new Answer() {
+ @Override
+ public Integer answer(InvocationOnMock invocationOnMock) throws Throwable {
+ byte[] buffer = (byte[])invocationOnMock.getArguments()[0];
+ return stub.read(buffer, 0, buffer.length);
+ }
+ });
+
+ when(result.read(any(byte[].class), anyInt(), anyInt())).thenAnswer(new Answer() {
+ @Override
+ public Integer answer(InvocationOnMock invocationOnMock) throws Throwable {
+ byte[] buffer = (byte[]) invocationOnMock.getArguments()[0];
+ int offset = (int) invocationOnMock.getArguments()[1];
+ int length = (int) invocationOnMock.getArguments()[2];
+
+ return stub.read(buffer, offset, length);
+ }
+ });
+
+ return result;
+ }
+
+ private void readResponses(PubSubConnection pubSubConnection) throws Exception {
+ PubSubReply reply = pubSubConnection.read();
+
+ assertEquals(reply.getType(), PubSubReply.Type.SUBSCRIBE);
+ assertEquals(reply.getChannel(), "abcde");
+ assertFalse(reply.getContent().isPresent());
+
+ reply = pubSubConnection.read();
+
+ assertEquals(reply.getType(), PubSubReply.Type.SUBSCRIBE);
+ assertEquals(reply.getChannel(), "fghij");
+ assertFalse(reply.getContent().isPresent());
+
+ reply = pubSubConnection.read();
+
+ assertEquals(reply.getType(), PubSubReply.Type.SUBSCRIBE);
+ assertEquals(reply.getChannel(), "klmno");
+ assertFalse(reply.getContent().isPresent());
+
+ reply = pubSubConnection.read();
+
+ assertEquals(reply.getType(), PubSubReply.Type.MESSAGE);
+ assertEquals(reply.getChannel(), "abcde");
+ assertArrayEquals(reply.getContent().get(), "1234567890".getBytes());
+
+ reply = pubSubConnection.read();
+
+ assertEquals(reply.getType(), PubSubReply.Type.MESSAGE);
+ assertEquals(reply.getChannel(), "klmno");
+ assertArrayEquals(reply.getContent().get(), "0987654321".getBytes());
+ }
+
+ private interface MockInputStream {
+ public int read();
+ public int read(byte[] input, int offset, int length);
+ }
+
+ private static class TrickleInputStream implements MockInputStream {
+
+ private final byte[] data;
+ private int index = 0;
+
+ private TrickleInputStream(byte[] data) {
+ this.data = data;
+ }
+
+ public int read() {
+ return data[index++];
+ }
+
+ public int read(byte[] input, int offset, int length) {
+ input[offset] = data[index++];
+ return 1;
+ }
+
+ }
+
+ private static class FullInputStream implements MockInputStream {
+
+ private final byte[] data;
+ private int index = 0;
+
+ private FullInputStream(byte[] data) {
+ this.data = data;
+ }
+
+ public int read() {
+ return data[index++];
+ }
+
+ public int read(byte[] input, int offset, int length) {
+ int amount = Math.min(data.length - index, length);
+ System.arraycopy(data, index, input, offset, amount);
+ index += length;
+
+ return amount;
+ }
+ }
+
+ private static class RandomInputStream implements MockInputStream {
+ private final byte[] data;
+ private int index = 0;
+
+ private RandomInputStream(byte[] data) {
+ this.data = data;
+ }
+
+ public int read() {
+ return data[index++];
+ }
+
+ public int read(byte[] input, int offset, int length) {
+ try {
+ int maxCopy = Math.min(data.length - index, length);
+ int randomCopy = SecureRandom.getInstance("SHA1PRNG").nextInt(maxCopy) + 1;
+ int copyAmount = Math.min(maxCopy, randomCopy);
+
+ System.arraycopy(data, index, input, offset, copyAmount);
+ index += copyAmount;
+
+ return copyAmount;
+ } catch (NoSuchAlgorithmException e) {
+ throw new AssertionError(e);
+ }
+ }
+
+ }
+
+}
diff --git a/src/test/java/org/whispersystems/dispatch/redis/protocol/ArrayReplyHeaderTest.java b/src/test/java/org/whispersystems/dispatch/redis/protocol/ArrayReplyHeaderTest.java
new file mode 100644
index 000000000..1c930ceb3
--- /dev/null
+++ b/src/test/java/org/whispersystems/dispatch/redis/protocol/ArrayReplyHeaderTest.java
@@ -0,0 +1,51 @@
+package org.whispersystems.dispatch.redis.protocol;
+
+
+import org.junit.Test;
+
+import java.io.IOException;
+
+import static org.junit.Assert.assertEquals;
+
+public class ArrayReplyHeaderTest {
+
+
+ @Test(expected = IOException.class)
+ public void testNull() throws IOException {
+ new ArrayReplyHeader(null);
+ }
+
+ @Test(expected = IOException.class)
+ public void testBadPrefix() throws IOException {
+ new ArrayReplyHeader(":3");
+ }
+
+ @Test(expected = IOException.class)
+ public void testEmpty() throws IOException {
+ new ArrayReplyHeader("");
+ }
+
+ @Test(expected = IOException.class)
+ public void testTruncated() throws IOException {
+ new ArrayReplyHeader("*");
+ }
+
+ @Test(expected = IOException.class)
+ public void testBadNumber() throws IOException {
+ new ArrayReplyHeader("*ABC");
+ }
+
+ @Test
+ public void testValid() throws IOException {
+ assertEquals(4, new ArrayReplyHeader("*4").getElementCount());
+ }
+
+
+
+
+
+
+
+
+
+}
diff --git a/src/test/java/org/whispersystems/dispatch/redis/protocol/IntReplyHeaderTest.java b/src/test/java/org/whispersystems/dispatch/redis/protocol/IntReplyHeaderTest.java
new file mode 100644
index 000000000..24881c93c
--- /dev/null
+++ b/src/test/java/org/whispersystems/dispatch/redis/protocol/IntReplyHeaderTest.java
@@ -0,0 +1,36 @@
+package org.whispersystems.dispatch.redis.protocol;
+
+
+import org.junit.Test;
+
+import java.io.IOException;
+
+import static org.junit.Assert.assertEquals;
+
+public class IntReplyHeaderTest {
+
+ @Test(expected = IOException.class)
+ public void testNull() throws IOException {
+ new IntReply(null);
+ }
+
+ @Test(expected = IOException.class)
+ public void testEmpty() throws IOException {
+ new IntReply("");
+ }
+
+ @Test(expected = IOException.class)
+ public void testBadNumber() throws IOException {
+ new IntReply(":A");
+ }
+
+ @Test(expected = IOException.class)
+ public void testBadFormat() throws IOException {
+ new IntReply("*");
+ }
+
+ @Test
+ public void testValid() throws IOException {
+ assertEquals(23, new IntReply(":23").getValue());
+ }
+}
diff --git a/src/test/java/org/whispersystems/dispatch/redis/protocol/StringReplyHeaderTest.java b/src/test/java/org/whispersystems/dispatch/redis/protocol/StringReplyHeaderTest.java
new file mode 100644
index 000000000..a49517df9
--- /dev/null
+++ b/src/test/java/org/whispersystems/dispatch/redis/protocol/StringReplyHeaderTest.java
@@ -0,0 +1,47 @@
+package org.whispersystems.dispatch.redis.protocol;
+
+import org.junit.Test;
+
+import java.io.IOException;
+
+import static org.junit.Assert.assertEquals;
+
+public class StringReplyHeaderTest {
+
+ @Test
+ public void testNull() {
+ try {
+ new StringReplyHeader(null);
+ throw new AssertionError();
+ } catch (IOException e) {
+ // good
+ }
+ }
+
+ @Test
+ public void testBadNumber() {
+ try {
+ new StringReplyHeader("$100A");
+ throw new AssertionError();
+ } catch (IOException e) {
+ // good
+ }
+ }
+
+ @Test
+ public void testBadPrefix() {
+ try {
+ new StringReplyHeader("*");
+ throw new AssertionError();
+ } catch (IOException e) {
+ // good
+ }
+ }
+
+ @Test
+ public void testValid() throws IOException {
+ assertEquals(1000, new StringReplyHeader("$1000").getStringLength());
+ }
+
+
+}
diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java
index 32833cd2f..40cf98200 100644
--- a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java
+++ b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java
@@ -44,45 +44,20 @@ import static org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMe
public class WebSocketConnectionTest {
-// private static final ObjectMapper mapper = new ObjectMapper();
-
private static final String VALID_USER = "+14152222222";
private static final String INVALID_USER = "+14151111111";
private static final String VALID_PASSWORD = "secure";
private static final String INVALID_PASSWORD = "insecure";
-// private static final StoredMessages storedMessages = mock(StoredMessages.class);
private static final AccountAuthenticator accountAuthenticator = mock(AccountAuthenticator.class);
private static final AccountsManager accountsManager = mock(AccountsManager.class);
private static final PubSubManager pubSubManager = mock(PubSubManager.class );
private static final Account account = mock(Account.class );
private static final Device device = mock(Device.class );
private static final UpgradeRequest upgradeRequest = mock(UpgradeRequest.class );
-// private static final Session session = mock(Session.class );
private static final PushSender pushSender = mock(PushSender.class);
- @Test
- public void testCloseExisting() throws Exception {
- MessagesManager storedMessages = mock(MessagesManager.class );
- WebSocketConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, pushSender, storedMessages, pubSubManager);
- WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
- Account account = mock(Account.class );
- Device device = mock(Device.class );
-
- when(sessionContext.getAuthenticated(Account.class)).thenReturn(Optional.of(account));
- when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
- when(account.getNumber()).thenReturn("+14157777777");
- when(device.getId()).thenReturn(1L);
-
- connectListener.onWebSocketConnect(sessionContext);
-
- ArgumentCaptor message = ArgumentCaptor.forClass(PubSubProtos.PubSubMessage.class);
-
- verify(pubSubManager).publish(eq(new WebsocketAddress("+14157777777", 1L)), message.capture());
- assertEquals(message.getValue().getType().getNumber(), PubSubProtos.PubSubMessage.Type.CLOSE_VALUE);
- }
-
@Test
public void testCredentials() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class);
@@ -98,10 +73,6 @@ public class WebSocketConnectionTest {
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
-// when(session.getUpgradeRequest()).thenReturn(upgradeRequest);
-//
-// WebsocketController controller = new WebsocketController(accountAuthenticator, accountsManager, pushSender, pubSubManager, storedMessages);
-
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap() {{
put("login", new String[] {VALID_USER});
put("password", new String[] {VALID_PASSWORD});
@@ -114,13 +85,6 @@ public class WebSocketConnectionTest {
verify(sessionContext).addListener(any(WebSocketSessionContext.WebSocketEventListener.class));
-//
-// controller.onWebSocketConnect(session);
-
-// verify(session, never()).close();
-// verify(session, never()).close(any(CloseStatus.class));
-// verify(session, never()).close(anyInt(), anyString());
-
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap() {{
put("login", new String[] {INVALID_USER});
put("password", new String[] {INVALID_PASSWORD});
@@ -128,15 +92,6 @@ public class WebSocketConnectionTest {
account = webSocketAuthenticator.authenticate(upgradeRequest);
assertFalse(account.isPresent());
-// when(sessionContext.getAuthenticated(Account.class)).thenReturn(account);
-//
-// WebSocketClient client = mock(WebSocketClient.class);
-// when(sessionContext.getClient()).thenReturn(client);
-//
-// connectListener.onWebSocketConnect(sessionContext);
-//
-// verify(sessionContext, times(1)).addListener(any(WebSocketSessionContext.WebSocketEventListener.class));
-// verify(client).close(eq(4001), anyString());
}
@Test
@@ -183,12 +138,11 @@ public class WebSocketConnectionTest {
}
});
+ WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId());
WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender, storedMessages,
- pubSubManager, account, device, client);
+ account, device, client);
- connection.onConnected();
-
- verify(pubSubManager).subscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq((connection)));
+ connection.onDispatchSubscribed(websocketAddress.serialize());
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(Optional.class));
assertTrue(futures.size() == 3);
@@ -205,11 +159,10 @@ public class WebSocketConnectionTest {
add(createMessage("sender2", 3333, false, "third"));
}};
-// verify(pushSender, times(2)).sendMessage(eq(account), eq(device), any(OutgoingMessageSignal.class));
verify(pushSender, times(1)).sendMessage(eq(sender1), eq(sender1device), any(OutgoingMessageSignal.class));
- connection.onConnectionLost();
- verify(pubSubManager).unsubscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq(connection));
+ connection.onDispatchUnsubscribed(websocketAddress.serialize());
+ verify(client).close(anyInt(), anyString());
}
private OutgoingMessageSignal createMessage(String sender, long timestamp, boolean receipt, String content) {