From fdb35d4f778a066661495dd0edbd4b4422ed6a80 Mon Sep 17 00:00:00 2001 From: Moxie Marlinspike Date: Fri, 14 Nov 2014 17:59:50 -0800 Subject: [PATCH] Switch to WebSocket-Resources // FREEBIE --- pom.xml | 5 + .../textsecuregcm/WhisperServerService.java | 15 +- .../controllers/WebsocketController.java | 263 ------------------ .../textsecuregcm/push/WebsocketSender.java | 3 +- .../websocket/ConnectListener.java | 65 +++++ .../WebSocketAccountAuthenticator.java | 43 +++ .../websocket/WebSocketConnection.java | 160 +++++++++++ .../websocket/WebsocketControllerFactory.java | 50 ---- .../websocket/WebsocketMessage.java | 18 -- .../WebSocketConnectionTest.java} | 114 +++++--- 10 files changed, 359 insertions(+), 377 deletions(-) delete mode 100644 src/main/java/org/whispersystems/textsecuregcm/controllers/WebsocketController.java create mode 100644 src/main/java/org/whispersystems/textsecuregcm/websocket/ConnectListener.java create mode 100644 src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java create mode 100644 src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java delete mode 100644 src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketControllerFactory.java delete mode 100644 src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketMessage.java rename src/test/java/org/whispersystems/textsecuregcm/tests/{controllers/WebsocketControllerTest.java => websocket/WebSocketConnectionTest.java} (50%) diff --git a/pom.xml b/pom.xml index 695f042f9..05ed967dc 100644 --- a/pom.xml +++ b/pom.xml @@ -130,6 +130,11 @@ smack-tcp 4.0.0 + + org.whispersystems.websocket + websocket-resources + 0.1-SNAPSHOT + diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index d15d0f313..13965d81c 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -73,9 +73,12 @@ import org.whispersystems.textsecuregcm.storage.PubSubManager; import org.whispersystems.textsecuregcm.storage.StoredMessages; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.UrlSigner; -import org.whispersystems.textsecuregcm.websocket.WebsocketControllerFactory; +import org.whispersystems.textsecuregcm.websocket.ConnectListener; +import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator; import org.whispersystems.textsecuregcm.workers.DirectoryCommand; import org.whispersystems.textsecuregcm.workers.VacuumCommand; +import org.whispersystems.websocket.WebSocketResourceProviderFactory; +import org.whispersystems.websocket.setup.WebSocketEnvironment; import javax.servlet.DispatcherType; import javax.servlet.FilterRegistration; @@ -188,11 +191,11 @@ public class WhisperServerService extends Application pendingMessages = new HashMap<>(); - - private final AccountAuthenticator accountAuthenticator; - private final AccountsManager accountsManager; - private final PubSubManager pubSubManager; - private final StoredMessages storedMessages; - private final PushSender pushSender; - - private WebsocketAddress address; - private Account account; - private Device device; - private Session session; - - private long pendingMessageSequence; - - public WebsocketController(AccountAuthenticator accountAuthenticator, - AccountsManager accountsManager, - PushSender pushSender, - PubSubManager pubSubManager, - StoredMessages storedMessages) - { - this.accountAuthenticator = accountAuthenticator; - this.accountsManager = accountsManager; - this.pushSender = pushSender; - this.pubSubManager = pubSubManager; - this.storedMessages = storedMessages; - } - - @Override - public void onWebSocketConnect(Session session) { - try { - UpgradeRequest request = session.getUpgradeRequest(); - Map parameters = request.getParameterMap(); - String[] usernames = parameters.get("login" ); - String[] passwords = parameters.get("password"); - - if (usernames == null || usernames.length == 0 || - passwords == null || passwords.length == 0) - { - session.close(new CloseStatus(4001, "Unauthorized")); - return; - } - - BasicCredentials credentials = new BasicCredentials(usernames[0], passwords[0]); - Optional account = accountAuthenticator.authenticate(credentials); - - if (!account.isPresent()) { - session.close(new CloseStatus(4001, "Unauthorized")); - return; - } - - this.account = account.get(); - this.device = account.get().getAuthenticatedDevice().get(); - this.address = new WebsocketAddress(this.account.getNumber(), this.device.getId()); - this.session = session; - - this.session.setIdleTimeout(10 * 60 * 1000); - this.pubSubManager.subscribe(this.address, this); - - handleQueryDatabase(); - } catch (AuthenticationException e) { - try { session.close(1011, "Server Error");} catch (IOException e1) {} - } catch (IOException ioe) { - logger.info("Abrupt session close."); - } - } - - @Override - public void onWebSocketText(String body) { - try { - IncomingWebsocketMessage incomingMessage = mapper.readValue(body, IncomingWebsocketMessage.class); - - switch (incomingMessage.getType()) { - case IncomingWebsocketMessage.TYPE_ACKNOWLEDGE_MESSAGE: - handleMessageAck(body); - break; - default: - close(new CloseStatus(1008, "Unknown Type")); - } - } catch (IOException e) { - logger.debug("Parse", e); - close(new CloseStatus(1008, "Badly Formatted")); - } - } - - @Override - public void onWebSocketClose(int i, String s) { - pubSubManager.unsubscribe(this.address, this); - - List remainingMessages = new LinkedList<>(); - - synchronized (pendingMessages) { - Long[] pendingKeys = pendingMessages.keySet().toArray(new Long[0]); - Arrays.sort(pendingKeys); - - for (long pendingKey : pendingKeys) { - remainingMessages.add(pendingMessages.get(pendingKey)); - } - - pendingMessages.clear(); - } - - for (PendingMessage remainingMessage : remainingMessages) { - try { - pushSender.sendMessage(account, device, remainingMessage); - } catch (NotPushRegisteredException | TransientPushFailureException e) { - logger.warn("onWebSocketClose", e); - storedMessages.insert(address, remainingMessage); - } - } - } - - @Override - public void onPubSubMessage(PubSubMessage outgoingMessage) { - switch (outgoingMessage.getType()) { - case PubSubMessage.TYPE_DELIVER: - try { - PendingMessage pendingMessage = mapper.readValue(outgoingMessage.getContents(), PendingMessage.class); - handleDeliverOutgoingMessage(pendingMessage); - } catch (IOException e) { - logger.warn("WebsocketController", "Error deserializing PendingMessage", e); - } - break; - case PubSubMessage.TYPE_QUERY_DB: - handleQueryDatabase(); - break; - default: - logger.warn("Unknown pubsub message: " + outgoingMessage.getType()); - } - } - - private void handleDeliverOutgoingMessage(PendingMessage message) { - try { - long messageSequence; - - synchronized (pendingMessages) { - messageSequence = pendingMessageSequence++; - pendingMessages.put(messageSequence, message); - } - - WebsocketMessage websocketMessage = new WebsocketMessage(messageSequence, message.getEncryptedOutgoingMessage()); - session.getRemote().sendStringByFuture(mapper.writeValueAsString(websocketMessage)); - } catch (IOException e) { - logger.debug("Response failed", e); - close(null); - } - } - - private void handleMessageAck(String message) { - try { - AcknowledgeWebsocketMessage ack = mapper.readValue(message, AcknowledgeWebsocketMessage.class); - PendingMessage acknowledgedMessage; - - synchronized (pendingMessages) { - acknowledgedMessage = pendingMessages.remove(ack.getId()); - } - - if (acknowledgedMessage != null && !acknowledgedMessage.isReceipt()) { - sendDeliveryReceipt(acknowledgedMessage); - } - - } catch (IOException e) { - logger.warn("Mapping", e); - } - } - - private void handleQueryDatabase() { - List messages = storedMessages.getMessagesForDevice(address); - - for (PendingMessage message : messages) { - handleDeliverOutgoingMessage(message); - } - } - - private void sendDeliveryReceipt(PendingMessage acknowledgedMessage) { - try { - Optional source = accountsManager.get(acknowledgedMessage.getSender()); - - if (!source.isPresent()) { - logger.warn("Source account disappeared? (%s)", acknowledgedMessage.getSender()); - return; - } - - OutgoingMessageSignal.Builder receipt = - OutgoingMessageSignal.newBuilder() - .setSource(account.getNumber()) - .setSourceDevice((int) device.getId()) - .setTimestamp(acknowledgedMessage.getMessageId()) - .setType(OutgoingMessageSignal.Type.RECEIPT_VALUE); - - for (Device device : source.get().getDevices()) { - pushSender.sendMessage(source.get(), device, receipt.build()); - } - } catch (NotPushRegisteredException | TransientPushFailureException e) { - logger.warn("Websocket", "Delivery receipet", e); - } - } - - @Override - public void onWebSocketBinary(byte[] bytes, int i, int i2) { - logger.info("Received binary message!"); - } - - @Override - public void onWebSocketError(Throwable throwable) { - logger.info("onWebSocketError", throwable); - } - - - private void close(CloseStatus closeStatus) { - try { - if (this.session != null) { - if (closeStatus != null) this.session.close(closeStatus); - else this.session.close(); - } - } catch (IOException e) { - logger.info("close()", e); - } - } -} diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java index fe3cd00e4..f33944faf 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java +++ b/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java @@ -23,7 +23,6 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.controllers.WebsocketController; import org.whispersystems.textsecuregcm.entities.PendingMessage; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; @@ -38,7 +37,7 @@ import static com.codahale.metrics.MetricRegistry.name; public class WebsocketSender { - private static final Logger logger = LoggerFactory.getLogger(WebsocketController.class); + private static final Logger logger = LoggerFactory.getLogger(WebsocketSender.class); private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private final Meter onlineMeter = metricRegistry.meter(name(getClass(), "online")); diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/ConnectListener.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/ConnectListener.java new file mode 100644 index 000000000..9e4c2c87d --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/ConnectListener.java @@ -0,0 +1,65 @@ +package org.whispersystems.textsecuregcm.websocket; + +import com.google.common.base.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.push.PushSender; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.PubSubManager; +import org.whispersystems.textsecuregcm.storage.StoredMessages; +import org.whispersystems.websocket.session.WebSocketSessionContext; +import org.whispersystems.websocket.setup.WebSocketConnectListener; + +public class ConnectListener implements WebSocketConnectListener { + + private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class); + + private final AccountsManager accountsManager; + private final PushSender pushSender; + private final StoredMessages storedMessages; + private final PubSubManager pubSubManager; + + public ConnectListener(AccountsManager accountsManager, PushSender pushSender, + StoredMessages storedMessages, PubSubManager pubSubManager) + { + this.accountsManager = accountsManager; + this.pushSender = pushSender; + this.storedMessages = storedMessages; + this.pubSubManager = pubSubManager; + } + + @Override + public void onWebSocketConnect(WebSocketSessionContext context) { + Optional account = Optional.fromNullable((Account) context.getAuthenticated()); + + if (!account.isPresent()) { + logger.debug("WS Connection with no authentication..."); + context.getClient().close(4001, "Authentication failed"); + return; + } + + Optional device = account.get().getAuthenticatedDevice(); + + if (!device.isPresent()) { + logger.debug("WS Connection with no authenticated device..."); + context.getClient().close(4001, "Device authentication failed"); + return; + } + + final WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender, + storedMessages, pubSubManager, + account.get(), device.get(), + context.getClient()); + + connection.onConnected(); + + context.addListener(new WebSocketSessionContext.WebSocketEventListener() { + @Override + public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) { + connection.onConnectionLost(); + } + }); + } +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java new file mode 100644 index 000000000..d756876ac --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java @@ -0,0 +1,43 @@ +package org.whispersystems.textsecuregcm.websocket; + +import com.google.common.base.Optional; +import org.eclipse.jetty.websocket.api.UpgradeRequest; +import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.websocket.auth.AuthenticationException; +import org.whispersystems.websocket.auth.WebSocketAuthenticator; + +import java.util.Map; + +import io.dropwizard.auth.basic.BasicCredentials; + + +public class WebSocketAccountAuthenticator implements WebSocketAuthenticator { + + private final AccountAuthenticator accountAuthenticator; + + public WebSocketAccountAuthenticator(AccountAuthenticator accountAuthenticator) { + this.accountAuthenticator = accountAuthenticator; + } + + @Override + public Optional authenticate(UpgradeRequest request) throws AuthenticationException { + try { + Map parameters = request.getParameterMap(); + String[] usernames = parameters.get("login"); + String[] passwords = parameters.get("password"); + + if (usernames == null || usernames.length == 0 || + passwords == null || passwords.length == 0) + { + return Optional.absent(); + } + + BasicCredentials credentials = new BasicCredentials(usernames[0], passwords[0]); + return accountAuthenticator.authenticate(credentials); + } catch (io.dropwizard.auth.AuthenticationException e) { + throw new AuthenticationException(e); + } + } + +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java new file mode 100644 index 000000000..5c1902d37 --- /dev/null +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -0,0 +1,160 @@ +package org.whispersystems.textsecuregcm.websocket; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Optional; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.entities.PendingMessage; +import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; +import org.whispersystems.textsecuregcm.push.PushSender; +import org.whispersystems.textsecuregcm.push.TransientPushFailureException; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.PubSubListener; +import org.whispersystems.textsecuregcm.storage.PubSubManager; +import org.whispersystems.textsecuregcm.storage.PubSubMessage; +import org.whispersystems.textsecuregcm.storage.StoredMessages; +import org.whispersystems.textsecuregcm.util.SystemMapper; +import org.whispersystems.websocket.WebSocketClient; +import org.whispersystems.websocket.messages.WebSocketResponseMessage; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.io.IOException; +import java.util.List; + +import static org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal; + +public class WebSocketConnection implements PubSubListener { + + private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class); + private static final ObjectMapper objectMapper = SystemMapper.getMapper(); + + private final AccountsManager accountsManager; + private final PushSender pushSender; + private final StoredMessages storedMessages; + private final PubSubManager pubSubManager; + + private final Account account; + private final Device device; + private final WebsocketAddress address; + private final WebSocketClient client; + + public WebSocketConnection(AccountsManager accountsManager, + PushSender pushSender, + StoredMessages storedMessages, + PubSubManager pubSubManager, + Account account, + Device device, + WebSocketClient client) + { + this.accountsManager = accountsManager; + this.pushSender = pushSender; + this.storedMessages = storedMessages; + this.pubSubManager = pubSubManager; + this.account = account; + this.device = device; + this.client = client; + this.address = new WebsocketAddress(account.getNumber(), device.getId()); + } + + public void onConnected() { + pubSubManager.subscribe(address, this); + processStoredMessages(); + } + + public void onConnectionLost() { + pubSubManager.unsubscribe(address, this); + } + + @Override + public void onPubSubMessage(PubSubMessage message) { + try { + switch (message.getType()) { + case PubSubMessage.TYPE_QUERY_DB: + processStoredMessages(); + break; + case PubSubMessage.TYPE_DELIVER: + PendingMessage pendingMessage = objectMapper.readValue(message.getContents(), + PendingMessage.class); + sendMessage(pendingMessage); + break; + default: + logger.warn("Unknown pubsub message: " + message.getType()); + } + } catch (IOException e) { + logger.warn("Error deserializing PendingMessage", e); + } + } + + private void sendMessage(final PendingMessage message) { + String content = message.getEncryptedOutgoingMessage(); + Optional body = Optional.fromNullable(content.getBytes()); + ListenableFuture response = client.sendRequest("PUT", "/api/v1/message", body); + + Futures.addCallback(response, new FutureCallback() { + @Override + public void onSuccess(@Nullable WebSocketResponseMessage response) { + if (isSuccessResponse(response) && !message.isReceipt()) { + sendDeliveryReceiptFor(message); + } else if (!isSuccessResponse(response)) { + requeueMessage(message); + } + } + + @Override + public void onFailure(@Nonnull Throwable throwable) { + requeueMessage(message); + } + + private boolean isSuccessResponse(WebSocketResponseMessage response) { + return response != null && response.getStatus() >= 200 && response.getStatus() < 300; + } + }); + } + + private void requeueMessage(PendingMessage message) { + try { + pushSender.sendMessage(account, device, message); + } catch (NotPushRegisteredException | TransientPushFailureException e) { + logger.warn("requeueMessage", e); + storedMessages.insert(address, message); + } + } + + private void sendDeliveryReceiptFor(PendingMessage message) { + try { + Optional source = accountsManager.get(message.getSender()); + + if (!source.isPresent()) { + logger.warn("Source account disappeared? (%s)", message.getSender()); + return; + } + + OutgoingMessageSignal.Builder receipt = + OutgoingMessageSignal.newBuilder() + .setSource(account.getNumber()) + .setSourceDevice((int) device.getId()) + .setTimestamp(message.getMessageId()) + .setType(OutgoingMessageSignal.Type.RECEIPT_VALUE); + + for (Device device : source.get().getDevices()) { + pushSender.sendMessage(source.get(), device, receipt.build()); + } + } catch (NotPushRegisteredException | TransientPushFailureException e) { + logger.warn("sendDeliveryReceiptFor", "Delivery receipet", e); + } + } + + private void processStoredMessages() { + List messages = storedMessages.getMessagesForDevice(address); + + for (PendingMessage message : messages) { + sendMessage(message); + } + } +} diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketControllerFactory.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketControllerFactory.java deleted file mode 100644 index 364d4788f..000000000 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketControllerFactory.java +++ /dev/null @@ -1,50 +0,0 @@ -package org.whispersystems.textsecuregcm.websocket; - -import org.eclipse.jetty.websocket.api.UpgradeRequest; -import org.eclipse.jetty.websocket.api.UpgradeResponse; -import org.eclipse.jetty.websocket.servlet.WebSocketCreator; -import org.eclipse.jetty.websocket.servlet.WebSocketServlet; -import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; -import org.whispersystems.textsecuregcm.controllers.WebsocketController; -import org.whispersystems.textsecuregcm.push.PushSender; -import org.whispersystems.textsecuregcm.storage.AccountsManager; -import org.whispersystems.textsecuregcm.storage.PubSubManager; -import org.whispersystems.textsecuregcm.storage.StoredMessages; - - -public class WebsocketControllerFactory extends WebSocketServlet implements WebSocketCreator { - - private final Logger logger = LoggerFactory.getLogger(WebsocketControllerFactory.class); - - private final PushSender pushSender; - private final StoredMessages storedMessages; - private final PubSubManager pubSubManager; - private final AccountAuthenticator accountAuthenticator; - private final AccountsManager accounts; - - public WebsocketControllerFactory(AccountAuthenticator accountAuthenticator, - AccountsManager accounts, - PushSender pushSender, - StoredMessages storedMessages, - PubSubManager pubSubManager) - { - this.accountAuthenticator = accountAuthenticator; - this.accounts = accounts; - this.pushSender = pushSender; - this.storedMessages = storedMessages; - this.pubSubManager = pubSubManager; - } - - @Override - public void configure(WebSocketServletFactory factory) { - factory.setCreator(this); - } - - @Override - public Object createWebSocket(UpgradeRequest upgradeRequest, UpgradeResponse upgradeResponse) { - return new WebsocketController(accountAuthenticator, accounts, pushSender, pubSubManager, storedMessages); - } -} diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketMessage.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketMessage.java deleted file mode 100644 index 04e5587b2..000000000 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketMessage.java +++ /dev/null @@ -1,18 +0,0 @@ -package org.whispersystems.textsecuregcm.websocket; - -import com.fasterxml.jackson.annotation.JsonProperty; - -public class WebsocketMessage { - - @JsonProperty - private long id; - - @JsonProperty - private String message; - - public WebsocketMessage(long id, String message) { - this.id = id; - this.message = message; - } - -} diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/WebsocketControllerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java similarity index 50% rename from src/test/java/org/whispersystems/textsecuregcm/tests/controllers/WebsocketControllerTest.java rename to src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java index 16424c17f..7b2ba931c 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/WebsocketControllerTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java @@ -1,16 +1,13 @@ -package org.whispersystems.textsecuregcm.tests.controllers; +package org.whispersystems.textsecuregcm.tests.websocket; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Optional; -import org.eclipse.jetty.websocket.api.CloseStatus; -import org.eclipse.jetty.websocket.api.RemoteEndpoint; -import org.eclipse.jetty.websocket.api.Session; +import com.google.common.util.concurrent.SettableFuture; import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; -import org.whispersystems.textsecuregcm.controllers.WebsocketController; -import org.whispersystems.textsecuregcm.entities.AcknowledgeWebsocketMessage; -import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.PendingMessage; import org.whispersystems.textsecuregcm.push.PushSender; @@ -19,20 +16,27 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.PubSubManager; import org.whispersystems.textsecuregcm.storage.StoredMessages; +import org.whispersystems.textsecuregcm.websocket.ConnectListener; +import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator; +import org.whispersystems.textsecuregcm.websocket.WebSocketConnection; import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; -import org.whispersystems.textsecuregcm.websocket.WebsocketControllerFactory; +import org.whispersystems.websocket.WebSocketClient; +import org.whispersystems.websocket.messages.WebSocketResponseMessage; +import org.whispersystems.websocket.session.WebSocketSessionContext; +import java.io.IOException; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import io.dropwizard.auth.basic.BasicCredentials; +import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.*; -public class WebsocketControllerTest { +public class WebSocketConnectionTest { - private static final ObjectMapper mapper = new ObjectMapper(); +// private static final ObjectMapper mapper = new ObjectMapper(); private static final String VALID_USER = "+14152222222"; private static final String INVALID_USER = "+14151111111"; @@ -40,52 +44,74 @@ public class WebsocketControllerTest { private static final String VALID_PASSWORD = "secure"; private static final String INVALID_PASSWORD = "insecure"; - private static final StoredMessages storedMessages = mock(StoredMessages.class); +// private static final StoredMessages storedMessages = mock(StoredMessages.class); private static final AccountAuthenticator accountAuthenticator = mock(AccountAuthenticator.class); private static final AccountsManager accountsManager = mock(AccountsManager.class); private static final PubSubManager pubSubManager = mock(PubSubManager.class ); private static final Account account = mock(Account.class ); private static final Device device = mock(Device.class ); private static final UpgradeRequest upgradeRequest = mock(UpgradeRequest.class ); - private static final Session session = mock(Session.class ); +// private static final Session session = mock(Session.class ); private static final PushSender pushSender = mock(PushSender.class); @Test public void testCredentials() throws Exception { + StoredMessages storedMessages = mock(StoredMessages.class); + WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator); + ConnectListener connectListener = new ConnectListener(accountsManager, pushSender, storedMessages, pubSubManager); + WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); + when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) .thenReturn(Optional.of(account)); when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD)))) .thenReturn(Optional.absent()); - when(session.getUpgradeRequest()).thenReturn(upgradeRequest); + when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); - WebsocketController controller = new WebsocketController(accountAuthenticator, accountsManager, pushSender, pubSubManager, storedMessages); +// when(session.getUpgradeRequest()).thenReturn(upgradeRequest); +// +// WebsocketController controller = new WebsocketController(accountAuthenticator, accountsManager, pushSender, pubSubManager, storedMessages); when(upgradeRequest.getParameterMap()).thenReturn(new HashMap() {{ put("login", new String[] {VALID_USER}); put("password", new String[] {VALID_PASSWORD}); }}); - controller.onWebSocketConnect(session); + Optional account = webSocketAuthenticator.authenticate(upgradeRequest); + when(sessionContext.getAuthenticated()).thenReturn(account.orNull()); - verify(session, never()).close(); - verify(session, never()).close(any(CloseStatus.class)); - verify(session, never()).close(anyInt(), anyString()); + connectListener.onWebSocketConnect(sessionContext); + + verify(sessionContext).addListener(any(WebSocketSessionContext.WebSocketEventListener.class)); + +// +// controller.onWebSocketConnect(session); + +// verify(session, never()).close(); +// verify(session, never()).close(any(CloseStatus.class)); +// verify(session, never()).close(anyInt(), anyString()); when(upgradeRequest.getParameterMap()).thenReturn(new HashMap() {{ put("login", new String[] {INVALID_USER}); put("password", new String[] {INVALID_PASSWORD}); }}); - controller.onWebSocketConnect(session); + account = webSocketAuthenticator.authenticate(upgradeRequest); + when(sessionContext.getAuthenticated()).thenReturn(account.orNull()); - verify(session).close(any(CloseStatus.class)); + WebSocketClient client = mock(WebSocketClient.class); + when(sessionContext.getClient()).thenReturn(client); + + connectListener.onWebSocketConnect(sessionContext); + + verify(sessionContext, times(1)).addListener(any(WebSocketSessionContext.WebSocketEventListener.class)); + verify(client).close(eq(4001), anyString()); } @Test public void testOpen() throws Exception { - RemoteEndpoint remote = mock(RemoteEndpoint.class); + StoredMessages storedMessages = mock(StoredMessages.class); List outgoingMessages = new LinkedList() {{ add(new PendingMessage("sender1", 1111, false, "first")); @@ -96,8 +122,6 @@ public class WebsocketControllerTest { when(device.getId()).thenReturn(2L); when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); when(account.getNumber()).thenReturn("+14152222222"); - when(session.getRemote()).thenReturn(remote); - when(session.getUpgradeRequest()).thenReturn(upgradeRequest); final Device sender1device = mock(Device.class); @@ -111,27 +135,38 @@ public class WebsocketControllerTest { when(accountsManager.get("sender1")).thenReturn(Optional.of(sender1)); when(accountsManager.get("sender2")).thenReturn(Optional.absent()); - when(upgradeRequest.getParameterMap()).thenReturn(new HashMap() {{ - put("login", new String[] {VALID_USER}); - put("password", new String[] {VALID_PASSWORD}); - }}); - - when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) - .thenReturn(Optional.of(account)); - when(storedMessages.getMessagesForDevice(new WebsocketAddress(account.getNumber(), device.getId()))) .thenReturn(outgoingMessages); - WebsocketControllerFactory factory = new WebsocketControllerFactory(accountAuthenticator, accountsManager, pushSender, storedMessages, pubSubManager); - WebsocketController controller = (WebsocketController) factory.createWebSocket(null, null); + final List> futures = new LinkedList<>(); + final WebSocketClient client = mock(WebSocketClient.class); - controller.onWebSocketConnect(session); + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(Optional.class))) + .thenAnswer(new Answer>() { + @Override + public SettableFuture answer(InvocationOnMock invocationOnMock) throws Throwable { + SettableFuture future = SettableFuture.create(); + futures.add(future); + return future; + } + }); - verify(pubSubManager).subscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq((controller))); - verify(remote, times(3)).sendStringByFuture(anyString()); + WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender, storedMessages, + pubSubManager, account, device, client); - controller.onWebSocketText(mapper.writeValueAsString(new AcknowledgeWebsocketMessage(1))); - controller.onWebSocketClose(1000, "Closed"); + connection.onConnected(); + + verify(pubSubManager).subscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq((connection))); + verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(Optional.class)); + + assertTrue(futures.size() == 3); + + WebSocketResponseMessage response = mock(WebSocketResponseMessage.class); + when(response.getStatus()).thenReturn(200); + futures.get(1).set(response); + + futures.get(0).setException(new IOException()); + futures.get(2).setException(new IOException()); List pending = new LinkedList() {{ add(new PendingMessage("sender1", 1111, false, "first")); @@ -140,6 +175,9 @@ public class WebsocketControllerTest { verify(pushSender, times(2)).sendMessage(eq(account), eq(device), any(PendingMessage.class)); verify(pushSender, times(1)).sendMessage(eq(sender1), eq(sender1device), any(MessageProtos.OutgoingMessageSignal.class)); + + connection.onConnectionLost(); + verify(pubSubManager).unsubscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq(connection)); } }