diff --git a/pom.xml b/pom.xml
index 8adf7b631..5f47387b9 100644
--- a/pom.xml
+++ b/pom.xml
@@ -96,7 +96,7 @@
org.whispersystems
websocket-resources
- 0.5.2
+ 0.5.3
org.whispersystems
diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java
index 64149932e..ceca0bdb3 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java
@@ -3,6 +3,7 @@ package org.whispersystems.textsecuregcm.websocket;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
+import com.google.protobuf.ByteString;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.push.PushSender;
@@ -18,6 +19,8 @@ import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener;
+import java.security.SecureRandom;
+
import static com.codahale.metrics.MetricRegistry.name;
public class AuthenticatedConnectListener implements WebSocketConnectListener {
@@ -45,16 +48,21 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
@Override
public void onWebSocketConnect(WebSocketSessionContext context) {
- final Account account = context.getAuthenticated(Account.class);
- final Device device = account.getAuthenticatedDevice().get();
- final Timer.Context timer = durationTimer.time();
- final WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
- final WebSocketConnectionInfo info = new WebSocketConnectionInfo(address);
- final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender,
- messagesManager, account, device,
- context.getClient());
+ final Account account = context.getAuthenticated(Account.class);
+ final Device device = account.getAuthenticatedDevice().get();
+ final String connectionId = String.valueOf(new SecureRandom().nextLong());
+ final Timer.Context timer = durationTimer.time();
+ final WebsocketAddress address = new WebsocketAddress(account.getNumber(), device.getId());
+ final WebSocketConnectionInfo info = new WebSocketConnectionInfo(address);
+ final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender,
+ messagesManager, account, device,
+ context.getClient(), connectionId);
+ final PubSubMessage connectMessage = PubSubMessage.newBuilder().setType(PubSubMessage.Type.CONNECTED)
+ .setContent(ByteString.copyFrom(connectionId.getBytes()))
+ .build();
- pubSubManager.publish(info, PubSubMessage.newBuilder().setType(PubSubMessage.Type.CONNECTED).build());
+ pubSubManager.publish(info, connectMessage);
+ pubSubManager.publish(address, connectMessage);
pubSubManager.subscribe(address, connection);
context.addListener(new WebSocketSessionContext.WebSocketEventListener() {
diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java
index 6ad73ba98..7fb703c2b 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java
@@ -53,13 +53,15 @@ public class WebSocketConnection implements DispatchChannel {
private final Account account;
private final Device device;
private final WebSocketClient client;
+ private final String connectionId;
public WebSocketConnection(PushSender pushSender,
ReceiptSender receiptSender,
MessagesManager messagesManager,
Account account,
Device device,
- WebSocketClient client)
+ WebSocketClient client,
+ String connectionId)
{
this.pushSender = pushSender;
this.receiptSender = receiptSender;
@@ -67,6 +69,7 @@ public class WebSocketConnection implements DispatchChannel {
this.account = account;
this.device = device;
this.client = client;
+ this.connectionId = connectionId;
}
@Override
@@ -81,6 +84,11 @@ public class WebSocketConnection implements DispatchChannel {
case PubSubMessage.Type.DELIVER_VALUE:
sendMessage(Envelope.parseFrom(pubSubMessage.getContent()), Optional.absent(), false);
break;
+ case PubSubMessage.Type.CONNECTED_VALUE:
+ if (pubSubMessage.hasContent() && !new String(pubSubMessage.getContent().toByteArray()).equals(connectionId)) {
+ client.hardDisconnectQuietly();
+ }
+ break;
default:
logger.warn("Unknown pubsub message: " + pubSubMessage.getType().getNumber());
}
diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java
index 6785d7c39..949e5b1e5 100644
--- a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java
+++ b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java
@@ -143,7 +143,7 @@ public class WebSocketConnectionTest {
WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId());
WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, storedMessages,
- account, device, client);
+ account, device, client, "someid");
connection.onDispatchSubscribed(websocketAddress.serialize());
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.>any());
@@ -227,7 +227,7 @@ public class WebSocketConnectionTest {
WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId());
WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, storedMessages,
- account, device, client);
+ account, device, client, "anotherid");
connection.onDispatchSubscribed(websocketAddress.serialize());
connection.onDispatchMessage(websocketAddress.serialize(), PubSubProtos.PubSubMessage.newBuilder()
@@ -333,7 +333,7 @@ public class WebSocketConnectionTest {
WebsocketAddress websocketAddress = new WebsocketAddress(account.getNumber(), device.getId());
WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, storedMessages,
- account, device, client);
+ account, device, client, "onemoreid");
connection.onDispatchSubscribed(websocketAddress.serialize());