diff --git a/pom.xml b/pom.xml
index 695f042f9..05ed967dc 100644
--- a/pom.xml
+++ b/pom.xml
@@ -130,6 +130,11 @@
smack-tcp
4.0.0
+
+ org.whispersystems.websocket
+ websocket-resources
+ 0.1-SNAPSHOT
+
diff --git a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
index d15d0f313..13965d81c 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
@@ -73,9 +73,12 @@ import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.StoredMessages;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.UrlSigner;
-import org.whispersystems.textsecuregcm.websocket.WebsocketControllerFactory;
+import org.whispersystems.textsecuregcm.websocket.ConnectListener;
+import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator;
import org.whispersystems.textsecuregcm.workers.DirectoryCommand;
import org.whispersystems.textsecuregcm.workers.VacuumCommand;
+import org.whispersystems.websocket.WebSocketResourceProviderFactory;
+import org.whispersystems.websocket.setup.WebSocketEnvironment;
import javax.servlet.DispatcherType;
import javax.servlet.FilterRegistration;
@@ -188,11 +191,11 @@ public class WhisperServerService extends Application pendingMessages = new HashMap<>();
-
- private final AccountAuthenticator accountAuthenticator;
- private final AccountsManager accountsManager;
- private final PubSubManager pubSubManager;
- private final StoredMessages storedMessages;
- private final PushSender pushSender;
-
- private WebsocketAddress address;
- private Account account;
- private Device device;
- private Session session;
-
- private long pendingMessageSequence;
-
- public WebsocketController(AccountAuthenticator accountAuthenticator,
- AccountsManager accountsManager,
- PushSender pushSender,
- PubSubManager pubSubManager,
- StoredMessages storedMessages)
- {
- this.accountAuthenticator = accountAuthenticator;
- this.accountsManager = accountsManager;
- this.pushSender = pushSender;
- this.pubSubManager = pubSubManager;
- this.storedMessages = storedMessages;
- }
-
- @Override
- public void onWebSocketConnect(Session session) {
- try {
- UpgradeRequest request = session.getUpgradeRequest();
- Map parameters = request.getParameterMap();
- String[] usernames = parameters.get("login" );
- String[] passwords = parameters.get("password");
-
- if (usernames == null || usernames.length == 0 ||
- passwords == null || passwords.length == 0)
- {
- session.close(new CloseStatus(4001, "Unauthorized"));
- return;
- }
-
- BasicCredentials credentials = new BasicCredentials(usernames[0], passwords[0]);
- Optional account = accountAuthenticator.authenticate(credentials);
-
- if (!account.isPresent()) {
- session.close(new CloseStatus(4001, "Unauthorized"));
- return;
- }
-
- this.account = account.get();
- this.device = account.get().getAuthenticatedDevice().get();
- this.address = new WebsocketAddress(this.account.getNumber(), this.device.getId());
- this.session = session;
-
- this.session.setIdleTimeout(10 * 60 * 1000);
- this.pubSubManager.subscribe(this.address, this);
-
- handleQueryDatabase();
- } catch (AuthenticationException e) {
- try { session.close(1011, "Server Error");} catch (IOException e1) {}
- } catch (IOException ioe) {
- logger.info("Abrupt session close.");
- }
- }
-
- @Override
- public void onWebSocketText(String body) {
- try {
- IncomingWebsocketMessage incomingMessage = mapper.readValue(body, IncomingWebsocketMessage.class);
-
- switch (incomingMessage.getType()) {
- case IncomingWebsocketMessage.TYPE_ACKNOWLEDGE_MESSAGE:
- handleMessageAck(body);
- break;
- default:
- close(new CloseStatus(1008, "Unknown Type"));
- }
- } catch (IOException e) {
- logger.debug("Parse", e);
- close(new CloseStatus(1008, "Badly Formatted"));
- }
- }
-
- @Override
- public void onWebSocketClose(int i, String s) {
- pubSubManager.unsubscribe(this.address, this);
-
- List remainingMessages = new LinkedList<>();
-
- synchronized (pendingMessages) {
- Long[] pendingKeys = pendingMessages.keySet().toArray(new Long[0]);
- Arrays.sort(pendingKeys);
-
- for (long pendingKey : pendingKeys) {
- remainingMessages.add(pendingMessages.get(pendingKey));
- }
-
- pendingMessages.clear();
- }
-
- for (PendingMessage remainingMessage : remainingMessages) {
- try {
- pushSender.sendMessage(account, device, remainingMessage);
- } catch (NotPushRegisteredException | TransientPushFailureException e) {
- logger.warn("onWebSocketClose", e);
- storedMessages.insert(address, remainingMessage);
- }
- }
- }
-
- @Override
- public void onPubSubMessage(PubSubMessage outgoingMessage) {
- switch (outgoingMessage.getType()) {
- case PubSubMessage.TYPE_DELIVER:
- try {
- PendingMessage pendingMessage = mapper.readValue(outgoingMessage.getContents(), PendingMessage.class);
- handleDeliverOutgoingMessage(pendingMessage);
- } catch (IOException e) {
- logger.warn("WebsocketController", "Error deserializing PendingMessage", e);
- }
- break;
- case PubSubMessage.TYPE_QUERY_DB:
- handleQueryDatabase();
- break;
- default:
- logger.warn("Unknown pubsub message: " + outgoingMessage.getType());
- }
- }
-
- private void handleDeliverOutgoingMessage(PendingMessage message) {
- try {
- long messageSequence;
-
- synchronized (pendingMessages) {
- messageSequence = pendingMessageSequence++;
- pendingMessages.put(messageSequence, message);
- }
-
- WebsocketMessage websocketMessage = new WebsocketMessage(messageSequence, message.getEncryptedOutgoingMessage());
- session.getRemote().sendStringByFuture(mapper.writeValueAsString(websocketMessage));
- } catch (IOException e) {
- logger.debug("Response failed", e);
- close(null);
- }
- }
-
- private void handleMessageAck(String message) {
- try {
- AcknowledgeWebsocketMessage ack = mapper.readValue(message, AcknowledgeWebsocketMessage.class);
- PendingMessage acknowledgedMessage;
-
- synchronized (pendingMessages) {
- acknowledgedMessage = pendingMessages.remove(ack.getId());
- }
-
- if (acknowledgedMessage != null && !acknowledgedMessage.isReceipt()) {
- sendDeliveryReceipt(acknowledgedMessage);
- }
-
- } catch (IOException e) {
- logger.warn("Mapping", e);
- }
- }
-
- private void handleQueryDatabase() {
- List messages = storedMessages.getMessagesForDevice(address);
-
- for (PendingMessage message : messages) {
- handleDeliverOutgoingMessage(message);
- }
- }
-
- private void sendDeliveryReceipt(PendingMessage acknowledgedMessage) {
- try {
- Optional source = accountsManager.get(acknowledgedMessage.getSender());
-
- if (!source.isPresent()) {
- logger.warn("Source account disappeared? (%s)", acknowledgedMessage.getSender());
- return;
- }
-
- OutgoingMessageSignal.Builder receipt =
- OutgoingMessageSignal.newBuilder()
- .setSource(account.getNumber())
- .setSourceDevice((int) device.getId())
- .setTimestamp(acknowledgedMessage.getMessageId())
- .setType(OutgoingMessageSignal.Type.RECEIPT_VALUE);
-
- for (Device device : source.get().getDevices()) {
- pushSender.sendMessage(source.get(), device, receipt.build());
- }
- } catch (NotPushRegisteredException | TransientPushFailureException e) {
- logger.warn("Websocket", "Delivery receipet", e);
- }
- }
-
- @Override
- public void onWebSocketBinary(byte[] bytes, int i, int i2) {
- logger.info("Received binary message!");
- }
-
- @Override
- public void onWebSocketError(Throwable throwable) {
- logger.info("onWebSocketError", throwable);
- }
-
-
- private void close(CloseStatus closeStatus) {
- try {
- if (this.session != null) {
- if (closeStatus != null) this.session.close(closeStatus);
- else this.session.close();
- }
- } catch (IOException e) {
- logger.info("close()", e);
- }
- }
-}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java b/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java
index fe3cd00e4..f33944faf 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/push/WebsocketSender.java
@@ -23,7 +23,6 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.whispersystems.textsecuregcm.controllers.WebsocketController;
import org.whispersystems.textsecuregcm.entities.PendingMessage;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
@@ -38,7 +37,7 @@ import static com.codahale.metrics.MetricRegistry.name;
public class WebsocketSender {
- private static final Logger logger = LoggerFactory.getLogger(WebsocketController.class);
+ private static final Logger logger = LoggerFactory.getLogger(WebsocketSender.class);
private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private final Meter onlineMeter = metricRegistry.meter(name(getClass(), "online"));
diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/ConnectListener.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/ConnectListener.java
new file mode 100644
index 000000000..9e4c2c87d
--- /dev/null
+++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/ConnectListener.java
@@ -0,0 +1,65 @@
+package org.whispersystems.textsecuregcm.websocket;
+
+import com.google.common.base.Optional;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.textsecuregcm.push.PushSender;
+import org.whispersystems.textsecuregcm.storage.Account;
+import org.whispersystems.textsecuregcm.storage.AccountsManager;
+import org.whispersystems.textsecuregcm.storage.Device;
+import org.whispersystems.textsecuregcm.storage.PubSubManager;
+import org.whispersystems.textsecuregcm.storage.StoredMessages;
+import org.whispersystems.websocket.session.WebSocketSessionContext;
+import org.whispersystems.websocket.setup.WebSocketConnectListener;
+
+public class ConnectListener implements WebSocketConnectListener {
+
+ private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class);
+
+ private final AccountsManager accountsManager;
+ private final PushSender pushSender;
+ private final StoredMessages storedMessages;
+ private final PubSubManager pubSubManager;
+
+ public ConnectListener(AccountsManager accountsManager, PushSender pushSender,
+ StoredMessages storedMessages, PubSubManager pubSubManager)
+ {
+ this.accountsManager = accountsManager;
+ this.pushSender = pushSender;
+ this.storedMessages = storedMessages;
+ this.pubSubManager = pubSubManager;
+ }
+
+ @Override
+ public void onWebSocketConnect(WebSocketSessionContext context) {
+ Optional account = Optional.fromNullable((Account) context.getAuthenticated());
+
+ if (!account.isPresent()) {
+ logger.debug("WS Connection with no authentication...");
+ context.getClient().close(4001, "Authentication failed");
+ return;
+ }
+
+ Optional device = account.get().getAuthenticatedDevice();
+
+ if (!device.isPresent()) {
+ logger.debug("WS Connection with no authenticated device...");
+ context.getClient().close(4001, "Device authentication failed");
+ return;
+ }
+
+ final WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender,
+ storedMessages, pubSubManager,
+ account.get(), device.get(),
+ context.getClient());
+
+ connection.onConnected();
+
+ context.addListener(new WebSocketSessionContext.WebSocketEventListener() {
+ @Override
+ public void onWebSocketClose(WebSocketSessionContext context, int statusCode, String reason) {
+ connection.onConnectionLost();
+ }
+ });
+ }
+}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java
new file mode 100644
index 000000000..d756876ac
--- /dev/null
+++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java
@@ -0,0 +1,43 @@
+package org.whispersystems.textsecuregcm.websocket;
+
+import com.google.common.base.Optional;
+import org.eclipse.jetty.websocket.api.UpgradeRequest;
+import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
+import org.whispersystems.textsecuregcm.storage.Account;
+import org.whispersystems.websocket.auth.AuthenticationException;
+import org.whispersystems.websocket.auth.WebSocketAuthenticator;
+
+import java.util.Map;
+
+import io.dropwizard.auth.basic.BasicCredentials;
+
+
+public class WebSocketAccountAuthenticator implements WebSocketAuthenticator {
+
+ private final AccountAuthenticator accountAuthenticator;
+
+ public WebSocketAccountAuthenticator(AccountAuthenticator accountAuthenticator) {
+ this.accountAuthenticator = accountAuthenticator;
+ }
+
+ @Override
+ public Optional authenticate(UpgradeRequest request) throws AuthenticationException {
+ try {
+ Map parameters = request.getParameterMap();
+ String[] usernames = parameters.get("login");
+ String[] passwords = parameters.get("password");
+
+ if (usernames == null || usernames.length == 0 ||
+ passwords == null || passwords.length == 0)
+ {
+ return Optional.absent();
+ }
+
+ BasicCredentials credentials = new BasicCredentials(usernames[0], passwords[0]);
+ return accountAuthenticator.authenticate(credentials);
+ } catch (io.dropwizard.auth.AuthenticationException e) {
+ throw new AuthenticationException(e);
+ }
+ }
+
+}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java
new file mode 100644
index 000000000..5c1902d37
--- /dev/null
+++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java
@@ -0,0 +1,160 @@
+package org.whispersystems.textsecuregcm.websocket;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.base.Optional;
+import com.google.common.util.concurrent.FutureCallback;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.textsecuregcm.entities.PendingMessage;
+import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
+import org.whispersystems.textsecuregcm.push.PushSender;
+import org.whispersystems.textsecuregcm.push.TransientPushFailureException;
+import org.whispersystems.textsecuregcm.storage.Account;
+import org.whispersystems.textsecuregcm.storage.AccountsManager;
+import org.whispersystems.textsecuregcm.storage.Device;
+import org.whispersystems.textsecuregcm.storage.PubSubListener;
+import org.whispersystems.textsecuregcm.storage.PubSubManager;
+import org.whispersystems.textsecuregcm.storage.PubSubMessage;
+import org.whispersystems.textsecuregcm.storage.StoredMessages;
+import org.whispersystems.textsecuregcm.util.SystemMapper;
+import org.whispersystems.websocket.WebSocketClient;
+import org.whispersystems.websocket.messages.WebSocketResponseMessage;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import java.io.IOException;
+import java.util.List;
+
+import static org.whispersystems.textsecuregcm.entities.MessageProtos.OutgoingMessageSignal;
+
+public class WebSocketConnection implements PubSubListener {
+
+ private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class);
+ private static final ObjectMapper objectMapper = SystemMapper.getMapper();
+
+ private final AccountsManager accountsManager;
+ private final PushSender pushSender;
+ private final StoredMessages storedMessages;
+ private final PubSubManager pubSubManager;
+
+ private final Account account;
+ private final Device device;
+ private final WebsocketAddress address;
+ private final WebSocketClient client;
+
+ public WebSocketConnection(AccountsManager accountsManager,
+ PushSender pushSender,
+ StoredMessages storedMessages,
+ PubSubManager pubSubManager,
+ Account account,
+ Device device,
+ WebSocketClient client)
+ {
+ this.accountsManager = accountsManager;
+ this.pushSender = pushSender;
+ this.storedMessages = storedMessages;
+ this.pubSubManager = pubSubManager;
+ this.account = account;
+ this.device = device;
+ this.client = client;
+ this.address = new WebsocketAddress(account.getNumber(), device.getId());
+ }
+
+ public void onConnected() {
+ pubSubManager.subscribe(address, this);
+ processStoredMessages();
+ }
+
+ public void onConnectionLost() {
+ pubSubManager.unsubscribe(address, this);
+ }
+
+ @Override
+ public void onPubSubMessage(PubSubMessage message) {
+ try {
+ switch (message.getType()) {
+ case PubSubMessage.TYPE_QUERY_DB:
+ processStoredMessages();
+ break;
+ case PubSubMessage.TYPE_DELIVER:
+ PendingMessage pendingMessage = objectMapper.readValue(message.getContents(),
+ PendingMessage.class);
+ sendMessage(pendingMessage);
+ break;
+ default:
+ logger.warn("Unknown pubsub message: " + message.getType());
+ }
+ } catch (IOException e) {
+ logger.warn("Error deserializing PendingMessage", e);
+ }
+ }
+
+ private void sendMessage(final PendingMessage message) {
+ String content = message.getEncryptedOutgoingMessage();
+ Optional body = Optional.fromNullable(content.getBytes());
+ ListenableFuture response = client.sendRequest("PUT", "/api/v1/message", body);
+
+ Futures.addCallback(response, new FutureCallback() {
+ @Override
+ public void onSuccess(@Nullable WebSocketResponseMessage response) {
+ if (isSuccessResponse(response) && !message.isReceipt()) {
+ sendDeliveryReceiptFor(message);
+ } else if (!isSuccessResponse(response)) {
+ requeueMessage(message);
+ }
+ }
+
+ @Override
+ public void onFailure(@Nonnull Throwable throwable) {
+ requeueMessage(message);
+ }
+
+ private boolean isSuccessResponse(WebSocketResponseMessage response) {
+ return response != null && response.getStatus() >= 200 && response.getStatus() < 300;
+ }
+ });
+ }
+
+ private void requeueMessage(PendingMessage message) {
+ try {
+ pushSender.sendMessage(account, device, message);
+ } catch (NotPushRegisteredException | TransientPushFailureException e) {
+ logger.warn("requeueMessage", e);
+ storedMessages.insert(address, message);
+ }
+ }
+
+ private void sendDeliveryReceiptFor(PendingMessage message) {
+ try {
+ Optional source = accountsManager.get(message.getSender());
+
+ if (!source.isPresent()) {
+ logger.warn("Source account disappeared? (%s)", message.getSender());
+ return;
+ }
+
+ OutgoingMessageSignal.Builder receipt =
+ OutgoingMessageSignal.newBuilder()
+ .setSource(account.getNumber())
+ .setSourceDevice((int) device.getId())
+ .setTimestamp(message.getMessageId())
+ .setType(OutgoingMessageSignal.Type.RECEIPT_VALUE);
+
+ for (Device device : source.get().getDevices()) {
+ pushSender.sendMessage(source.get(), device, receipt.build());
+ }
+ } catch (NotPushRegisteredException | TransientPushFailureException e) {
+ logger.warn("sendDeliveryReceiptFor", "Delivery receipet", e);
+ }
+ }
+
+ private void processStoredMessages() {
+ List messages = storedMessages.getMessagesForDevice(address);
+
+ for (PendingMessage message : messages) {
+ sendMessage(message);
+ }
+ }
+}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketControllerFactory.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketControllerFactory.java
deleted file mode 100644
index 364d4788f..000000000
--- a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketControllerFactory.java
+++ /dev/null
@@ -1,50 +0,0 @@
-package org.whispersystems.textsecuregcm.websocket;
-
-import org.eclipse.jetty.websocket.api.UpgradeRequest;
-import org.eclipse.jetty.websocket.api.UpgradeResponse;
-import org.eclipse.jetty.websocket.servlet.WebSocketCreator;
-import org.eclipse.jetty.websocket.servlet.WebSocketServlet;
-import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
-import org.whispersystems.textsecuregcm.controllers.WebsocketController;
-import org.whispersystems.textsecuregcm.push.PushSender;
-import org.whispersystems.textsecuregcm.storage.AccountsManager;
-import org.whispersystems.textsecuregcm.storage.PubSubManager;
-import org.whispersystems.textsecuregcm.storage.StoredMessages;
-
-
-public class WebsocketControllerFactory extends WebSocketServlet implements WebSocketCreator {
-
- private final Logger logger = LoggerFactory.getLogger(WebsocketControllerFactory.class);
-
- private final PushSender pushSender;
- private final StoredMessages storedMessages;
- private final PubSubManager pubSubManager;
- private final AccountAuthenticator accountAuthenticator;
- private final AccountsManager accounts;
-
- public WebsocketControllerFactory(AccountAuthenticator accountAuthenticator,
- AccountsManager accounts,
- PushSender pushSender,
- StoredMessages storedMessages,
- PubSubManager pubSubManager)
- {
- this.accountAuthenticator = accountAuthenticator;
- this.accounts = accounts;
- this.pushSender = pushSender;
- this.storedMessages = storedMessages;
- this.pubSubManager = pubSubManager;
- }
-
- @Override
- public void configure(WebSocketServletFactory factory) {
- factory.setCreator(this);
- }
-
- @Override
- public Object createWebSocket(UpgradeRequest upgradeRequest, UpgradeResponse upgradeResponse) {
- return new WebsocketController(accountAuthenticator, accounts, pushSender, pubSubManager, storedMessages);
- }
-}
diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketMessage.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketMessage.java
deleted file mode 100644
index 04e5587b2..000000000
--- a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebsocketMessage.java
+++ /dev/null
@@ -1,18 +0,0 @@
-package org.whispersystems.textsecuregcm.websocket;
-
-import com.fasterxml.jackson.annotation.JsonProperty;
-
-public class WebsocketMessage {
-
- @JsonProperty
- private long id;
-
- @JsonProperty
- private String message;
-
- public WebsocketMessage(long id, String message) {
- this.id = id;
- this.message = message;
- }
-
-}
diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/WebsocketControllerTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java
similarity index 50%
rename from src/test/java/org/whispersystems/textsecuregcm/tests/controllers/WebsocketControllerTest.java
rename to src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java
index 16424c17f..7b2ba931c 100644
--- a/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/WebsocketControllerTest.java
+++ b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java
@@ -1,16 +1,13 @@
-package org.whispersystems.textsecuregcm.tests.controllers;
+package org.whispersystems.textsecuregcm.tests.websocket;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Optional;
-import org.eclipse.jetty.websocket.api.CloseStatus;
-import org.eclipse.jetty.websocket.api.RemoteEndpoint;
-import org.eclipse.jetty.websocket.api.Session;
+import com.google.common.util.concurrent.SettableFuture;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
-import org.whispersystems.textsecuregcm.controllers.WebsocketController;
-import org.whispersystems.textsecuregcm.entities.AcknowledgeWebsocketMessage;
-import org.whispersystems.textsecuregcm.entities.EncryptedOutgoingMessage;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.PendingMessage;
import org.whispersystems.textsecuregcm.push.PushSender;
@@ -19,20 +16,27 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.StoredMessages;
+import org.whispersystems.textsecuregcm.websocket.ConnectListener;
+import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator;
+import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
-import org.whispersystems.textsecuregcm.websocket.WebsocketControllerFactory;
+import org.whispersystems.websocket.WebSocketClient;
+import org.whispersystems.websocket.messages.WebSocketResponseMessage;
+import org.whispersystems.websocket.session.WebSocketSessionContext;
+import java.io.IOException;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import io.dropwizard.auth.basic.BasicCredentials;
+import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.*;
-public class WebsocketControllerTest {
+public class WebSocketConnectionTest {
- private static final ObjectMapper mapper = new ObjectMapper();
+// private static final ObjectMapper mapper = new ObjectMapper();
private static final String VALID_USER = "+14152222222";
private static final String INVALID_USER = "+14151111111";
@@ -40,52 +44,74 @@ public class WebsocketControllerTest {
private static final String VALID_PASSWORD = "secure";
private static final String INVALID_PASSWORD = "insecure";
- private static final StoredMessages storedMessages = mock(StoredMessages.class);
+// private static final StoredMessages storedMessages = mock(StoredMessages.class);
private static final AccountAuthenticator accountAuthenticator = mock(AccountAuthenticator.class);
private static final AccountsManager accountsManager = mock(AccountsManager.class);
private static final PubSubManager pubSubManager = mock(PubSubManager.class );
private static final Account account = mock(Account.class );
private static final Device device = mock(Device.class );
private static final UpgradeRequest upgradeRequest = mock(UpgradeRequest.class );
- private static final Session session = mock(Session.class );
+// private static final Session session = mock(Session.class );
private static final PushSender pushSender = mock(PushSender.class);
@Test
public void testCredentials() throws Exception {
+ StoredMessages storedMessages = mock(StoredMessages.class);
+ WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
+ ConnectListener connectListener = new ConnectListener(accountsManager, pushSender, storedMessages, pubSubManager);
+ WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
+
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(account));
when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD))))
.thenReturn(Optional.absent());
- when(session.getUpgradeRequest()).thenReturn(upgradeRequest);
+ when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
- WebsocketController controller = new WebsocketController(accountAuthenticator, accountsManager, pushSender, pubSubManager, storedMessages);
+// when(session.getUpgradeRequest()).thenReturn(upgradeRequest);
+//
+// WebsocketController controller = new WebsocketController(accountAuthenticator, accountsManager, pushSender, pubSubManager, storedMessages);
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap() {{
put("login", new String[] {VALID_USER});
put("password", new String[] {VALID_PASSWORD});
}});
- controller.onWebSocketConnect(session);
+ Optional account = webSocketAuthenticator.authenticate(upgradeRequest);
+ when(sessionContext.getAuthenticated()).thenReturn(account.orNull());
- verify(session, never()).close();
- verify(session, never()).close(any(CloseStatus.class));
- verify(session, never()).close(anyInt(), anyString());
+ connectListener.onWebSocketConnect(sessionContext);
+
+ verify(sessionContext).addListener(any(WebSocketSessionContext.WebSocketEventListener.class));
+
+//
+// controller.onWebSocketConnect(session);
+
+// verify(session, never()).close();
+// verify(session, never()).close(any(CloseStatus.class));
+// verify(session, never()).close(anyInt(), anyString());
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap() {{
put("login", new String[] {INVALID_USER});
put("password", new String[] {INVALID_PASSWORD});
}});
- controller.onWebSocketConnect(session);
+ account = webSocketAuthenticator.authenticate(upgradeRequest);
+ when(sessionContext.getAuthenticated()).thenReturn(account.orNull());
- verify(session).close(any(CloseStatus.class));
+ WebSocketClient client = mock(WebSocketClient.class);
+ when(sessionContext.getClient()).thenReturn(client);
+
+ connectListener.onWebSocketConnect(sessionContext);
+
+ verify(sessionContext, times(1)).addListener(any(WebSocketSessionContext.WebSocketEventListener.class));
+ verify(client).close(eq(4001), anyString());
}
@Test
public void testOpen() throws Exception {
- RemoteEndpoint remote = mock(RemoteEndpoint.class);
+ StoredMessages storedMessages = mock(StoredMessages.class);
List outgoingMessages = new LinkedList() {{
add(new PendingMessage("sender1", 1111, false, "first"));
@@ -96,8 +122,6 @@ public class WebsocketControllerTest {
when(device.getId()).thenReturn(2L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222");
- when(session.getRemote()).thenReturn(remote);
- when(session.getUpgradeRequest()).thenReturn(upgradeRequest);
final Device sender1device = mock(Device.class);
@@ -111,27 +135,38 @@ public class WebsocketControllerTest {
when(accountsManager.get("sender1")).thenReturn(Optional.of(sender1));
when(accountsManager.get("sender2")).thenReturn(Optional.absent());
- when(upgradeRequest.getParameterMap()).thenReturn(new HashMap() {{
- put("login", new String[] {VALID_USER});
- put("password", new String[] {VALID_PASSWORD});
- }});
-
- when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
- .thenReturn(Optional.of(account));
-
when(storedMessages.getMessagesForDevice(new WebsocketAddress(account.getNumber(), device.getId())))
.thenReturn(outgoingMessages);
- WebsocketControllerFactory factory = new WebsocketControllerFactory(accountAuthenticator, accountsManager, pushSender, storedMessages, pubSubManager);
- WebsocketController controller = (WebsocketController) factory.createWebSocket(null, null);
+ final List> futures = new LinkedList<>();
+ final WebSocketClient client = mock(WebSocketClient.class);
- controller.onWebSocketConnect(session);
+ when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(Optional.class)))
+ .thenAnswer(new Answer>() {
+ @Override
+ public SettableFuture answer(InvocationOnMock invocationOnMock) throws Throwable {
+ SettableFuture future = SettableFuture.create();
+ futures.add(future);
+ return future;
+ }
+ });
- verify(pubSubManager).subscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq((controller)));
- verify(remote, times(3)).sendStringByFuture(anyString());
+ WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender, storedMessages,
+ pubSubManager, account, device, client);
- controller.onWebSocketText(mapper.writeValueAsString(new AcknowledgeWebsocketMessage(1)));
- controller.onWebSocketClose(1000, "Closed");
+ connection.onConnected();
+
+ verify(pubSubManager).subscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq((connection)));
+ verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(Optional.class));
+
+ assertTrue(futures.size() == 3);
+
+ WebSocketResponseMessage response = mock(WebSocketResponseMessage.class);
+ when(response.getStatus()).thenReturn(200);
+ futures.get(1).set(response);
+
+ futures.get(0).setException(new IOException());
+ futures.get(2).setException(new IOException());
List pending = new LinkedList() {{
add(new PendingMessage("sender1", 1111, false, "first"));
@@ -140,6 +175,9 @@ public class WebsocketControllerTest {
verify(pushSender, times(2)).sendMessage(eq(account), eq(device), any(PendingMessage.class));
verify(pushSender, times(1)).sendMessage(eq(sender1), eq(sender1device), any(MessageProtos.OutgoingMessageSignal.class));
+
+ connection.onConnectionLost();
+ verify(pubSubManager).unsubscribe(eq(new WebsocketAddress("+14152222222", 2L)), eq(connection));
}
}