From 7bb505db4c322ca41f10854c71a109dd441625e1 Mon Sep 17 00:00:00 2001 From: Moxie Marlinspike Date: Fri, 24 Jan 2014 12:33:40 -0800 Subject: [PATCH] Refactor WebSocket support to use Redis for pubsub communication. --- pom.xml | 6 + .../textsecuregcm/WhisperServerService.java | 15 +- .../controllers/WebsocketController.java | 155 +++++++++++++++++ .../WebsocketControllerFactory.java | 98 +++++++++++ .../entities/AcknowledgeWebsocketMessage.java | 16 ++ .../entities/IncomingWebsocketMessage.java | 25 +++ .../textsecuregcm/push/PushSender.java | 6 +- .../textsecuregcm/storage/Account.java | 11 ++ .../textsecuregcm/storage/Accounts.java | 5 +- .../textsecuregcm/storage/PubSubListener.java | 7 + .../textsecuregcm/storage/PubSubManager.java | 157 ++++++++++++++++++ .../textsecuregcm/storage/PubSubMessage.java | 32 ++++ .../storage/StoredMessageManager.java | 40 ++++- .../textsecuregcm/storage/StoredMessages.java | 12 +- .../InvalidWebsocketAddressException.java | 11 ++ .../websocket/WebsocketAddress.java | 57 +++++++ .../websocket/WebsocketMessage.java | 18 ++ src/main/resources/migrations.xml | 20 ++- .../controllers/DeviceControllerTest.java | 2 + 19 files changed, 670 insertions(+), 23 deletions(-) create mode 100644 src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketController.java create mode 100644 src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketControllerFactory.java create mode 100644 src/main/java/org/whispersystems/textsecuregcm/entities/AcknowledgeWebsocketMessage.java create mode 100644 src/main/java/org/whispersystems/textsecuregcm/entities/IncomingWebsocketMessage.java create mode 100644 src/main/java/org/whispersystems/textsecuregcm/storage/PubSubListener.java create mode 100644 src/main/java/org/whispersystems/textsecuregcm/storage/PubSubManager.java create mode 100644 src/main/java/org/whispersystems/textsecuregcm/storage/PubSubMessage.java create mode 100644 src/main/java/org/whispersystems/textsecuregcm/websocket/InvalidWebsocketAddressException.java create mode 100644 src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java create mode 100644 src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketMessage.java diff --git a/pom.xml b/pom.xml index de6cd7651..4ce3ae012 100644 --- a/pom.xml +++ b/pom.xml @@ -108,6 +108,12 @@ jersey-json 1.17.1 + + + org.eclipse.jetty + jetty-websocket + 8.1.14.v20131031 + diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index f2381c304..66db201d7 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -20,6 +20,7 @@ import com.google.common.base.Optional; import com.yammer.dropwizard.Service; import com.yammer.dropwizard.config.Bootstrap; import com.yammer.dropwizard.config.Environment; +import com.yammer.dropwizard.config.HttpConfiguration; import com.yammer.dropwizard.db.DatabaseConfiguration; import com.yammer.dropwizard.jdbi.DBIFactory; import com.yammer.dropwizard.migrations.MigrationsBundle; @@ -32,12 +33,13 @@ import org.whispersystems.textsecuregcm.auth.FederatedPeerAuthenticator; import org.whispersystems.textsecuregcm.auth.MultiBasicAuthProvider; import org.whispersystems.textsecuregcm.configuration.NexmoConfiguration; import org.whispersystems.textsecuregcm.controllers.AccountController; -import org.whispersystems.textsecuregcm.controllers.DeviceController; import org.whispersystems.textsecuregcm.controllers.AttachmentController; +import org.whispersystems.textsecuregcm.controllers.DeviceController; import org.whispersystems.textsecuregcm.controllers.DirectoryController; import org.whispersystems.textsecuregcm.controllers.FederationController; import org.whispersystems.textsecuregcm.controllers.KeysController; import org.whispersystems.textsecuregcm.controllers.MessageController; +import org.whispersystems.textsecuregcm.controllers.WebsocketControllerFactory; import org.whispersystems.textsecuregcm.federation.FederatedClientManager; import org.whispersystems.textsecuregcm.federation.FederatedPeer; import org.whispersystems.textsecuregcm.limits.RateLimiters; @@ -51,15 +53,16 @@ import org.whispersystems.textsecuregcm.push.PushSender; import org.whispersystems.textsecuregcm.sms.NexmoSmsSender; import org.whispersystems.textsecuregcm.sms.SmsSender; import org.whispersystems.textsecuregcm.sms.TwilioSmsSender; -import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Accounts; import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.DirectoryManager; import org.whispersystems.textsecuregcm.storage.Keys; import org.whispersystems.textsecuregcm.storage.PendingAccounts; import org.whispersystems.textsecuregcm.storage.PendingAccountsManager; import org.whispersystems.textsecuregcm.storage.PendingDevices; import org.whispersystems.textsecuregcm.storage.PendingDevicesManager; +import org.whispersystems.textsecuregcm.storage.PubSubManager; import org.whispersystems.textsecuregcm.storage.StoredMessageManager; import org.whispersystems.textsecuregcm.storage.StoredMessages; import org.whispersystems.textsecuregcm.util.CORSHeaderFilter; @@ -93,6 +96,8 @@ public class WhisperServerService extends Service { public void run(WhisperServerConfiguration config, Environment environment) throws Exception { + config.getHttpConfiguration().setConnectorType(HttpConfiguration.ConnectorType.NONBLOCKING); + DBIFactory dbiFactory = new DBIFactory(); DBI jdbi = dbiFactory.build(environment, config.getDatabaseConfiguration(), "postgresql"); @@ -110,7 +115,8 @@ public class WhisperServerService extends Service { PendingDevicesManager pendingDevicesManager = new PendingDevicesManager(pendingDevices, memcachedClient); AccountsManager accountsManager = new AccountsManager(accounts, directory, memcachedClient); FederatedClientManager federatedClientManager = new FederatedClientManager(config.getFederationConfiguration()); - StoredMessageManager storedMessageManager = new StoredMessageManager(storedMessages); + PubSubManager pubSubManager = new PubSubManager(redisClient); + StoredMessageManager storedMessageManager = new StoredMessageManager(storedMessages, pubSubManager); AccountAuthenticator deviceAuthenticator = new AccountAuthenticator(accountsManager); RateLimiters rateLimiters = new RateLimiters(config.getLimitsConfiguration(), memcachedClient); @@ -141,6 +147,9 @@ public class WhisperServerService extends Service { environment.addResource(keysController); environment.addResource(messageController); + environment.addServlet(new WebsocketControllerFactory(deviceAuthenticator, storedMessageManager, pubSubManager), + "/v1/websocket/"); + environment.addHealthCheck(new RedisHealthCheck(redisClient)); environment.addHealthCheck(new MemcacheHealthCheck(memcachedClient)); diff --git a/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketController.java b/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketController.java new file mode 100644 index 000000000..ea6022f39 --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketController.java @@ -0,0 +1,155 @@ +package org.whispersystems.textsecuregcm.controllers; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.eclipse.jetty.websocket.WebSocket; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.entities.AcknowledgeWebsocketMessage; +import org.whispersystems.textsecuregcm.entities.IncomingWebsocketMessage; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.PubSubListener; +import org.whispersystems.textsecuregcm.storage.PubSubManager; +import org.whispersystems.textsecuregcm.storage.PubSubMessage; +import org.whispersystems.textsecuregcm.storage.StoredMessageManager; +import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; +import org.whispersystems.textsecuregcm.websocket.WebsocketMessage; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +public class WebsocketController implements WebSocket.OnTextMessage, PubSubListener { + + private static final Logger logger = LoggerFactory.getLogger(WebsocketController.class); + private static final ObjectMapper mapper = new ObjectMapper(); + private static final Map pendingMessages = new HashMap<>(); + + private final StoredMessageManager storedMessageManager; + private final PubSubManager pubSubManager; + + private final Account account; + private final Device device; + + private Connection connection; + private long pendingMessageSequence; + + public WebsocketController(StoredMessageManager storedMessageManager, + PubSubManager pubSubManager, + Account account) + { + this.storedMessageManager = storedMessageManager; + this.pubSubManager = pubSubManager; + this.account = account; + this.device = account.getAuthenticatedDevice().get(); + } + + + @Override + public void onOpen(Connection connection) { + this.connection = connection; + pubSubManager.subscribe(new WebsocketAddress(this.account.getId(), this.device.getId()), this); + handleQueryDatabase(); + } + + @Override + public void onClose(int i, String s) { + handleClose(); + } + + @Override + public void onMessage(String body) { + try { + IncomingWebsocketMessage incomingMessage = mapper.readValue(body, IncomingWebsocketMessage.class); + + switch (incomingMessage.getType()) { + case IncomingWebsocketMessage.TYPE_ACKNOWLEDGE_MESSAGE: handleMessageAck(body); break; + case IncomingWebsocketMessage.TYPE_PING_MESSAGE: handlePing(); break; + default: handleClose(); break; + } + } catch (IOException e) { + logger.debug("Parse", e); + handleClose(); + } + } + + @Override + public void onPubSubMessage(PubSubMessage outgoingMessage) { + switch (outgoingMessage.getType()) { + case PubSubMessage.TYPE_DELIVER: handleDeliverOutgoingMessage(outgoingMessage.getContents()); break; + case PubSubMessage.TYPE_QUERY_DB: handleQueryDatabase(); break; + default: + logger.warn("Unknown pubsub message: " + outgoingMessage.getType()); + } + } + + private void handleDeliverOutgoingMessage(String message) { + try { + long messageSequence; + + synchronized (pendingMessages) { + messageSequence = pendingMessageSequence++; + pendingMessages.put(messageSequence, message); + } + + connection.sendMessage(mapper.writeValueAsString(new WebsocketMessage(messageSequence, message))); + } catch (IOException e) { + logger.debug("Response failed", e); + handleClose(); + } + } + + private void handleMessageAck(String message) { + try { + AcknowledgeWebsocketMessage ack = mapper.readValue(message, AcknowledgeWebsocketMessage.class); + + synchronized (pendingMessages) { + pendingMessages.remove(ack.getId()); + } + } catch (IOException e) { + logger.warn("Mapping", e); + } + } + + private void handlePing() { + try { + IncomingWebsocketMessage pongMessage = new IncomingWebsocketMessage(IncomingWebsocketMessage.TYPE_PONG_MESSAGE); + connection.sendMessage(mapper.writeValueAsString(pongMessage)); + } catch (IOException e) { + logger.warn("Pong failed", e); + handleClose(); + } + } + + private void handleClose() { + pubSubManager.unsubscribe(new WebsocketAddress(account.getId(), device.getId()), this); + connection.close(); + + List remainingMessages = new LinkedList<>(); + + synchronized (pendingMessages) { + Long[] pendingKeys = pendingMessages.keySet().toArray(new Long[0]); + Arrays.sort(pendingKeys); + + for (long pendingKey : pendingKeys) { + remainingMessages.add(pendingMessages.get(pendingKey)); + } + + pendingMessages.clear(); + } + + storedMessageManager.storeMessages(account, device, remainingMessages); + } + + private void handleQueryDatabase() { + List messages = storedMessageManager.getOutgoingMessages(account, device); + + for (String message : messages) { + handleDeliverOutgoingMessage(message); + } + } + +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketControllerFactory.java b/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketControllerFactory.java new file mode 100644 index 000000000..825ebb2c9 --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketControllerFactory.java @@ -0,0 +1,98 @@ +package org.whispersystems.textsecuregcm.controllers; + +import com.google.common.base.Optional; +import com.yammer.dropwizard.auth.AuthenticationException; +import com.yammer.dropwizard.auth.basic.BasicCredentials; +import org.eclipse.jetty.websocket.WebSocket; +import org.eclipse.jetty.websocket.WebSocketServlet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.PubSubManager; +import org.whispersystems.textsecuregcm.storage.StoredMessageManager; + +import javax.servlet.http.HttpServletRequest; +import java.util.LinkedHashMap; +import java.util.Map; + + +public class WebsocketControllerFactory extends WebSocketServlet { + + private final Logger logger = LoggerFactory.getLogger(WebsocketControllerFactory.class); + + private final StoredMessageManager storedMessageManager; + private final PubSubManager pubSubManager; + private final AccountAuthenticator accountAuthenticator; + + private final LinkedHashMap> cache = + new LinkedHashMap>() { + @Override + protected boolean removeEldestEntry(Map.Entry> eldest) { + return size() > 10; + } + }; + + public WebsocketControllerFactory(AccountAuthenticator accountAuthenticator, + StoredMessageManager storedMessageManager, + PubSubManager pubSubManager) + { + this.accountAuthenticator = accountAuthenticator; + this.storedMessageManager = storedMessageManager; + this.pubSubManager = pubSubManager; + } + + @Override + public WebSocket doWebSocketConnect(HttpServletRequest request, String s) { + try { + String username = request.getParameter("user"); + String password = request.getParameter("password"); + + if (username == null || password == null) { + return null; + } + + BasicCredentials credentials = new BasicCredentials(username, password); + + Optional account = cache.remove(credentials); + + if (account == null) { + account = accountAuthenticator.authenticate(new BasicCredentials(username, password)); + } + + if (!account.isPresent()) { + return null; + } + + return new WebsocketController(storedMessageManager, pubSubManager, account.get()); + } catch (AuthenticationException e) { + throw new AssertionError(e); + } + } + + @Override + public boolean checkOrigin(HttpServletRequest request, String origin) { + try { + String username = request.getParameter("user"); + String password = request.getParameter("password"); + + if (username == null || password == null) { + return false; + } + + BasicCredentials credentials = new BasicCredentials(username, password); + Optional account = accountAuthenticator.authenticate(credentials); + + if (!account.isPresent()) { + return false; + } + + cache.put(credentials, account); + + return true; + } catch (AuthenticationException e) { + logger.warn("Auth Failure", e); + return false; + } + } +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/entities/AcknowledgeWebsocketMessage.java b/src/main/java/org/whispersystems/textsecuregcm/entities/AcknowledgeWebsocketMessage.java new file mode 100644 index 000000000..2963678a2 --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/entities/AcknowledgeWebsocketMessage.java @@ -0,0 +1,16 @@ +package org.whispersystems.textsecuregcm.entities; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +@JsonIgnoreProperties(ignoreUnknown = true) +public class AcknowledgeWebsocketMessage extends IncomingWebsocketMessage { + + @JsonProperty + private long id; + + public long getId() { + return id; + } + +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingWebsocketMessage.java b/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingWebsocketMessage.java new file mode 100644 index 000000000..db3ba3e7f --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingWebsocketMessage.java @@ -0,0 +1,25 @@ +package org.whispersystems.textsecuregcm.entities; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +@JsonIgnoreProperties(ignoreUnknown = false) +public class IncomingWebsocketMessage { + + public static final int TYPE_ACKNOWLEDGE_MESSAGE = 1; + public static final int TYPE_PING_MESSAGE = 2; + public static final int TYPE_PONG_MESSAGE = 3; + + @JsonProperty + private int type; + + public IncomingWebsocketMessage() {} + + public IncomingWebsocketMessage(int type) { + this.type = type; + } + + public int getType() { + return type; + } +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java index fcb221cf5..8f212a79d 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/PushSender.java @@ -62,7 +62,7 @@ public class PushSender { if (device.getGcmId() != null) sendGcmMessage(account, device, message); else if (device.getApnId() != null) sendApnMessage(account, device, message); - else if (device.getFetchesMessages()) storeFetchedMessage(device, message); + else if (device.getFetchesMessages()) storeFetchedMessage(account, device, message); else throw new NotPushRegisteredException("No delivery possible!"); } @@ -97,11 +97,11 @@ public class PushSender { } } - private void storeFetchedMessage(Device device, EncryptedOutgoingMessage outgoingMessage) + private void storeFetchedMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage) throws NotPushRegisteredException { try { - storedMessageManager.storeMessage(device, outgoingMessage); + storedMessageManager.storeMessage(account, device, outgoingMessage); } catch (CryptoEncodingException e) { throw new NotPushRegisteredException(e); } diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java b/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java index d61832ae4..bd8738239 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java @@ -29,6 +29,9 @@ public class Account implements Serializable { public static final int MEMCACHE_VERION = 2; + @JsonIgnore + private long id; + @JsonProperty private String number; @@ -48,6 +51,14 @@ public class Account implements Serializable { this.supportsSms = supportsSms; } + public long getId() { + return id; + } + + public void setId(long id) { + this.id = id; + } + public Optional getAuthenticatedDevice() { return authenticatedDevice; } diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java b/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java index 5d74452ab..62c18b69c 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java @@ -95,7 +95,10 @@ public abstract class Accounts { throws SQLException { try { - return mapper.readValue(resultSet.getString(DATA), Account.class); + Account account = mapper.readValue(resultSet.getString(DATA), Account.class); + account.setId(resultSet.getLong(ID)); + + return account; } catch (IOException e) { throw new SQLException(e); } diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubListener.java b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubListener.java new file mode 100644 index 000000000..d9a24d595 --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubListener.java @@ -0,0 +1,7 @@ +package org.whispersystems.textsecuregcm.storage; + +public interface PubSubListener { + + public void onPubSubMessage(PubSubMessage outgoingMessage); + +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubManager.java b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubManager.java new file mode 100644 index 000000000..b9e10b7e8 --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubManager.java @@ -0,0 +1,157 @@ +package org.whispersystems.textsecuregcm.storage; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.websocket.InvalidWebsocketAddressException; +import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import redis.clients.jedis.Jedis; +import redis.clients.jedis.JedisPool; +import redis.clients.jedis.JedisPubSub; + +public class PubSubManager { + + private final Logger logger = LoggerFactory.getLogger(PubSubManager.class); + private final ObjectMapper mapper = new ObjectMapper(); + private final SubscriptionListener baseListener = new SubscriptionListener(); + private final Map listeners = new HashMap<>(); + + private final JedisPool jedisPool; + private boolean subscribed = false; + + public PubSubManager(final JedisPool jedisPool) { + this.jedisPool = jedisPool; + initializePubSubWorker(); + waitForSubscription(); + } + + public synchronized void subscribe(WebsocketAddress address, PubSubListener listener) { + listeners.put(address, listener); + baseListener.subscribe(address.toString()); + } + + public synchronized void unsubscribe(WebsocketAddress address, PubSubListener listener) { + if (listeners.get(address) == listener) { + listeners.remove(address); + baseListener.unsubscribe(address.toString()); + } + } + + public synchronized boolean publish(WebsocketAddress address, PubSubMessage message) { + try { + String serialized = mapper.writeValueAsString(message); + Jedis jedis = null; + + try { + jedis = jedisPool.getResource(); + return jedis.publish(address.toString(), serialized) != 0; + } finally { + if (jedis != null) + jedisPool.returnResource(jedis); + } + } catch (JsonProcessingException e) { + throw new AssertionError(e); + } + } + + private synchronized void waitForSubscription() { + try { + while (!subscribed) { + wait(); + } + } catch (InterruptedException e) { + throw new AssertionError(e); + } + } + + private void initializePubSubWorker() { + new Thread("PubSubListener") { + @Override + public void run() { + for (;;) { + Jedis jedis = null; + try { + jedis = jedisPool.getResource(); + jedis.subscribe(baseListener, new WebsocketAddress(0, 0).toString()); + logger.warn("**** Unsubscribed from holding channel!!! ******"); + } finally { + if (jedis != null) + jedisPool.returnResource(jedis); + } + } + } + }.start(); + + new Thread("PubSubKeepAlive") { + @Override + public void run() { + for (;;) { + try { + Thread.sleep(20000); + publish(new WebsocketAddress(0, 0), new PubSubMessage(0, "foo")); + } catch (InterruptedException e) { + throw new AssertionError(e); + } + } + } + }.start(); + } + + private class SubscriptionListener extends JedisPubSub { + + @Override + public void onMessage(String channel, String message) { + try { + WebsocketAddress address = new WebsocketAddress(channel); + PubSubListener listener; + + synchronized (PubSubManager.this) { + listener = listeners.get(address); + } + + if (listener != null) { + listener.onPubSubMessage(mapper.readValue(message, PubSubMessage.class)); + } + } catch (InvalidWebsocketAddressException e) { + logger.warn("Address", e); + } catch (IOException e) { + logger.warn("IOE", e); + } + } + + @Override + public void onPMessage(String s, String s2, String s3) { + logger.warn("Received PMessage!"); + } + + @Override + public void onSubscribe(String channel, int count) { + try { + WebsocketAddress address = new WebsocketAddress(channel); + if (address.getAccountId() == 0 && address.getDeviceId() == 0) { + synchronized (PubSubManager.this) { + subscribed = true; + PubSubManager.this.notifyAll(); + } + } + } catch (InvalidWebsocketAddressException e) { + logger.warn("Weird address", e); + } + } + + @Override + public void onUnsubscribe(String s, int i) {} + + @Override + public void onPUnsubscribe(String s, int i) {} + + @Override + public void onPSubscribe(String s, int i) {} + } +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubMessage.java b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubMessage.java new file mode 100644 index 000000000..fa24dbd48 --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubMessage.java @@ -0,0 +1,32 @@ +package org.whispersystems.textsecuregcm.storage; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +@JsonIgnoreProperties(ignoreUnknown = true) +public class PubSubMessage { + + public static final int TYPE_QUERY_DB = 1; + public static final int TYPE_DELIVER = 2; + + @JsonProperty + private int type; + + @JsonProperty + private String contents; + + public PubSubMessage() {} + + public PubSubMessage(int type, String contents) { + this.type = type; + this.contents = contents; + } + + public int getType() { + return type; + } + + public String getContents() { + return contents; + } +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessageManager.java b/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessageManager.java index 7782e8969..b0de973af 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessageManager.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessageManager.java @@ -18,23 +18,49 @@ package org.whispersystems.textsecuregcm.storage; import org.whispersystems.textsecuregcm.entities.CryptoEncodingException; import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; +import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; -import java.io.IOException; import java.util.List; public class StoredMessageManager { - StoredMessages storedMessages; - public StoredMessageManager(StoredMessages storedMessages) { + + private final StoredMessages storedMessages; + private final PubSubManager pubSubManager; + + public StoredMessageManager(StoredMessages storedMessages, PubSubManager pubSubManager) { this.storedMessages = storedMessages; + this.pubSubManager = pubSubManager; } - public void storeMessage(Device device, EncryptedOutgoingMessage outgoingMessage) + public void storeMessage(Account account, Device device, EncryptedOutgoingMessage outgoingMessage) throws CryptoEncodingException { - storedMessages.insert(device.getId(), outgoingMessage.serialize()); + storeMessage(account, device, outgoingMessage.serialize()); } - public List getStoredMessage(Device device) { - return storedMessages.getMessagesForAccountId(device.getId()); + public void storeMessages(Account account, Device device, List serializedMessages) { + for (String serializedMessage : serializedMessages) { + storeMessage(account, device, serializedMessage); + } + } + + private void storeMessage(Account account, Device device, String serializedMessage) { + if (device.getFetchesMessages()) { + WebsocketAddress address = new WebsocketAddress(account.getId(), device.getId()); + PubSubMessage pubSubMessage = new PubSubMessage(PubSubMessage.TYPE_DELIVER, serializedMessage); + + if (!pubSubManager.publish(address, pubSubMessage)) { + storedMessages.insert(account.getId(), device.getId(), serializedMessage); + pubSubManager.publish(address, new PubSubMessage(PubSubMessage.TYPE_QUERY_DB, null)); + } + + return; + } + + storedMessages.insert(account.getId(), device.getId(), serializedMessage); + } + + public List getOutgoingMessages(Account account, Device device) { + return storedMessages.getMessagesForDevice(account.getId(), device.getId()); } } diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java b/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java index 9bd1f4f9c..fdc18b8f1 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/StoredMessages.java @@ -17,6 +17,7 @@ package org.whispersystems.textsecuregcm.storage; import org.skife.jdbi.v2.sqlobject.Bind; +import org.skife.jdbi.v2.sqlobject.SqlBatch; import org.skife.jdbi.v2.sqlobject.SqlQuery; import org.skife.jdbi.v2.sqlobject.SqlUpdate; @@ -24,9 +25,12 @@ import java.util.List; public interface StoredMessages { - @SqlUpdate("INSERT INTO stored_messages (destination_id, encrypted_message) VALUES (:destination_id, :encrypted_message)") - void insert(@Bind("destination_id") long destinationAccountId, @Bind("encrypted_message") String encryptedOutgoingMessage); + @SqlUpdate("INSERT INTO messages (account_id, device_id, encrypted_message) VALUES (:account_id, :device_id, :encrypted_message)") + void insert(@Bind("account_id") long accountId, @Bind("device_id") long deviceId, @Bind("encrypted_message") String encryptedOutgoingMessage); - @SqlQuery("SELECT encrypted_message FROM stored_messages WHERE destination_id = :account_id") - List getMessagesForAccountId(@Bind("account_id") long accountId); + @SqlBatch("INSERT INTO messages (account_id, device_id, encrypted_message) VALUES (:account_id, :device_id, :encrypted_message)") + void insert(@Bind("account_id") long accountId, @Bind("device_id") long deviceId, @Bind("encrypted_message") List encryptedOutgoingMessages); + + @SqlQuery("DELETE FROM messages WHERE account_id = :account_id AND device_id = :device_id RETURNING encrypted_message") + List getMessagesForDevice(@Bind("account_id") long accountId, @Bind("device_id") long deviceId); } diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/InvalidWebsocketAddressException.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/InvalidWebsocketAddressException.java new file mode 100644 index 000000000..788283a57 --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/InvalidWebsocketAddressException.java @@ -0,0 +1,11 @@ +package org.whispersystems.textsecuregcm.websocket; + +public class InvalidWebsocketAddressException extends Exception { + public InvalidWebsocketAddressException(String serialized) { + super(serialized); + } + + public InvalidWebsocketAddressException(Exception e) { + super(e); + } +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java new file mode 100644 index 000000000..72f72778c --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketAddress.java @@ -0,0 +1,57 @@ +package org.whispersystems.textsecuregcm.websocket; + +public class WebsocketAddress { + + private final long accountId; + private final long deviceId; + + public WebsocketAddress(String serialized) throws InvalidWebsocketAddressException { + try { + String[] parts = serialized.split(":"); + + if (parts == null || parts.length != 2) { + throw new InvalidWebsocketAddressException(serialized); + } + + this.accountId = Long.parseLong(parts[0]); + this.deviceId = Long.parseLong(parts[1]); + } catch (NumberFormatException e) { + throw new InvalidWebsocketAddressException(e); + } + } + + public WebsocketAddress(long accountId, long deviceId) { + this.accountId = accountId; + this.deviceId = deviceId; + } + + public long getAccountId() { + return accountId; + } + + public long getDeviceId() { + return deviceId; + } + + public String toString() { + return accountId + ":" + deviceId; + } + + @Override + public boolean equals(Object other) { + if (other == null) return false; + if (!(other instanceof WebsocketAddress)) return false; + + WebsocketAddress that = (WebsocketAddress)other; + + return + this.accountId == that.accountId && + this.deviceId == that.deviceId; + } + + @Override + public int hashCode() { + return (int)accountId ^ (int)deviceId; + } + +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketMessage.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketMessage.java new file mode 100644 index 000000000..04e5587b2 --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketMessage.java @@ -0,0 +1,18 @@ +package org.whispersystems.textsecuregcm.websocket; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public class WebsocketMessage { + + @JsonProperty + private long id; + + @JsonProperty + private String message; + + public WebsocketMessage(long id, String message) { + this.id = id; + this.message = message; + } + +} diff --git a/src/main/resources/migrations.xml b/src/main/resources/migrations.xml index 38284dcac..9a6e4857b 100644 --- a/src/main/resources/migrations.xml +++ b/src/main/resources/migrations.xml @@ -104,27 +104,37 @@ - + - + - + - - + + + + + + + + + + + + diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java index 852fdb393..88f1333c7 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java @@ -62,6 +62,8 @@ public class DeviceControllerTest extends ResourceTest { when(rateLimiters.getSmsDestinationLimiter()).thenReturn(rateLimiter); when(rateLimiters.getVoiceDestinationLimiter()).thenReturn(rateLimiter); when(rateLimiters.getVerifyLimiter()).thenReturn(rateLimiter); + when(rateLimiters.getAllocateDeviceLimiter()).thenReturn(rateLimiter); + when(rateLimiters.getVerifyDeviceLimiter()).thenReturn(rateLimiter); when(account.getNextDeviceId()).thenReturn(42L);