Use a custom redis pubsub implementation rather than Jedis.

// FREEBIE
This commit is contained in:
Moxie Marlinspike 2015-03-16 15:05:33 -07:00
parent e79861c30a
commit c7e0cc1158
26 changed files with 1254 additions and 292 deletions

View File

@ -132,7 +132,6 @@
<version>0.2.3</version>
</dependency>
</dependencies>
<dependencyManagement>

View File

@ -0,0 +1,7 @@
package org.whispersystems.dispatch;
public interface DispatchChannel {
public void onDispatchMessage(String channel, byte[] message);
public void onDispatchSubscribed(String channel);
public void onDispatchUnsubscribed(String channel);
}

View File

@ -0,0 +1,191 @@
package org.whispersystems.dispatch;
import com.google.common.base.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.io.RedisPubSubConnectionFactory;
import org.whispersystems.dispatch.redis.PubSubConnection;
import org.whispersystems.dispatch.redis.PubSubReply;
import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
public class DispatchManager extends Thread {
private final Logger logger = LoggerFactory.getLogger(DispatchManager.class);
private final Executor executor = Executors.newCachedThreadPool();
private final Map<String, DispatchChannel> subscriptions = new HashMap<>();
private final Optional<DispatchChannel> deadLetterChannel;
private final RedisPubSubConnectionFactory redisPubSubConnectionFactory;
private PubSubConnection pubSubConnection;
private volatile boolean running;
public DispatchManager(RedisPubSubConnectionFactory redisPubSubConnectionFactory,
Optional<DispatchChannel> deadLetterChannel)
{
this.redisPubSubConnectionFactory = redisPubSubConnectionFactory;
this.deadLetterChannel = deadLetterChannel;
}
@Override
public void start() {
this.pubSubConnection = redisPubSubConnectionFactory.connect();
this.running = true;
super.start();
}
public void shutdown() {
this.running = false;
this.pubSubConnection.close();
}
public void subscribe(String name, DispatchChannel dispatchChannel) {
Optional<DispatchChannel> previous;
synchronized (subscriptions) {
previous = Optional.fromNullable(subscriptions.get(name));
subscriptions.put(name, dispatchChannel);
}
try {
pubSubConnection.subscribe(name);
} catch (IOException e) {
logger.warn("Subscription error", e);
}
if (previous.isPresent()) {
dispatchUnsubscription(name, previous.get());
}
}
public void unsubscribe(String name, DispatchChannel channel) {
final Optional<DispatchChannel> subscription;
synchronized (subscriptions) {
subscription = Optional.fromNullable(subscriptions.get(name));
if (subscription.isPresent() && subscription.get() == channel) {
subscriptions.remove(name);
}
}
if (subscription.isPresent()) {
try {
pubSubConnection.unsubscribe(name);
} catch (IOException e) {
logger.warn("Unsubscribe error", e);
}
dispatchUnsubscription(name, subscription.get());
}
}
@Override
public void run() {
while (running) {
try {
PubSubReply reply = pubSubConnection.read();
switch (reply.getType()) {
case UNSUBSCRIBE: break;
case SUBSCRIBE: dispatchSubscribe(reply); break;
case MESSAGE: dispatchMessage(reply); break;
default: throw new AssertionError("Unknown pubsub reply type! " + reply.getType());
}
} catch (IOException e) {
logger.warn("***** PubSub Connection Error *****", e);
if (running) {
this.pubSubConnection.close();
this.pubSubConnection = redisPubSubConnectionFactory.connect();
resubscribe();
}
}
}
logger.warn("DispatchManager Shutting Down...");
}
private void dispatchSubscribe(final PubSubReply reply) {
final Optional<DispatchChannel> subscription;
synchronized (subscriptions) {
subscription = Optional.fromNullable(subscriptions.get(reply.getChannel()));
}
if (subscription.isPresent()) {
dispatchSubscription(reply.getChannel(), subscription.get());
} else {
logger.warn("Received subscribe event for non-existing channel: " + reply.getChannel());
}
}
private void dispatchMessage(PubSubReply reply) {
Optional<DispatchChannel> subscription;
synchronized (subscriptions) {
subscription = Optional.fromNullable(subscriptions.get(reply.getChannel()));
}
if (subscription.isPresent()) {
dispatchMessage(reply.getChannel(), subscription.get(), reply.getContent().get());
} else if (deadLetterChannel.isPresent()) {
dispatchMessage(reply.getChannel(), deadLetterChannel.get(), reply.getContent().get());
} else {
logger.warn("Received message for non-existing channel, with no dead letter handler: " + reply.getChannel());
}
}
private void resubscribe() {
final Collection<String> names;
synchronized (subscriptions) {
names = subscriptions.keySet();
}
new Thread() {
@Override
public void run() {
try {
for (String name : names) {
pubSubConnection.subscribe(name);
}
} catch (IOException e) {
logger.warn("***** RESUBSCRIPTION ERROR *****", e);
}
}
}.start();
}
private void dispatchMessage(final String name, final DispatchChannel channel, final byte[] message) {
executor.execute(new Runnable() {
@Override
public void run() {
channel.onDispatchMessage(name, message);
}
});
}
private void dispatchSubscription(final String name, final DispatchChannel channel) {
executor.execute(new Runnable() {
@Override
public void run() {
channel.onDispatchSubscribed(name);
}
});
}
private void dispatchUnsubscription(final String name, final DispatchChannel channel) {
executor.execute(new Runnable() {
@Override
public void run() {
channel.onDispatchUnsubscribed(name);
}
});
}
}

View File

@ -0,0 +1,64 @@
package org.whispersystems.dispatch.io;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
public class RedisInputStream {
private static final byte CR = 0x0D;
private static final byte LF = 0x0A;
private final InputStream inputStream;
public RedisInputStream(InputStream inputStream) {
this.inputStream = inputStream;
}
public String readLine() throws IOException {
ByteArrayOutputStream boas = new ByteArrayOutputStream();
boolean foundCr = false;
while (true) {
int character = inputStream.read();
if (character == -1) {
throw new IOException("Stream closed!");
}
boas.write(character);
if (foundCr && character == LF) break;
else if (character == CR) foundCr = true;
else if (foundCr) foundCr = false;
}
byte[] data = boas.toByteArray();
return new String(data, 0, data.length-2);
}
public byte[] readFully(int size) throws IOException {
byte[] result = new byte[size];
int offset = 0;
int remaining = result.length;
while (remaining > 0) {
int read = inputStream.read(result, offset, remaining);
if (read < 0) {
throw new IOException("Stream closed!");
}
offset += read;
remaining -= read;
}
return result;
}
public void close() throws IOException {
inputStream.close();
}
}

View File

@ -0,0 +1,9 @@
package org.whispersystems.dispatch.io;
import org.whispersystems.dispatch.redis.PubSubConnection;
public interface RedisPubSubConnectionFactory {
public PubSubConnection connect();
}

View File

@ -0,0 +1,119 @@
package org.whispersystems.dispatch.redis;
import com.google.common.base.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.io.RedisInputStream;
import org.whispersystems.dispatch.redis.protocol.ArrayReplyHeader;
import org.whispersystems.dispatch.redis.protocol.IntReply;
import org.whispersystems.dispatch.redis.protocol.StringReplyHeader;
import org.whispersystems.dispatch.util.Util;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.net.Socket;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicBoolean;
public class PubSubConnection {
private final Logger logger = LoggerFactory.getLogger(PubSubConnection.class);
private static final byte[] UNSUBSCRIBE_TYPE = {'u', 'n', 's', 'u', 'b', 's', 'c', 'r', 'i', 'b', 'e' };
private static final byte[] SUBSCRIBE_TYPE = {'s', 'u', 'b', 's', 'c', 'r', 'i', 'b', 'e' };
private static final byte[] MESSAGE_TYPE = {'m', 'e', 's', 's', 'a', 'g', 'e' };
private static final byte[] SUBSCRIBE_COMMAND = {'S', 'U', 'B', 'S', 'C', 'R', 'I', 'B', 'E', ' ' };
private static final byte[] UNSUBSCRIBE_COMMAND = {'U', 'N', 'S', 'U', 'B', 'S', 'C', 'R', 'I', 'B', 'E', ' '};
private static final byte[] CRLF = {'\r', '\n' };
private final OutputStream outputStream;
private final RedisInputStream inputStream;
private final Socket socket;
private final AtomicBoolean closed;
public PubSubConnection(Socket socket) throws IOException {
this.socket = socket;
this.outputStream = socket.getOutputStream();
this.inputStream = new RedisInputStream(new BufferedInputStream(socket.getInputStream()));
this.closed = new AtomicBoolean(false);
}
public void subscribe(String channelName) throws IOException {
if (closed.get()) throw new IOException("Connection closed!");
byte[] command = Util.combine(SUBSCRIBE_COMMAND, channelName.getBytes(), CRLF);
outputStream.write(command);
}
public void unsubscribe(String channelName) throws IOException {
if (closed.get()) throw new IOException("Connection closed!");
byte[] command = Util.combine(UNSUBSCRIBE_COMMAND, channelName.getBytes(), CRLF);
outputStream.write(command);
}
public PubSubReply read() throws IOException {
if (closed.get()) throw new IOException("Connection closed!");
ArrayReplyHeader replyHeader = new ArrayReplyHeader(inputStream.readLine());
if (replyHeader.getElementCount() != 3) {
throw new IOException("Received array reply header with strange count: " + replyHeader.getElementCount());
}
StringReplyHeader replyTypeHeader = new StringReplyHeader(inputStream.readLine());
byte[] replyType = inputStream.readFully(replyTypeHeader.getStringLength());
inputStream.readLine();
if (Arrays.equals(SUBSCRIBE_TYPE, replyType)) return readSubscribeReply();
else if (Arrays.equals(UNSUBSCRIBE_TYPE, replyType)) return readUnsubscribeReply();
else if (Arrays.equals(MESSAGE_TYPE, replyType)) return readMessageReply();
else throw new IOException("Unknown reply type: " + new String(replyType));
}
public void close() {
try {
this.closed.set(true);
this.inputStream.close();
this.outputStream.close();
this.socket.close();
} catch (IOException e) {
logger.warn("Exception while closing", e);
}
}
private PubSubReply readMessageReply() throws IOException {
StringReplyHeader channelNameHeader = new StringReplyHeader(inputStream.readLine());
byte[] channelName = inputStream.readFully(channelNameHeader.getStringLength());
inputStream.readLine();
StringReplyHeader messageHeader = new StringReplyHeader(inputStream.readLine());
byte[] message = inputStream.readFully(messageHeader.getStringLength());
inputStream.readLine();
return new PubSubReply(PubSubReply.Type.MESSAGE, new String(channelName), Optional.of(message));
}
private PubSubReply readUnsubscribeReply() throws IOException {
String channelName = readSubscriptionReply();
return new PubSubReply(PubSubReply.Type.UNSUBSCRIBE, channelName, Optional.<byte[]>absent());
}
private PubSubReply readSubscribeReply() throws IOException {
String channelName = readSubscriptionReply();
return new PubSubReply(PubSubReply.Type.SUBSCRIBE, channelName, Optional.<byte[]>absent());
}
private String readSubscriptionReply() throws IOException {
StringReplyHeader channelNameHeader = new StringReplyHeader(inputStream.readLine());
byte[] channelName = inputStream.readFully(channelNameHeader.getStringLength());
inputStream.readLine();
IntReply subscriptionCount = new IntReply(inputStream.readLine());
return new String(channelName);
}
}

View File

@ -0,0 +1,35 @@
package org.whispersystems.dispatch.redis;
import com.google.common.base.Optional;
public class PubSubReply {
public enum Type {
MESSAGE,
SUBSCRIBE,
UNSUBSCRIBE
}
private final Type type;
private final String channel;
private final Optional<byte[]> content;
public PubSubReply(Type type, String channel, Optional<byte[]> content) {
this.type = type;
this.channel = channel;
this.content = content;
}
public Type getType() {
return type;
}
public String getChannel() {
return channel;
}
public Optional<byte[]> getContent() {
return content;
}
}

View File

@ -0,0 +1,24 @@
package org.whispersystems.dispatch.redis.protocol;
import java.io.IOException;
public class ArrayReplyHeader {
private final int elementCount;
public ArrayReplyHeader(String header) throws IOException {
if (header == null || header.length() < 2 || header.charAt(0) != '*') {
throw new IOException("Invalid array reply header: " + header);
}
try {
this.elementCount = Integer.parseInt(header.substring(1));
} catch (NumberFormatException e) {
throw new IOException(e);
}
}
public int getElementCount() {
return elementCount;
}
}

View File

@ -0,0 +1,24 @@
package org.whispersystems.dispatch.redis.protocol;
import java.io.IOException;
public class IntReply {
private final int value;
public IntReply(String reply) throws IOException {
if (reply == null || reply.length() < 2 || reply.charAt(0) != ':') {
throw new IOException("Invalid int reply: " + reply);
}
try {
this.value = Integer.parseInt(reply.substring(1));
} catch (NumberFormatException e) {
throw new IOException(e);
}
}
public int getValue() {
return value;
}
}

View File

@ -0,0 +1,24 @@
package org.whispersystems.dispatch.redis.protocol;
import java.io.IOException;
public class StringReplyHeader {
private final int stringLength;
public StringReplyHeader(String header) throws IOException {
if (header == null || header.length() < 2 || header.charAt(0) != '$') {
throw new IOException("Invalid string reply header: " + header);
}
try {
this.stringLength = Integer.parseInt(header.substring(1));
} catch (NumberFormatException e) {
throw new IOException(e);
}
}
public int getStringLength() {
return stringLength;
}
}

View File

@ -0,0 +1,36 @@
package org.whispersystems.dispatch.util;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
public class Util {
public static byte[] combine(byte[]... elements) {
try {
int sum = 0;
for (byte[] element : elements) {
sum += element.length;
}
ByteArrayOutputStream baos = new ByteArrayOutputStream(sum);
for (byte[] element : elements) {
baos.write(element);
}
return baos.toByteArray();
} catch (IOException e) {
throw new AssertionError(e);
}
}
public static void sleep(long millis) {
try {
Thread.sleep(millis);
} catch (InterruptedException e) {
throw new AssertionError(e);
}
}
}

View File

@ -24,6 +24,8 @@ import com.sun.jersey.api.client.Client;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.eclipse.jetty.servlets.CrossOriginFilter;
import org.skife.jdbi.v2.DBI;
import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.dispatch.DispatchManager;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.FederatedPeerAuthenticator;
import org.whispersystems.textsecuregcm.auth.MultiBasicAuthProvider;
@ -147,10 +149,11 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
Keys keys = database.onDemand(Keys.class);
Messages messages = messagedb.onDemand(Messages.class);
JedisPool cacheClient = new RedisClientFactory(config.getCacheConfiguration().getUrl()).getRedisClientPool();
JedisPool directoryClient = new RedisClientFactory(config.getDirectoryConfiguration().getUrl()).getRedisClientPool();
Client httpClient = new JerseyClientBuilder(environment).using(config.getJerseyClientConfiguration())
.build(getName());
RedisClientFactory cacheClientFactory = new RedisClientFactory(config.getCacheConfiguration().getUrl());
JedisPool cacheClient = cacheClientFactory.getRedisClientPool();
JedisPool directoryClient = new RedisClientFactory(config.getDirectoryConfiguration().getUrl()).getRedisClientPool();
Client httpClient = new JerseyClientBuilder(environment).using(config.getJerseyClientConfiguration())
.build(getName());
DirectoryManager directory = new DirectoryManager(directoryClient);
PendingAccountsManager pendingAccountsManager = new PendingAccountsManager(pendingAccounts, cacheClient);
@ -159,7 +162,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
FederatedClientManager federatedClientManager = new FederatedClientManager(config.getFederationConfiguration());
MessagesManager messagesManager = new MessagesManager(messages);
DeadLetterHandler deadLetterHandler = new DeadLetterHandler(messagesManager);
PubSubManager pubSubManager = new PubSubManager(cacheClient, deadLetterHandler);
DispatchManager dispatchManager = new DispatchManager(cacheClientFactory, Optional.<DispatchChannel>of(deadLetterHandler));
PubSubManager pubSubManager = new PubSubManager(cacheClient, dispatchManager);
PushServiceClient pushServiceClient = new PushServiceClient(httpClient, config.getPushConfiguration());
WebsocketSender websocketSender = new WebsocketSender(messagesManager, pubSubManager);
AccountAuthenticator deviceAuthenticator = new AccountAuthenticator(accountsManager);
@ -173,6 +177,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
FeedbackHandler feedbackHandler = new FeedbackHandler(pushServiceClient, accountsManager);
Optional<byte[]> authorizationKey = config.getRedphoneConfiguration().getAuthorizationKey();
environment.lifecycle().manage(pubSubManager);
environment.lifecycle().manage(feedbackHandler);
AttachmentController attachmentController = new AttachmentController(rateLimiters, federatedClientManager, urlSigner);
@ -263,5 +268,4 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
public static void main(String[] args) throws Exception {
new WhisperServerService().run(args);
}
}

View File

@ -16,8 +16,14 @@
*/
package org.whispersystems.textsecuregcm.providers;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.io.RedisPubSubConnectionFactory;
import org.whispersystems.dispatch.redis.PubSubConnection;
import org.whispersystems.textsecuregcm.util.Util;
import java.io.IOException;
import java.net.Socket;
import java.net.URI;
import java.net.URISyntaxException;
@ -25,29 +31,40 @@ import redis.clients.jedis.JedisPool;
import redis.clients.jedis.JedisPoolConfig;
import redis.clients.jedis.Protocol;
public class RedisClientFactory {
public class RedisClientFactory implements RedisPubSubConnectionFactory {
private final Logger logger = LoggerFactory.getLogger(RedisClientFactory.class);
private final String host;
private final int port;
private final JedisPool jedisPool;
public RedisClientFactory(String url) throws URISyntaxException {
JedisPoolConfig poolConfig = new JedisPoolConfig();
poolConfig.setTestOnBorrow(true);
URI redisURI = new URI(url);
String redisHost = redisURI.getHost();
int redisPort = redisURI.getPort();
String redisPassword = null;
URI redisURI = new URI(url);
if (!Util.isEmpty(redisURI.getUserInfo())) {
redisPassword = redisURI.getUserInfo().split(":",2)[1];
}
this.jedisPool = new JedisPool(poolConfig, redisHost, redisPort,
Protocol.DEFAULT_TIMEOUT, redisPassword);
this.host = redisURI.getHost();
this.port = redisURI.getPort();
this.jedisPool = new JedisPool(poolConfig, host, port,
Protocol.DEFAULT_TIMEOUT, null);
}
public JedisPool getRedisClientPool() {
return jedisPool;
}
@Override
public PubSubConnection connect() {
while (true) {
try {
Socket socket = new Socket(host, port);
return new PubSubConnection(socket);
} catch (IOException e) {
logger.warn("Error connecting", e);
Util.sleep(200);
}
}
}
}

View File

@ -1,9 +0,0 @@
package org.whispersystems.textsecuregcm.storage;
import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
public interface PubSubListener {
public void onPubSubMessage(PubSubMessage outgoingMessage);
}

View File

@ -1,177 +1,110 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.ByteString;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.websocket.DeadLetterHandler;
import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.dispatch.DispatchManager;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import io.dropwizard.lifecycle.Managed;
import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
import redis.clients.jedis.BinaryJedisPubSub;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
public class PubSubManager {
public class PubSubManager implements Managed {
private static final byte[] KEEPALIVE_CHANNEL = "KEEPALIVE".getBytes();
private static final String KEEPALIVE_CHANNEL = "KEEPALIVE";
private final Logger logger = LoggerFactory.getLogger(PubSubManager.class);
private final SubscriptionListener baseListener = new SubscriptionListener();
private final Map<String, PubSubListener> listeners = new HashMap<>();
private final Executor threaded = Executors.newCachedThreadPool();
private final Logger logger = LoggerFactory.getLogger(PubSubManager.class);
private final DispatchManager dispatchManager;
private final JedisPool jedisPool;
private final DeadLetterHandler deadLetterHandler;
private boolean subscribed = false;
public PubSubManager(JedisPool jedisPool, DeadLetterHandler deadLetterHandler) {
public PubSubManager(JedisPool jedisPool, DispatchManager dispatchManager) {
this.dispatchManager = dispatchManager;
this.jedisPool = jedisPool;
this.deadLetterHandler = deadLetterHandler;
initializePubSubWorker();
waitForSubscription();
}
public synchronized void subscribe(WebsocketAddress address, PubSubListener listener) {
String serializedAddress = address.serialize();
@Override
public void start() throws Exception {
this.dispatchManager.start();
listeners.put(serializedAddress, listener);
baseListener.subscribe(serializedAddress.getBytes());
}
KeepaliveDispatchChannel keepaliveDispatchChannel = new KeepaliveDispatchChannel();
this.dispatchManager.subscribe(KEEPALIVE_CHANNEL, keepaliveDispatchChannel);
public synchronized void unsubscribe(WebsocketAddress address, PubSubListener listener) {
String serializedAddress = address.serialize();
if (listeners.get(serializedAddress) == listener) {
listeners.remove(serializedAddress);
baseListener.unsubscribe(serializedAddress.getBytes());
synchronized (this) {
while (!subscribed) wait(0);
}
new KeepaliveSender().start();
}
public synchronized boolean publish(WebsocketAddress address, PubSubMessage message) {
@Override
public void stop() throws Exception {
dispatchManager.shutdown();
}
public void subscribe(WebsocketAddress address, DispatchChannel channel) {
String serializedAddress = address.serialize();
dispatchManager.subscribe(serializedAddress, channel);
}
public void unsubscribe(WebsocketAddress address, DispatchChannel dispatchChannel) {
String serializedAddress = address.serialize();
dispatchManager.unsubscribe(serializedAddress, dispatchChannel);
}
public boolean publish(WebsocketAddress address, PubSubMessage message) {
return publish(address.serialize().getBytes(), message);
}
private synchronized boolean publish(byte[] channel, PubSubMessage message) {
private boolean publish(byte[] channel, PubSubMessage message) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.publish(channel, message.toByteArray()) != 0;
}
}
private synchronized void resubscribeAll() {
for (String serializedAddress : listeners.keySet()) {
baseListener.subscribe(serializedAddress.getBytes());
}
}
private synchronized void waitForSubscription() {
try {
while (!subscribed) {
wait();
}
} catch (InterruptedException e) {
throw new AssertionError(e);
}
}
private void initializePubSubWorker() {
new Thread("PubSubListener") {
@Override
public void run() {
for (;;) {
logger.info("Starting Redis PubSub Subscriber...");
try (Jedis jedis = jedisPool.getResource()) {
jedis.subscribe(baseListener, KEEPALIVE_CHANNEL);
logger.warn("**** Unsubscribed from holding channel!!! ******");
} catch (Throwable t) {
logger.warn("*** SUBSCRIBER CONNECTION CLOSED", t);
}
}
}
}.start();
new Thread("PubSubKeepAlive") {
@Override
public void run() {
for (;;) {
try {
Thread.sleep(20000);
publish(KEEPALIVE_CHANNEL, PubSubMessage.newBuilder()
.setType(PubSubMessage.Type.KEEPALIVE)
.build());
} catch (Throwable e) {
logger.warn("KEEPALIVE PUBLISH EXCEPTION: ", e);
}
}
}
}.start();
}
private class SubscriptionListener extends BinaryJedisPubSub {
private class KeepaliveDispatchChannel implements DispatchChannel {
@Override
public void onMessage(final byte[] channel, final byte[] message) {
if (Arrays.equals(KEEPALIVE_CHANNEL, channel)) {
return;
}
final PubSubListener listener;
synchronized (PubSubManager.this) {
listener = listeners.get(new String(channel));
}
threaded.execute(new Runnable() {
@Override
public void run() {
try {
PubSubMessage receivedMessage = PubSubMessage.parseFrom(message);
if (listener != null) listener.onPubSubMessage(receivedMessage);
else deadLetterHandler.handle(channel, receivedMessage);
} catch (InvalidProtocolBufferException e) {
logger.warn("Error parsing PubSub protobuf", e);
}
}
});
public void onDispatchMessage(String channel, byte[] message) {
// Good
}
@Override
public void onPMessage(byte[] s, byte[] s2, byte[] s3) {
logger.warn("Received PMessage!");
}
@Override
public void onSubscribe(byte[] channel, int count) {
if (Arrays.equals(KEEPALIVE_CHANNEL, channel)) {
public void onDispatchSubscribed(String channel) {
if (KEEPALIVE_CHANNEL.equals(channel)) {
synchronized (PubSubManager.this) {
subscribed = true;
PubSubManager.this.notifyAll();
}
threaded.execute(new Runnable() {
@Override
public void run() {
resubscribeAll();
}
});
}
}
@Override
public void onUnsubscribe(byte[] s, int i) {}
public void onDispatchUnsubscribed(String channel) {
logger.warn("***** KEEPALIVE CHANNEL UNSUBSCRIBED *****");
}
}
private class KeepaliveSender extends Thread {
@Override
public void onPUnsubscribe(byte[] s, int i) {}
@Override
public void onPSubscribe(byte[] s, int i) {}
public void run() {
while (true) {
try {
Thread.sleep(20000);
publish(KEEPALIVE_CHANNEL.getBytes(), PubSubMessage.newBuilder()
.setType(PubSubMessage.Type.KEEPALIVE)
.build());
} catch (Throwable e) {
logger.warn("***** KEEPALIVE EXCEPTION ******", e);
}
}
}
}
}

View File

@ -1,5 +1,8 @@
package org.whispersystems.textsecuregcm.websocket;
import com.codahale.metrics.Histogram;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.push.PushSender;
@ -8,14 +11,19 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener;
import static com.codahale.metrics.MetricRegistry.name;
public class AuthenticatedConnectListener implements WebSocketConnectListener {
private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class);
private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class);
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Histogram durationHistogram = metricRegistry.histogram(name(WebSocketConnection.class, "connected_duration"));
private final AccountsManager accountsManager;
private final PushSender pushSender;
@ -33,23 +41,22 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
@Override
public void onWebSocketConnect(WebSocketSessionContext context) {
Account account = context.getAuthenticated(Account.class).get();
Device device = account.getAuthenticatedDevice().get();
final Account account = context.getAuthenticated(Account.class).get();
final Device device = account.getAuthenticatedDevice().get();
final long connectTime = System.currentTimeMillis();
final WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
final WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender,
messagesManager, account, device,
context.getClient());
updateLastSeen(account, device);
closeExistingDeviceConnection(account, device);
final WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender,
messagesManager, pubSubManager,
account, device,
context.getClient());
connection.onConnected();
pubSubManager.subscribe(address, connection);
context.addListener(new WebSocketSessionContext.WebSocketEventListener() {
@Override
public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) {
connection.onConnectionLost();
pubSubManager.unsubscribe(address, connection);
durationHistogram.update(System.currentTimeMillis() - connectTime);
}
});
}
@ -60,12 +67,5 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
accountsManager.update(account);
}
}
private void closeExistingDeviceConnection(Account account, Device device) {
pubSubManager.publish(new WebsocketAddress(account.getNumber(), device.getId()),
PubSubProtos.PubSubMessage.newBuilder()
.setType(PubSubProtos.PubSubMessage.Type.CLOSE)
.build());
}
}

View File

@ -3,11 +3,13 @@ package org.whispersystems.textsecuregcm.websocket;
import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos;
import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
public class DeadLetterHandler {
public class DeadLetterHandler implements DispatchChannel {
private final Logger logger = LoggerFactory.getLogger(DeadLetterHandler.class);
@ -17,14 +19,16 @@ public class DeadLetterHandler {
this.messagesManager = messagesManager;
}
public void handle(byte[] channel, PubSubProtos.PubSubMessage pubSubMessage) {
@Override
public void onDispatchMessage(String channel, byte[] data) {
try {
WebsocketAddress address = new WebsocketAddress(new String(channel));
logger.warn("Handling dead letter to: " + channel);
logger.warn("Handling dead letter to: " + address);
WebsocketAddress address = new WebsocketAddress(channel);
PubSubMessage pubSubMessage = PubSubMessage.parseFrom(data);
switch (pubSubMessage.getType().getNumber()) {
case PubSubProtos.PubSubMessage.Type.DELIVER_VALUE:
case PubSubMessage.Type.DELIVER_VALUE:
OutgoingMessageSignal message = OutgoingMessageSignal.parseFrom(pubSubMessage.getContent());
messagesManager.insert(address.getNumber(), address.getDeviceId(), message);
break;
@ -36,4 +40,13 @@ public class DeadLetterHandler {
}
}
@Override
public void onDispatchSubscribed(String channel) {
logger.warn("DeadLetterHandler subscription notice! " + channel);
}
@Override
public void onDispatchUnsubscribed(String channel) {
logger.warn("DeadLetterHandler unsubscribe notice! " + channel);
}
}

View File

@ -14,13 +14,15 @@ public class ProvisioningConnectListener implements WebSocketConnectListener {
@Override
public void onWebSocketConnect(WebSocketSessionContext context) {
final ProvisioningConnection connection = new ProvisioningConnection(pubSubManager, context.getClient());
connection.onConnected();
final ProvisioningConnection connection = new ProvisioningConnection(context.getClient());
final ProvisioningAddress provisioningAddress = ProvisioningAddress.generate();
pubSubManager.subscribe(provisioningAddress, connection);
context.addListener(new WebSocketSessionContext.WebSocketEventListener() {
@Override
public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) {
connection.onConnectionLost();
pubSubManager.unsubscribe(provisioningAddress, connection);
}
});
}

View File

@ -4,58 +4,62 @@ import com.google.common.base.Optional;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.textsecuregcm.entities.MessageProtos.ProvisioningUuid;
import org.whispersystems.textsecuregcm.storage.PubSubListener;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
public class ProvisioningConnection implements PubSubListener {
public class ProvisioningConnection implements DispatchChannel {
private final PubSubManager pubSubManager;
private final ProvisioningAddress provisioningAddress;
private final WebSocketClient client;
private final Logger logger = LoggerFactory.getLogger(ProvisioningConnection.class);
public ProvisioningConnection(PubSubManager pubSubManager, WebSocketClient client) {
this.pubSubManager = pubSubManager;
this.client = client;
this.provisioningAddress = ProvisioningAddress.generate();
private final WebSocketClient client;
public ProvisioningConnection(WebSocketClient client) {
this.client = client;
}
@Override
public void onPubSubMessage(PubSubMessage outgoingMessage) {
if (outgoingMessage.getType() == PubSubMessage.Type.DELIVER) {
Optional<byte[]> body = Optional.of(outgoingMessage.getContent().toByteArray());
public void onDispatchMessage(String channel, byte[] message) {
try {
PubSubMessage outgoingMessage = PubSubMessage.parseFrom(message);
ListenableFuture<WebSocketResponseMessage> response = client.sendRequest("PUT", "/v1/message", body);
if (outgoingMessage.getType() == PubSubMessage.Type.DELIVER) {
Optional<byte[]> body = Optional.of(outgoingMessage.getContent().toByteArray());
Futures.addCallback(response, new FutureCallback<WebSocketResponseMessage>() {
@Override
public void onSuccess(WebSocketResponseMessage webSocketResponseMessage) {
pubSubManager.unsubscribe(provisioningAddress, ProvisioningConnection.this);
client.close(1001, "All you get.");
}
ListenableFuture<WebSocketResponseMessage> response = client.sendRequest("PUT", "/v1/message", body);
@Override
public void onFailure(Throwable throwable) {
pubSubManager.unsubscribe(provisioningAddress, ProvisioningConnection.this);
client.close(1001, "That's all!");
}
});
Futures.addCallback(response, new FutureCallback<WebSocketResponseMessage>() {
@Override
public void onSuccess(WebSocketResponseMessage webSocketResponseMessage) {
client.close(1001, "All you get.");
}
@Override
public void onFailure(Throwable throwable) {
client.close(1001, "That's all!");
}
});
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Protobuf Error: ", e);
}
}
public void onConnected() {
this.pubSubManager.subscribe(provisioningAddress, this);
@Override
public void onDispatchSubscribed(String channel) {
this.client.sendRequest("PUT", "/v1/address", Optional.of(ProvisioningUuid.newBuilder()
.setUuid(provisioningAddress.getAddress())
.setUuid(channel)
.build()
.toByteArray()));
}
public void onConnectionLost() {
this.pubSubManager.unsubscribe(provisioningAddress, this);
this.client.close(1001, "Done");
@Override
public void onDispatchUnsubscribed(String channel) {
this.client.close(1001, "Closed");
}
}

View File

@ -10,6 +10,7 @@ import com.google.common.util.concurrent.ListenableFuture;
import com.google.protobuf.InvalidProtocolBufferException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.DispatchChannel;
import org.whispersystems.textsecuregcm.entities.CryptoEncodingException;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
@ -19,8 +20,6 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PubSubListener;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.WebSocketClient;
@ -34,21 +33,16 @@ import static com.codahale.metrics.MetricRegistry.name;
import static org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessage;
public class WebSocketConnection implements PubSubListener {
public class WebSocketConnection implements DispatchChannel {
private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class);
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Histogram durationHistogram = metricRegistry.histogram(name(WebSocketConnection.class, "connected_duration"));
private final AccountsManager accountsManager;
private final PushSender pushSender;
private final MessagesManager messagesManager;
private final PubSubManager pubSubManager;
private final Account account;
private final Device device;
private final WebsocketAddress address;
private final WebSocketClient client;
private long connectionStartTime;
@ -56,7 +50,6 @@ public class WebSocketConnection implements PubSubListener {
public WebSocketConnection(AccountsManager accountsManager,
PushSender pushSender,
MessagesManager messagesManager,
PubSubManager pubSubManager,
Account account,
Device device,
WebSocketClient client)
@ -64,27 +57,16 @@ public class WebSocketConnection implements PubSubListener {
this.accountsManager = accountsManager;
this.pushSender = pushSender;
this.messagesManager = messagesManager;
this.pubSubManager = pubSubManager;
this.account = account;
this.device = device;
this.client = client;
this.address = new WebsocketAddress(account.getNumber(), device.getId());
}
public void onConnected() {
connectionStartTime = System.currentTimeMillis();
pubSubManager.subscribe(address, this);
processStoredMessages();
}
public void onConnectionLost() {
durationHistogram.update(System.currentTimeMillis() - connectionStartTime);
pubSubManager.unsubscribe(address, this);
}
@Override
public void onPubSubMessage(PubSubMessage pubSubMessage) {
public void onDispatchMessage(String channel, byte[] message) {
try {
PubSubMessage pubSubMessage = PubSubMessage.parseFrom(message);
switch (pubSubMessage.getType().getNumber()) {
case PubSubMessage.Type.QUERY_DB_VALUE:
processStoredMessages();
@ -92,10 +74,6 @@ public class WebSocketConnection implements PubSubListener {
case PubSubMessage.Type.DELIVER_VALUE:
sendMessage(OutgoingMessageSignal.parseFrom(pubSubMessage.getContent()), Optional.<Long>absent());
break;
case PubSubMessage.Type.CLOSE_VALUE:
client.close(1000, "OK");
pubSubManager.unsubscribe(address, this);
break;
default:
logger.warn("Unknown pubsub message: " + pubSubMessage.getType().getNumber());
}
@ -104,6 +82,15 @@ public class WebSocketConnection implements PubSubListener {
}
}
@Override
public void onDispatchUnsubscribed(String channel) {
client.close(1000, "OK");
}
public void onDispatchSubscribed(String channel) {
processStoredMessages();
}
private void sendMessage(final OutgoingMessageSignal message,
final Optional<Long> storedMessageId)
{
@ -180,4 +167,6 @@ public class WebSocketConnection implements PubSubListener {
sendMessage(message.second(), Optional.of(message.first()));
}
}
}

View File

@ -0,0 +1,127 @@
package org.whispersystems.dispatch;
import com.google.common.base.Optional;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExternalResource;
import org.mockito.ArgumentCaptor;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.whispersystems.dispatch.io.RedisPubSubConnectionFactory;
import org.whispersystems.dispatch.redis.PubSubConnection;
import org.whispersystems.dispatch.redis.PubSubReply;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.*;
public class DispatchManagerTest {
private PubSubConnection pubSubConnection;
private RedisPubSubConnectionFactory socketFactory;
private DispatchManager dispatchManager;
private PubSubReplyInputStream pubSubReplyInputStream;
@Rule
public ExternalResource resource = new ExternalResource() {
@Override
protected void before() throws Throwable {
pubSubConnection = mock(PubSubConnection.class );
socketFactory = mock(RedisPubSubConnectionFactory.class);
pubSubReplyInputStream = new PubSubReplyInputStream();
when(socketFactory.connect()).thenReturn(pubSubConnection);
when(pubSubConnection.read()).thenAnswer(new Answer<PubSubReply>() {
@Override
public PubSubReply answer(InvocationOnMock invocationOnMock) throws Throwable {
return pubSubReplyInputStream.read();
}
});
dispatchManager = new DispatchManager(socketFactory, Optional.<DispatchChannel>absent());
dispatchManager.start();
}
@Override
protected void after() {
}
};
@Test
public void testConnect() {
verify(socketFactory).connect();
}
@Test
public void testSubscribe() throws IOException {
DispatchChannel dispatchChannel = mock(DispatchChannel.class);
dispatchManager.subscribe("foo", dispatchChannel);
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "foo", Optional.<byte[]>absent()));
verify(dispatchChannel, timeout(1000)).onDispatchSubscribed(eq("foo"));
}
@Test
public void testSubscribeUnsubscribe() throws IOException {
DispatchChannel dispatchChannel = mock(DispatchChannel.class);
dispatchManager.subscribe("foo", dispatchChannel);
dispatchManager.unsubscribe("foo", dispatchChannel);
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "foo", Optional.<byte[]>absent()));
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.UNSUBSCRIBE, "foo", Optional.<byte[]>absent()));
verify(dispatchChannel, timeout(1000)).onDispatchUnsubscribed(eq("foo"));
}
@Test
public void testMessages() throws IOException {
DispatchChannel fooChannel = mock(DispatchChannel.class);
DispatchChannel barChannel = mock(DispatchChannel.class);
dispatchManager.subscribe("foo", fooChannel);
dispatchManager.subscribe("bar", barChannel);
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "foo", Optional.<byte[]>absent()));
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.SUBSCRIBE, "bar", Optional.<byte[]>absent()));
verify(fooChannel, timeout(1000)).onDispatchSubscribed(eq("foo"));
verify(barChannel, timeout(1000)).onDispatchSubscribed(eq("bar"));
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.MESSAGE, "foo", Optional.of("hello".getBytes())));
pubSubReplyInputStream.write(new PubSubReply(PubSubReply.Type.MESSAGE, "bar", Optional.of("there".getBytes())));
ArgumentCaptor<byte[]> captor = ArgumentCaptor.forClass(byte[].class);
verify(fooChannel, timeout(1000)).onDispatchMessage(eq("foo"), captor.capture());
assertArrayEquals("hello".getBytes(), captor.getValue());
verify(barChannel, timeout(1000)).onDispatchMessage(eq("bar"), captor.capture());
assertArrayEquals("there".getBytes(), captor.getValue());
}
private static class PubSubReplyInputStream {
private final List<PubSubReply> pubSubReplyList = new LinkedList<>();
public synchronized PubSubReply read() {
try {
while (pubSubReplyList.isEmpty()) wait();
return pubSubReplyList.remove(0);
} catch (InterruptedException e) {
throw new AssertionError(e);
}
}
public synchronized void write(PubSubReply pubSubReply) {
pubSubReplyList.add(pubSubReply);
notifyAll();
}
}
}

View File

@ -0,0 +1,263 @@
package org.whispersystems.dispatch.redis;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import static org.junit.Assert.*;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Mockito.*;
public class PubSubConnectionTest {
private static final String REPLY = "*3\r\n" +
"$9\r\n" +
"subscribe\r\n" +
"$5\r\n" +
"abcde\r\n" +
":1\r\n" +
"*3\r\n" +
"$9\r\n" +
"subscribe\r\n" +
"$5\r\n" +
"fghij\r\n" +
":2\r\n" +
"*3\r\n" +
"$9\r\n" +
"subscribe\r\n" +
"$5\r\n" +
"klmno\r\n" +
":2\r\n" +
"*3\r\n" +
"$7\r\n" +
"message\r\n" +
"$5\r\n" +
"abcde\r\n" +
"$10\r\n" +
"1234567890\r\n" +
"*3\r\n" +
"$7\r\n" +
"message\r\n" +
"$5\r\n" +
"klmno\r\n" +
"$10\r\n" +
"0987654321\r\n";
@Test
public void testSubscribe() throws IOException {
// ByteChannel byteChannel = mock(ByteChannel.class);
OutputStream outputStream = mock(OutputStream.class);
Socket socket = mock(Socket.class );
when(socket.getOutputStream()).thenReturn(outputStream);
PubSubConnection connection = new PubSubConnection(socket);
connection.subscribe("foobar");
ArgumentCaptor<byte[]> captor = ArgumentCaptor.forClass(byte[].class);
verify(outputStream).write(captor.capture());
assertArrayEquals(captor.getValue(), "SUBSCRIBE foobar\r\n".getBytes());
}
@Test
public void testUnsubscribe() throws IOException {
OutputStream outputStream = mock(OutputStream.class);
Socket socket = mock(Socket.class );
when(socket.getOutputStream()).thenReturn(outputStream);
PubSubConnection connection = new PubSubConnection(socket);
connection.unsubscribe("bazbar");
ArgumentCaptor<byte[]> captor = ArgumentCaptor.forClass(byte[].class);
verify(outputStream).write(captor.capture());
assertArrayEquals(captor.getValue(), "UNSUBSCRIBE bazbar\r\n".getBytes());
}
@Test
public void testTricklyResponse() throws Exception {
InputStream inputStream = mockInputStreamFor(new TrickleInputStream(REPLY.getBytes()));
OutputStream outputStream = mock(OutputStream.class);
Socket socket = mock(Socket.class );
when(socket.getOutputStream()).thenReturn(outputStream);
when(socket.getInputStream()).thenReturn(inputStream);
PubSubConnection pubSubConnection = new PubSubConnection(socket);
readResponses(pubSubConnection);
}
@Test
public void testFullResponse() throws Exception {
InputStream inputStream = mockInputStreamFor(new FullInputStream(REPLY.getBytes()));
OutputStream outputStream = mock(OutputStream.class);
Socket socket = mock(Socket.class );
when(socket.getOutputStream()).thenReturn(outputStream);
when(socket.getInputStream()).thenReturn(inputStream);
PubSubConnection pubSubConnection = new PubSubConnection(socket);
readResponses(pubSubConnection);
}
@Test
public void testRandomLengthResponse() throws Exception {
InputStream inputStream = mockInputStreamFor(new RandomInputStream(REPLY.getBytes()));
OutputStream outputStream = mock(OutputStream.class);
Socket socket = mock(Socket.class );
when(socket.getOutputStream()).thenReturn(outputStream);
when(socket.getInputStream()).thenReturn(inputStream);
PubSubConnection pubSubConnection = new PubSubConnection(socket);
readResponses(pubSubConnection);
}
private InputStream mockInputStreamFor(final MockInputStream stub) throws IOException {
InputStream result = mock(InputStream.class);
when(result.read()).thenAnswer(new Answer<Integer>() {
@Override
public Integer answer(InvocationOnMock invocationOnMock) throws Throwable {
return stub.read();
}
});
when(result.read(any(byte[].class))).thenAnswer(new Answer<Integer>() {
@Override
public Integer answer(InvocationOnMock invocationOnMock) throws Throwable {
byte[] buffer = (byte[])invocationOnMock.getArguments()[0];
return stub.read(buffer, 0, buffer.length);
}
});
when(result.read(any(byte[].class), anyInt(), anyInt())).thenAnswer(new Answer<Integer>() {
@Override
public Integer answer(InvocationOnMock invocationOnMock) throws Throwable {
byte[] buffer = (byte[]) invocationOnMock.getArguments()[0];
int offset = (int) invocationOnMock.getArguments()[1];
int length = (int) invocationOnMock.getArguments()[2];
return stub.read(buffer, offset, length);
}
});
return result;
}
private void readResponses(PubSubConnection pubSubConnection) throws Exception {
PubSubReply reply = pubSubConnection.read();
assertEquals(reply.getType(), PubSubReply.Type.SUBSCRIBE);
assertEquals(reply.getChannel(), "abcde");
assertFalse(reply.getContent().isPresent());
reply = pubSubConnection.read();
assertEquals(reply.getType(), PubSubReply.Type.SUBSCRIBE);
assertEquals(reply.getChannel(), "fghij");
assertFalse(reply.getContent().isPresent());
reply = pubSubConnection.read();
assertEquals(reply.getType(), PubSubReply.Type.SUBSCRIBE);
assertEquals(reply.getChannel(), "klmno");
assertFalse(reply.getContent().isPresent());
reply = pubSubConnection.read();
assertEquals(reply.getType(), PubSubReply.Type.MESSAGE);
assertEquals(reply.getChannel(), "abcde");
assertArrayEquals(reply.getContent().get(), "1234567890".getBytes());
reply = pubSubConnection.read();
assertEquals(reply.getType(), PubSubReply.Type.MESSAGE);
assertEquals(reply.getChannel(), "klmno");
assertArrayEquals(reply.getContent().get(), "0987654321".getBytes());
}
private interface MockInputStream {
public int read();
public int read(byte[] input, int offset, int length);
}
private static class TrickleInputStream implements MockInputStream {
private final byte[] data;
private int index = 0;
private TrickleInputStream(byte[] data) {
this.data = data;
}
public int read() {
return data[index++];
}
public int read(byte[] input, int offset, int length) {
input[offset] = data[index++];
return 1;
}
}
private static class FullInputStream implements MockInputStream {
private final byte[] data;
private int index = 0;
private FullInputStream(byte[] data) {
this.data = data;
}
public int read() {
return data[index++];
}
public int read(byte[] input, int offset, int length) {
int amount = Math.min(data.length - index, length);
System.arraycopy(data, index, input, offset, amount);
index += length;
return amount;
}
}
private static class RandomInputStream implements MockInputStream {
private final byte[] data;
private int index = 0;
private RandomInputStream(byte[] data) {
this.data = data;
}
public int read() {
return data[index++];
}
public int read(byte[] input, int offset, int length) {
try {
int maxCopy = Math.min(data.length - index, length);
int randomCopy = SecureRandom.getInstance("SHA1PRNG").nextInt(maxCopy) + 1;
int copyAmount = Math.min(maxCopy, randomCopy);
System.arraycopy(data, index, input, offset, copyAmount);
index += copyAmount;
return copyAmount;
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
}
}
}
}

View File

@ -0,0 +1,51 @@
package org.whispersystems.dispatch.redis.protocol;
import org.junit.Test;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
public class ArrayReplyHeaderTest {
@Test(expected = IOException.class)
public void testNull() throws IOException {
new ArrayReplyHeader(null);
}
@Test(expected = IOException.class)
public void testBadPrefix() throws IOException {
new ArrayReplyHeader(":3");
}
@Test(expected = IOException.class)
public void testEmpty() throws IOException {
new ArrayReplyHeader("");
}
@Test(expected = IOException.class)
public void testTruncated() throws IOException {
new ArrayReplyHeader("*");
}
@Test(expected = IOException.class)
public void testBadNumber() throws IOException {
new ArrayReplyHeader("*ABC");
}
@Test
public void testValid() throws IOException {
assertEquals(4, new ArrayReplyHeader("*4").getElementCount());
}
}

View File

@ -0,0 +1,36 @@
package org.whispersystems.dispatch.redis.protocol;
import org.junit.Test;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
public class IntReplyHeaderTest {
@Test(expected = IOException.class)
public void testNull() throws IOException {
new IntReply(null);
}
@Test(expected = IOException.class)
public void testEmpty() throws IOException {
new IntReply("");
}
@Test(expected = IOException.class)
public void testBadNumber() throws IOException {
new IntReply(":A");
}
@Test(expected = IOException.class)
public void testBadFormat() throws IOException {
new IntReply("*");
}
@Test
public void testValid() throws IOException {
assertEquals(23, new IntReply(":23").getValue());
}
}

View File

@ -0,0 +1,47 @@
package org.whispersystems.dispatch.redis.protocol;
import org.junit.Test;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
public class StringReplyHeaderTest {
@Test
public void testNull() {
try {
new StringReplyHeader(null);
throw new AssertionError();
} catch (IOException e) {
// good
}
}
@Test
public void testBadNumber() {
try {
new StringReplyHeader("$100A");
throw new AssertionError();
} catch (IOException e) {
// good
}
}
@Test
public void testBadPrefix() {
try {
new StringReplyHeader("*");
throw new AssertionError();
} catch (IOException e) {
// good
}
}
@Test
public void testValid() throws IOException {
assertEquals(1000, new StringReplyHeader("$1000").getStringLength());
}
}

View File

@ -44,45 +44,20 @@ import static org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMe
public class WebSocketConnectionTest {
// private static final ObjectMapper mapper = new ObjectMapper();
private static final String VALID_USER = "+14152222222";
private static final String INVALID_USER = "+14151111111";
private static final String VALID_PASSWORD = "secure";
private static final String INVALID_PASSWORD = "insecure";
// private static final StoredMessages storedMessages = mock(StoredMessages.class);
private static final AccountAuthenticator accountAuthenticator = mock(AccountAuthenticator.class);
private static final AccountsManager accountsManager = mock(AccountsManager.class);
private static final PubSubManager pubSubManager = mock(PubSubManager.class );
private static final Account account = mock(Account.class );
private static final Device device = mock(Device.class );
private static final UpgradeRequest upgradeRequest = mock(UpgradeRequest.class );
// private static final Session session = mock(Session.class );
private static final PushSender pushSender = mock(PushSender.class);
@Test
public void testCloseExisting() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class );
WebSocketConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, pushSender, storedMessages, pubSubManager);
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
Account account = mock(Account.class );
Device device = mock(Device.class );
when(sessionContext.getAuthenticated(Account.class)).thenReturn(Optional.of(account));
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14157777777");
when(device.getId()).thenReturn(1L);
connectListener.onWebSocketConnect(sessionContext);
ArgumentCaptor<PubSubProtos.PubSubMessage> message = ArgumentCaptor.forClass(PubSubProtos.PubSubMessage.class);
verify(pubSubManager).publish(eq(new WebsocketAddress("+14157777777", 1L)), message.capture());
assertEquals(message.getValue().getType().getNumber(), PubSubProtos.PubSubMessage.Type.CLOSE_VALUE);
}
@Test
public void testCredentials() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class);
@ -98,10 +73,6 @@ public class WebSocketConnectionTest {
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
// when(session.getUpgradeRequest()).thenReturn(upgradeRequest);
//
// WebsocketController controller = new WebsocketController(accountAuthenticator, accountsManager, pushSender, pubSubManager, storedMessages);
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, String[]>() {{
put("login", new String[] {VALID_USER});
put("password", new String[] {VALID_PASSWORD});
@ -114,13 +85,6 @@ public class WebSocketConnectionTest {
verify(sessionContext).addListener(any(WebSocketSessionContext.WebSocketEventListener.class));
//
// controller.onWebSocketConnect(session);
// verify(session, never()).close();
// verify(session, never()).close(any(CloseStatus.class));
// verify(session, never()).close(anyInt(), anyString());
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, String[]>() {{
put("login", new String[] {INVALID_USER});
put("password", new String[] {INVALID_PASSWORD});
@ -128,15 +92,6 @@ public class WebSocketConnectionTest {
account = webSocketAuthenticator.authenticate(upgradeRequest);
assertFalse(account.isPresent());
// when(sessionContext.getAuthenticated(Account.class)).thenReturn(account);
//
// WebSocketClient client = mock(WebSocketClient.class);
// when(sessionContext.getClient()).thenReturn(client);
//
// connectListener.onWebSocketConnect(sessionContext);
//
// verify(sessionContext, times(1)).addListener(any(WebSocketSessionContext.WebSocketEventListener.class));
// verify(client).close(eq(4001), anyString());
}
@Test
@ -183,12 +138,11 @@ public class WebSocketConnectionTest {
}
});
WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId());
WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender, storedMessages,
pubSubManager, account, device, client);
account, device, client);
connection.onConnected();
verify(pubSubManager).subscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq((connection)));
connection.onDispatchSubscribed(websocketAddress.serialize());
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(Optional.class));
assertTrue(futures.size() == 3);
@ -205,11 +159,10 @@ public class WebSocketConnectionTest {
add(createMessage("sender2", 3333, false, "third"));
}};
// verify(pushSender, times(2)).sendMessage(eq(account), eq(device), any(OutgoingMessageSignal.class));
verify(pushSender, times(1)).sendMessage(eq(sender1), eq(sender1device), any(OutgoingMessageSignal.class));
connection.onConnectionLost();
verify(pubSubManager).unsubscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq(connection));
connection.onDispatchUnsubscribed(websocketAddress.serialize());
verify(client).close(anyInt(), anyString());
}
private OutgoingMessageSignal createMessage(String sender, long timestamp, boolean receipt, String content) {