From 789f11a5c4d1241d2aa5c300f1e2d9c7ab273507 Mon Sep 17 00:00:00 2001 From: Moxie Marlinspike Date: Fri, 18 Aug 2017 15:58:12 -0700 Subject: [PATCH] Disconnect sockets on other servers when new websocket comes in // FREEBIE --- pom.xml | 2 +- .../AuthenticatedConnectListener.java | 26 ++++++++++++------- .../websocket/WebSocketConnection.java | 10 ++++++- .../websocket/WebSocketConnectionTest.java | 6 ++--- 4 files changed, 30 insertions(+), 14 deletions(-) diff --git a/pom.xml b/pom.xml index 8adf7b631..5f47387b9 100644 --- a/pom.xml +++ b/pom.xml @@ -96,7 +96,7 @@ org.whispersystems websocket-resources - 0.5.2 + 0.5.3 org.whispersystems diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index 64149932e..ceca0bdb3 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -3,6 +3,7 @@ package org.whispersystems.textsecuregcm.websocket; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.Timer; +import com.google.protobuf.ByteString; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.push.PushSender; @@ -18,6 +19,8 @@ import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.websocket.session.WebSocketSessionContext; import org.whispersystems.websocket.setup.WebSocketConnectListener; +import java.security.SecureRandom; + import static com.codahale.metrics.MetricRegistry.name; public class AuthenticatedConnectListener implements WebSocketConnectListener { @@ -45,16 +48,21 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { @Override public void onWebSocketConnect(WebSocketSessionContext context) { - final Account account = context.getAuthenticated(Account.class); - final Device device = account.getAuthenticatedDevice().get(); - final Timer.Context timer = durationTimer.time(); - final WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId()); - final WebSocketConnectionInfo info = new WebSocketConnectionInfo(address); - final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, - messagesManager, account, device, - context.getClient()); + final Account account = context.getAuthenticated(Account.class); + final Device device = account.getAuthenticatedDevice().get(); + final String connectionId = String.valueOf(new SecureRandom().nextLong()); + final Timer.Context timer = durationTimer.time(); + final WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId()); + final WebSocketConnectionInfo info = new WebSocketConnectionInfo(address); + final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, + messagesManager, account, device, + context.getClient(), connectionId); + final PubSubMessage connectMessage = PubSubMessage.newBuilder().setType(PubSubMessage.Type.CONNECTED) + .setContent(ByteString.copyFrom(connectionId.getBytes())) + .build(); - pubSubManager.publish(info, PubSubMessage.newBuilder().setType(PubSubMessage.Type.CONNECTED).build()); + pubSubManager.publish(info, connectMessage); + pubSubManager.publish(address, connectMessage); pubSubManager.subscribe(address, connection); context.addListener(new WebSocketSessionContext.WebSocketEventListener() { diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 6ad73ba98..7fb703c2b 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -53,13 +53,15 @@ public class WebSocketConnection implements DispatchChannel { private final Account account; private final Device device; private final WebSocketClient client; + private final String connectionId; public WebSocketConnection(PushSender pushSender, ReceiptSender receiptSender, MessagesManager messagesManager, Account account, Device device, - WebSocketClient client) + WebSocketClient client, + String connectionId) { this.pushSender = pushSender; this.receiptSender = receiptSender; @@ -67,6 +69,7 @@ public class WebSocketConnection implements DispatchChannel { this.account = account; this.device = device; this.client = client; + this.connectionId = connectionId; } @Override @@ -81,6 +84,11 @@ public class WebSocketConnection implements DispatchChannel { case PubSubMessage.Type.DELIVER_VALUE: sendMessage(Envelope.parseFrom(pubSubMessage.getContent()), Optional.absent(), false); break; + case PubSubMessage.Type.CONNECTED_VALUE: + if (pubSubMessage.hasContent() && !new String(pubSubMessage.getContent().toByteArray()).equals(connectionId)) { + client.hardDisconnectQuietly(); + } + break; default: logger.warn("Unknown pubsub message: " + pubSubMessage.getType().getNumber()); } diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java index 6785d7c39..949e5b1e5 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java @@ -143,7 +143,7 @@ public class WebSocketConnectionTest { WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId()); WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, storedMessages, - account, device, client); + account, device, client, "someid"); connection.onDispatchSubscribed(websocketAddress.serialize()); verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.>any()); @@ -227,7 +227,7 @@ public class WebSocketConnectionTest { WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId()); WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, storedMessages, - account, device, client); + account, device, client, "anotherid"); connection.onDispatchSubscribed(websocketAddress.serialize()); connection.onDispatchMessage(websocketAddress.serialize(), PubSubProtos.PubSubMessage.newBuilder() @@ -333,7 +333,7 @@ public class WebSocketConnectionTest { WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId()); WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, storedMessages, - account, device, client); + account, device, client, "onemoreid"); connection.onDispatchSubscribed(websocketAddress.serialize());