diff --git a/pom.xml b/pom.xml
index de6cd7651..4ce3ae012 100644
--- a/pom.xml
+++ b/pom.xml
@@ -108,6 +108,12 @@
jersey-json
1.17.1
+
+
+ org.eclipse.jetty
+ jetty-websocket
+ 8.1.14.v20131031
+
diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
index f2381c304..66db201d7 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
@@ -20,6 +20,7 @@ import com.google.common.base.Optional;
import com.yammer.dropwizard.Service;
import com.yammer.dropwizard.config.Bootstrap;
import com.yammer.dropwizard.config.Environment;
+import com.yammer.dropwizard.config.HttpConfiguration;
import com.yammer.dropwizard.db.DatabaseConfiguration;
import com.yammer.dropwizard.jdbi.DBIFactory;
import com.yammer.dropwizard.migrations.MigrationsBundle;
@@ -32,12 +33,13 @@ import org.whispersystems.textsecuregcm.auth.FederatedPeerAuthenticator;
import org.whispersystems.textsecuregcm.auth.MultiBasicAuthProvider;
import org.whispersystems.textsecuregcm.configuration.NexmoConfiguration;
import org.whispersystems.textsecuregcm.controllers.AccountController;
-import org.whispersystems.textsecuregcm.controllers.DeviceController;
import org.whispersystems.textsecuregcm.controllers.AttachmentController;
+import org.whispersystems.textsecuregcm.controllers.DeviceController;
import org.whispersystems.textsecuregcm.controllers.DirectoryController;
import org.whispersystems.textsecuregcm.controllers.FederationController;
import org.whispersystems.textsecuregcm.controllers.KeysController;
import org.whispersystems.textsecuregcm.controllers.MessageController;
+import org.whispersystems.textsecuregcm.controllers.WebsocketControllerFactory;
import org.whispersystems.textsecuregcm.federation.FederatedClientManager;
import org.whispersystems.textsecuregcm.federation.FederatedPeer;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
@@ -51,15 +53,16 @@ import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.sms.NexmoSmsSender;
import org.whispersystems.textsecuregcm.sms.SmsSender;
import org.whispersystems.textsecuregcm.sms.TwilioSmsSender;
-import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
+import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DirectoryManager;
import org.whispersystems.textsecuregcm.storage.Keys;
import org.whispersystems.textsecuregcm.storage.PendingAccounts;
import org.whispersystems.textsecuregcm.storage.PendingAccountsManager;
import org.whispersystems.textsecuregcm.storage.PendingDevices;
import org.whispersystems.textsecuregcm.storage.PendingDevicesManager;
+import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.StoredMessageManager;
import org.whispersystems.textsecuregcm.storage.StoredMessages;
import org.whispersystems.textsecuregcm.util.CORSHeaderFilter;
@@ -93,6 +96,8 @@ public class WhisperServerService extends Service {
public void run(WhisperServerConfiguration config, Environment environment)
throws Exception
{
+ config.getHttpConfiguration().setConnectorType(HttpConfiguration.ConnectorType.NONBLOCKING);
+
DBIFactory dbiFactory = new DBIFactory();
DBI jdbi = dbiFactory.build(environment, config.getDatabaseConfiguration(), "postgresql");
@@ -110,7 +115,8 @@ public class WhisperServerService extends Service {
PendingDevicesManager pendingDevicesManager = new PendingDevicesManager(pendingDevices, memcachedClient);
AccountsManager accountsManager = new AccountsManager(accounts, directory, memcachedClient);
FederatedClientManager federatedClientManager = new FederatedClientManager(config.getFederationConfiguration());
- StoredMessageManager storedMessageManager = new StoredMessageManager(storedMessages);
+ PubSubManager pubSubManager = new PubSubManager(redisClient);
+ StoredMessageManager storedMessageManager = new StoredMessageManager(storedMessages, pubSubManager);
AccountAuthenticator deviceAuthenticator = new AccountAuthenticator(accountsManager);
RateLimiters rateLimiters = new RateLimiters(config.getLimitsConfiguration(), memcachedClient);
@@ -141,6 +147,9 @@ public class WhisperServerService extends Service {
environment.addResource(keysController);
environment.addResource(messageController);
+ environment.addServlet(new WebsocketControllerFactory(deviceAuthenticator, storedMessageManager, pubSubManager),
+ "/v1/websocket/");
+
environment.addHealthCheck(new RedisHealthCheck(redisClient));
environment.addHealthCheck(new MemcacheHealthCheck(memcachedClient));
diff --git a/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketController.java b/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketController.java
new file mode 100644
index 000000000..ea6022f39
--- /dev/null
+++ b/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketController.java
@@ -0,0 +1,155 @@
+package org.whispersystems.textsecuregcm.controllers;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.eclipse.jetty.websocket.WebSocket;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.textsecuregcm.entities.AcknowledgeWebsocketMessage;
+import org.whispersystems.textsecuregcm.entities.IncomingWebsocketMessage;
+import org.whispersystems.textsecuregcm.storage.Account;
+import org.whispersystems.textsecuregcm.storage.Device;
+import org.whispersystems.textsecuregcm.storage.PubSubListener;
+import org.whispersystems.textsecuregcm.storage.PubSubManager;
+import org.whispersystems.textsecuregcm.storage.PubSubMessage;
+import org.whispersystems.textsecuregcm.storage.StoredMessageManager;
+import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
+import org.whispersystems.textsecuregcm.websocket.WebsocketMessage;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+
+public class WebsocketController implements WebSocket.OnTextMessage, PubSubListener {
+
+ private static final Logger logger = LoggerFactory.getLogger(WebsocketController.class);
+ private static final ObjectMapper mapper = new ObjectMapper();
+ private static final Map pendingMessages = new HashMap<>();
+
+ private final StoredMessageManager storedMessageManager;
+ private final PubSubManager pubSubManager;
+
+ private final Account account;
+ private final Device device;
+
+ private Connection connection;
+ private long pendingMessageSequence;
+
+ public WebsocketController(StoredMessageManager storedMessageManager,
+ PubSubManager pubSubManager,
+ Account account)
+ {
+ this.storedMessageManager = storedMessageManager;
+ this.pubSubManager = pubSubManager;
+ this.account = account;
+ this.device = account.getAuthenticatedDevice().get();
+ }
+
+
+ @Override
+ public void onOpen(Connection connection) {
+ this.connection = connection;
+ pubSubManager.subscribe(new WebsocketAddress(this.account.getId(), this.device.getId()), this);
+ handleQueryDatabase();
+ }
+
+ @Override
+ public void onClose(int i, String s) {
+ handleClose();
+ }
+
+ @Override
+ public void onMessage(String body) {
+ try {
+ IncomingWebsocketMessage incomingMessage = mapper.readValue(body, IncomingWebsocketMessage.class);
+
+ switch (incomingMessage.getType()) {
+ case IncomingWebsocketMessage.TYPE_ACKNOWLEDGE_MESSAGE: handleMessageAck(body); break;
+ case IncomingWebsocketMessage.TYPE_PING_MESSAGE: handlePing(); break;
+ default: handleClose(); break;
+ }
+ } catch (IOException e) {
+ logger.debug("Parse", e);
+ handleClose();
+ }
+ }
+
+ @Override
+ public void onPubSubMessage(PubSubMessage outgoingMessage) {
+ switch (outgoingMessage.getType()) {
+ case PubSubMessage.TYPE_DELIVER: handleDeliverOutgoingMessage(outgoingMessage.getContents()); break;
+ case PubSubMessage.TYPE_QUERY_DB: handleQueryDatabase(); break;
+ default:
+ logger.warn("Unknown pubsub message: " + outgoingMessage.getType());
+ }
+ }
+
+ private void handleDeliverOutgoingMessage(String message) {
+ try {
+ long messageSequence;
+
+ synchronized (pendingMessages) {
+ messageSequence = pendingMessageSequence++;
+ pendingMessages.put(messageSequence, message);
+ }
+
+ connection.sendMessage(mapper.writeValueAsString(new WebsocketMessage(messageSequence, message)));
+ } catch (IOException e) {
+ logger.debug("Response failed", e);
+ handleClose();
+ }
+ }
+
+ private void handleMessageAck(String message) {
+ try {
+ AcknowledgeWebsocketMessage ack = mapper.readValue(message, AcknowledgeWebsocketMessage.class);
+
+ synchronized (pendingMessages) {
+ pendingMessages.remove(ack.getId());
+ }
+ } catch (IOException e) {
+ logger.warn("Mapping", e);
+ }
+ }
+
+ private void handlePing() {
+ try {
+ IncomingWebsocketMessage pongMessage = new IncomingWebsocketMessage(IncomingWebsocketMessage.TYPE_PONG_MESSAGE);
+ connection.sendMessage(mapper.writeValueAsString(pongMessage));
+ } catch (IOException e) {
+ logger.warn("Pong failed", e);
+ handleClose();
+ }
+ }
+
+ private void handleClose() {
+ pubSubManager.unsubscribe(new WebsocketAddress(account.getId(), device.getId()), this);
+ connection.close();
+
+ List remainingMessages = new LinkedList<>();
+
+ synchronized (pendingMessages) {
+ Long[] pendingKeys = pendingMessages.keySet().toArray(new Long[0]);
+ Arrays.sort(pendingKeys);
+
+ for (long pendingKey : pendingKeys) {
+ remainingMessages.add(pendingMessages.get(pendingKey));
+ }
+
+ pendingMessages.clear();
+ }
+
+ storedMessageManager.storeMessages(account, device, remainingMessages);
+ }
+
+ private void handleQueryDatabase() {
+ List messages = storedMessageManager.getOutgoingMessages(account, device);
+
+ for (String message : messages) {
+ handleDeliverOutgoingMessage(message);
+ }
+ }
+
+}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketControllerFactory.java b/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketControllerFactory.java
new file mode 100644
index 000000000..825ebb2c9
--- /dev/null
+++ b/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketControllerFactory.java
@@ -0,0 +1,98 @@
+package org.whispersystems.textsecuregcm.controllers;
+
+import com.google.common.base.Optional;
+import com.yammer.dropwizard.auth.AuthenticationException;
+import com.yammer.dropwizard.auth.basic.BasicCredentials;
+import org.eclipse.jetty.websocket.WebSocket;
+import org.eclipse.jetty.websocket.WebSocketServlet;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
+import org.whispersystems.textsecuregcm.storage.Account;
+import org.whispersystems.textsecuregcm.storage.PubSubManager;
+import org.whispersystems.textsecuregcm.storage.StoredMessageManager;
+
+import javax.servlet.http.HttpServletRequest;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+
+public class WebsocketControllerFactory extends WebSocketServlet {
+
+ private final Logger logger = LoggerFactory.getLogger(WebsocketControllerFactory.class);
+
+ private final StoredMessageManager storedMessageManager;
+ private final PubSubManager pubSubManager;
+ private final AccountAuthenticator accountAuthenticator;
+
+ private final LinkedHashMap> cache =
+ new LinkedHashMap>() {
+ @Override
+ protected boolean removeEldestEntry(Map.Entry> eldest) {
+ return size() > 10;
+ }
+ };
+
+ public WebsocketControllerFactory(AccountAuthenticator accountAuthenticator,
+ StoredMessageManager storedMessageManager,
+ PubSubManager pubSubManager)
+ {
+ this.accountAuthenticator = accountAuthenticator;
+ this.storedMessageManager = storedMessageManager;
+ this.pubSubManager = pubSubManager;
+ }
+
+ @Override
+ public WebSocket doWebSocketConnect(HttpServletRequest request, String s) {
+ try {
+ String username = request.getParameter("user");
+ String password = request.getParameter("password");
+
+ if (username == null || password == null) {
+ return null;
+ }
+
+ BasicCredentials credentials = new BasicCredentials(username, password);
+
+ Optional account = cache.remove(credentials);
+
+ if (account == null) {
+ account = accountAuthenticator.authenticate(new BasicCredentials(username, password));
+ }
+
+ if (!account.isPresent()) {
+ return null;
+ }
+
+ return new WebsocketController(storedMessageManager, pubSubManager, account.get());
+ } catch (AuthenticationException e) {
+ throw new AssertionError(e);
+ }
+ }
+
+ @Override
+ public boolean checkOrigin(HttpServletRequest request, String origin) {
+ try {
+ String username = request.getParameter("user");
+ String password = request.getParameter("password");
+
+ if (username == null || password == null) {
+ return false;
+ }
+
+ BasicCredentials credentials = new BasicCredentials(username, password);
+ Optional account = accountAuthenticator.authenticate(credentials);
+
+ if (!account.isPresent()) {
+ return false;
+ }
+
+ cache.put(credentials, account);
+
+ return true;
+ } catch (AuthenticationException e) {
+ logger.warn("Auth Failure", e);
+ return false;
+ }
+ }
+}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/entities/AcknowledgeWebsocketMessage.java b/src/main/java/org/whispersystems/textsecuregcm/entities/AcknowledgeWebsocketMessage.java
new file mode 100644
index 000000000..2963678a2
--- /dev/null
+++ b/src/main/java/org/whispersystems/textsecuregcm/entities/AcknowledgeWebsocketMessage.java
@@ -0,0 +1,16 @@
+package org.whispersystems.textsecuregcm.entities;
+
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class AcknowledgeWebsocketMessage extends IncomingWebsocketMessage {
+
+ @JsonProperty
+ private long id;
+
+ public long getId() {
+ return id;
+ }
+
+}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingWebsocketMessage.java b/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingWebsocketMessage.java
new file mode 100644
index 000000000..db3ba3e7f
--- /dev/null
+++ b/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingWebsocketMessage.java
@@ -0,0 +1,25 @@
+package org.whispersystems.textsecuregcm.entities;
+
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+@JsonIgnoreProperties(ignoreUnknown = false)
+public class IncomingWebsocketMessage {
+
+ public static final int TYPE_ACKNOWLEDGE_MESSAGE = 1;
+ public static final int TYPE_PING_MESSAGE = 2;
+ public static final int TYPE_PONG_MESSAGE = 3;
+
+ @JsonProperty
+ private int type;
+
+ public IncomingWebsocketMessage() {}
+
+ public IncomingWebsocketMessage(int type) {
+ this.type = type;
+ }
+
+ public int getType() {
+ return type;
+ }
+}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java
index fcb221cf5..8f212a79d 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java
@@ -62,7 +62,7 @@ public class PushSender {
if (device.getGcmId() != null) sendGcmMessage(account, device, message);
else if (device.getApnId() != null) sendApnMessage(account, device, message);
- else if (device.getFetchesMessages()) storeFetchedMessage(device, message);
+ else if (device.getFetchesMessages()) storeFetchedMessage(account, device, message);
else throw new NotPushRegisteredException("No delivery possible!");
}
@@ -97,11 +97,11 @@ public class PushSender {
}
}
- private void storeFetchedMessage(Device device, EncryptedOutgoingMessage outgoingMessage)
+ private void storeFetchedMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage)
throws NotPushRegisteredException
{
try {
- storedMessageManager.storeMessage(device, outgoingMessage);
+ storedMessageManager.storeMessage(account, device, outgoingMessage);
} catch (CryptoEncodingException e) {
throw new NotPushRegisteredException(e);
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java b/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java
index d61832ae4..bd8738239 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java
@@ -29,6 +29,9 @@ public class Account implements Serializable {
public static final int MEMCACHE_VERION = 2;
+ @JsonIgnore
+ private long id;
+
@JsonProperty
private String number;
@@ -48,6 +51,14 @@ public class Account implements Serializable {
this.supportsSms = supportsSms;
}
+ public long getId() {
+ return id;
+ }
+
+ public void setId(long id) {
+ this.id = id;
+ }
+
public Optional getAuthenticatedDevice() {
return authenticatedDevice;
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java b/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java
index 5d74452ab..62c18b69c 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java
@@ -95,7 +95,10 @@ public abstract class Accounts {
throws SQLException
{
try {
- return mapper.readValue(resultSet.getString(DATA), Account.class);
+ Account account = mapper.readValue(resultSet.getString(DATA), Account.class);
+ account.setId(resultSet.getLong(ID));
+
+ return account;
} catch (IOException e) {
throw new SQLException(e);
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubListener.java b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubListener.java
new file mode 100644
index 000000000..d9a24d595
--- /dev/null
+++ b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubListener.java
@@ -0,0 +1,7 @@
+package org.whispersystems.textsecuregcm.storage;
+
+public interface PubSubListener {
+
+ public void onPubSubMessage(PubSubMessage outgoingMessage);
+
+}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubManager.java b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubManager.java
new file mode 100644
index 000000000..b9e10b7e8
--- /dev/null
+++ b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubManager.java
@@ -0,0 +1,157 @@
+package org.whispersystems.textsecuregcm.storage;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.textsecuregcm.websocket.InvalidWebsocketAddressException;
+import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import redis.clients.jedis.Jedis;
+import redis.clients.jedis.JedisPool;
+import redis.clients.jedis.JedisPubSub;
+
+public class PubSubManager {
+
+ private final Logger logger = LoggerFactory.getLogger(PubSubManager.class);
+ private final ObjectMapper mapper = new ObjectMapper();
+ private final SubscriptionListener baseListener = new SubscriptionListener();
+ private final Map listeners = new HashMap<>();
+
+ private final JedisPool jedisPool;
+ private boolean subscribed = false;
+
+ public PubSubManager(final JedisPool jedisPool) {
+ this.jedisPool = jedisPool;
+ initializePubSubWorker();
+ waitForSubscription();
+ }
+
+ public synchronized void subscribe(WebsocketAddress address, PubSubListener listener) {
+ listeners.put(address, listener);
+ baseListener.subscribe(address.toString());
+ }
+
+ public synchronized void unsubscribe(WebsocketAddress address, PubSubListener listener) {
+ if (listeners.get(address) == listener) {
+ listeners.remove(address);
+ baseListener.unsubscribe(address.toString());
+ }
+ }
+
+ public synchronized boolean publish(WebsocketAddress address, PubSubMessage message) {
+ try {
+ String serialized = mapper.writeValueAsString(message);
+ Jedis jedis = null;
+
+ try {
+ jedis = jedisPool.getResource();
+ return jedis.publish(address.toString(), serialized) != 0;
+ } finally {
+ if (jedis != null)
+ jedisPool.returnResource(jedis);
+ }
+ } catch (JsonProcessingException e) {
+ throw new AssertionError(e);
+ }
+ }
+
+ private synchronized void waitForSubscription() {
+ try {
+ while (!subscribed) {
+ wait();
+ }
+ } catch (InterruptedException e) {
+ throw new AssertionError(e);
+ }
+ }
+
+ private void initializePubSubWorker() {
+ new Thread("PubSubListener") {
+ @Override
+ public void run() {
+ for (;;) {
+ Jedis jedis = null;
+ try {
+ jedis = jedisPool.getResource();
+ jedis.subscribe(baseListener, new WebsocketAddress(0, 0).toString());
+ logger.warn("**** Unsubscribed from holding channel!!! ******");
+ } finally {
+ if (jedis != null)
+ jedisPool.returnResource(jedis);
+ }
+ }
+ }
+ }.start();
+
+ new Thread("PubSubKeepAlive") {
+ @Override
+ public void run() {
+ for (;;) {
+ try {
+ Thread.sleep(20000);
+ publish(new WebsocketAddress(0, 0), new PubSubMessage(0, "foo"));
+ } catch (InterruptedException e) {
+ throw new AssertionError(e);
+ }
+ }
+ }
+ }.start();
+ }
+
+ private class SubscriptionListener extends JedisPubSub {
+
+ @Override
+ public void onMessage(String channel, String message) {
+ try {
+ WebsocketAddress address = new WebsocketAddress(channel);
+ PubSubListener listener;
+
+ synchronized (PubSubManager.this) {
+ listener = listeners.get(address);
+ }
+
+ if (listener != null) {
+ listener.onPubSubMessage(mapper.readValue(message, PubSubMessage.class));
+ }
+ } catch (InvalidWebsocketAddressException e) {
+ logger.warn("Address", e);
+ } catch (IOException e) {
+ logger.warn("IOE", e);
+ }
+ }
+
+ @Override
+ public void onPMessage(String s, String s2, String s3) {
+ logger.warn("Received PMessage!");
+ }
+
+ @Override
+ public void onSubscribe(String channel, int count) {
+ try {
+ WebsocketAddress address = new WebsocketAddress(channel);
+ if (address.getAccountId() == 0 && address.getDeviceId() == 0) {
+ synchronized (PubSubManager.this) {
+ subscribed = true;
+ PubSubManager.this.notifyAll();
+ }
+ }
+ } catch (InvalidWebsocketAddressException e) {
+ logger.warn("Weird address", e);
+ }
+ }
+
+ @Override
+ public void onUnsubscribe(String s, int i) {}
+
+ @Override
+ public void onPUnsubscribe(String s, int i) {}
+
+ @Override
+ public void onPSubscribe(String s, int i) {}
+ }
+}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubMessage.java b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubMessage.java
new file mode 100644
index 000000000..fa24dbd48
--- /dev/null
+++ b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubMessage.java
@@ -0,0 +1,32 @@
+package org.whispersystems.textsecuregcm.storage;
+
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class PubSubMessage {
+
+ public static final int TYPE_QUERY_DB = 1;
+ public static final int TYPE_DELIVER = 2;
+
+ @JsonProperty
+ private int type;
+
+ @JsonProperty
+ private String contents;
+
+ public PubSubMessage() {}
+
+ public PubSubMessage(int type, String contents) {
+ this.type = type;
+ this.contents = contents;
+ }
+
+ public int getType() {
+ return type;
+ }
+
+ public String getContents() {
+ return contents;
+ }
+}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessageManager.java b/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessageManager.java
index 7782e8969..b0de973af 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessageManager.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessageManager.java
@@ -18,23 +18,49 @@ package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.entities.CryptoEncodingException;
import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
+import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
-import java.io.IOException;
import java.util.List;
public class StoredMessageManager {
- StoredMessages storedMessages;
- public StoredMessageManager(StoredMessages storedMessages) {
+
+ private final StoredMessages storedMessages;
+ private final PubSubManager pubSubManager;
+
+ public StoredMessageManager(StoredMessages storedMessages, PubSubManager pubSubManager) {
this.storedMessages = storedMessages;
+ this.pubSubManager = pubSubManager;
}
- public void storeMessage(Device device, EncryptedOutgoingMessage outgoingMessage)
+ public void storeMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage)
throws CryptoEncodingException
{
- storedMessages.insert(device.getId(), outgoingMessage.serialize());
+ storeMessage(account, device, outgoingMessage.serialize());
}
- public List getStoredMessage(Device device) {
- return storedMessages.getMessagesForAccountId(device.getId());
+ public void storeMessages(Account account, Device device, List serializedMessages) {
+ for (String serializedMessage : serializedMessages) {
+ storeMessage(account, device, serializedMessage);
+ }
+ }
+
+ private void storeMessage(Account account, Device device, String serializedMessage) {
+ if (device.getFetchesMessages()) {
+ WebsocketAddress address = new WebsocketAddress(account.getId(), device.getId());
+ PubSubMessage pubSubMessage = new PubSubMessage(PubSubMessage.TYPE_DELIVER, serializedMessage);
+
+ if (!pubSubManager.publish(address, pubSubMessage)) {
+ storedMessages.insert(account.getId(), device.getId(), serializedMessage);
+ pubSubManager.publish(address, new PubSubMessage(PubSubMessage.TYPE_QUERY_DB, null));
+ }
+
+ return;
+ }
+
+ storedMessages.insert(account.getId(), device.getId(), serializedMessage);
+ }
+
+ public List getOutgoingMessages(Account account, Device device) {
+ return storedMessages.getMessagesForDevice(account.getId(), device.getId());
}
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java b/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java
index 9bd1f4f9c..fdc18b8f1 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java
@@ -17,6 +17,7 @@
package org.whispersystems.textsecuregcm.storage;
import org.skife.jdbi.v2.sqlobject.Bind;
+import org.skife.jdbi.v2.sqlobject.SqlBatch;
import org.skife.jdbi.v2.sqlobject.SqlQuery;
import org.skife.jdbi.v2.sqlobject.SqlUpdate;
@@ -24,9 +25,12 @@ import java.util.List;
public interface StoredMessages {
- @SqlUpdate("INSERT INTO stored_messages (destination_id, encrypted_message) VALUES (:destination_id, :encrypted_message)")
- void insert(@Bind("destination_id") long destinationAccountId, @Bind("encrypted_message") String encryptedOutgoingMessage);
+ @SqlUpdate("INSERT INTO messages (account_id, device_id, encrypted_message) VALUES (:account_id, :device_id, :encrypted_message)")
+ void insert(@Bind("account_id") long accountId, @Bind("device_id") long deviceId, @Bind("encrypted_message") String encryptedOutgoingMessage);
- @SqlQuery("SELECT encrypted_message FROM stored_messages WHERE destination_id = :account_id")
- List getMessagesForAccountId(@Bind("account_id") long accountId);
+ @SqlBatch("INSERT INTO messages (account_id, device_id, encrypted_message) VALUES (:account_id, :device_id, :encrypted_message)")
+ void insert(@Bind("account_id") long accountId, @Bind("device_id") long deviceId, @Bind("encrypted_message") List encryptedOutgoingMessages);
+
+ @SqlQuery("DELETE FROM messages WHERE account_id = :account_id AND device_id = :device_id RETURNING encrypted_message")
+ List getMessagesForDevice(@Bind("account_id") long accountId, @Bind("device_id") long deviceId);
}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/InvalidWebsocketAddressException.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/InvalidWebsocketAddressException.java
new file mode 100644
index 000000000..788283a57
--- /dev/null
+++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/InvalidWebsocketAddressException.java
@@ -0,0 +1,11 @@
+package org.whispersystems.textsecuregcm.websocket;
+
+public class InvalidWebsocketAddressException extends Exception {
+ public InvalidWebsocketAddressException(String serialized) {
+ super(serialized);
+ }
+
+ public InvalidWebsocketAddressException(Exception e) {
+ super(e);
+ }
+}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java
new file mode 100644
index 000000000..72f72778c
--- /dev/null
+++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java
@@ -0,0 +1,57 @@
+package org.whispersystems.textsecuregcm.websocket;
+
+public class WebsocketAddress {
+
+ private final long accountId;
+ private final long deviceId;
+
+ public WebsocketAddress(String serialized) throws InvalidWebsocketAddressException {
+ try {
+ String[] parts = serialized.split(":");
+
+ if (parts == null || parts.length != 2) {
+ throw new InvalidWebsocketAddressException(serialized);
+ }
+
+ this.accountId = Long.parseLong(parts[0]);
+ this.deviceId = Long.parseLong(parts[1]);
+ } catch (NumberFormatException e) {
+ throw new InvalidWebsocketAddressException(e);
+ }
+ }
+
+ public WebsocketAddress(long accountId, long deviceId) {
+ this.accountId = accountId;
+ this.deviceId = deviceId;
+ }
+
+ public long getAccountId() {
+ return accountId;
+ }
+
+ public long getDeviceId() {
+ return deviceId;
+ }
+
+ public String toString() {
+ return accountId + ":" + deviceId;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other == null) return false;
+ if (!(other instanceof WebsocketAddress)) return false;
+
+ WebsocketAddress that = (WebsocketAddress)other;
+
+ return
+ this.accountId == that.accountId &&
+ this.deviceId == that.deviceId;
+ }
+
+ @Override
+ public int hashCode() {
+ return (int)accountId ^ (int)deviceId;
+ }
+
+}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketMessage.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketMessage.java
new file mode 100644
index 000000000..04e5587b2
--- /dev/null
+++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketMessage.java
@@ -0,0 +1,18 @@
+package org.whispersystems.textsecuregcm.websocket;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+public class WebsocketMessage {
+
+ @JsonProperty
+ private long id;
+
+ @JsonProperty
+ private String message;
+
+ public WebsocketMessage(long id, String message) {
+ this.id = id;
+ this.message = message;
+ }
+
+}
diff --git a/src/main/resources/migrations.xml b/src/main/resources/migrations.xml
index 38284dcac..9a6e4857b 100644
--- a/src/main/resources/migrations.xml
+++ b/src/main/resources/migrations.xml
@@ -104,27 +104,37 @@
-
+
-
+
-
+
-
-
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java
index 852fdb393..88f1333c7 100644
--- a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java
+++ b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java
@@ -62,6 +62,8 @@ public class DeviceControllerTest extends ResourceTest {
when(rateLimiters.getSmsDestinationLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVoiceDestinationLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVerifyLimiter()).thenReturn(rateLimiter);
+ when(rateLimiters.getAllocateDeviceLimiter()).thenReturn(rateLimiter);
+ when(rateLimiters.getVerifyDeviceLimiter()).thenReturn(rateLimiter);
when(account.getNextDeviceId()).thenReturn(42L);