diff --git a/protobuf/PubSubMessage.proto b/protobuf/PubSubMessage.proto
index 7b257a368..980ab4840 100644
--- a/protobuf/PubSubMessage.proto
+++ b/protobuf/PubSubMessage.proto
@@ -25,6 +25,7 @@ message PubSubMessage {
QUERY_DB = 1;
DELIVER = 2;
KEEPALIVE = 3;
+ CLOSE = 4;
}
optional Type type = 1;
diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubProtos.java b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubProtos.java
index 8785e372d..7ca6225e3 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubProtos.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubProtos.java
@@ -158,6 +158,10 @@ public final class PubSubProtos {
* KEEPALIVE = 3;
*/
KEEPALIVE(3, 3),
+ /**
+ * CLOSE = 4;
+ */
+ CLOSE(4, 4),
;
/**
@@ -176,6 +180,10 @@ public final class PubSubProtos {
* KEEPALIVE = 3;
*/
public static final int KEEPALIVE_VALUE = 3;
+ /**
+ * CLOSE = 4;
+ */
+ public static final int CLOSE_VALUE = 4;
public final int getNumber() { return value; }
@@ -186,6 +194,7 @@ public final class PubSubProtos {
case 1: return QUERY_DB;
case 2: return DELIVER;
case 3: return KEEPALIVE;
+ case 4: return CLOSE;
default: return null;
}
}
@@ -611,12 +620,13 @@ public final class PubSubProtos {
descriptor;
static {
java.lang.String[] descriptorData = {
- "\n\023PubSubMessage.proto\022\ntextsecure\"\215\001\n\rPu" +
+ "\n\023PubSubMessage.proto\022\ntextsecure\"\230\001\n\rPu" +
"bSubMessage\022,\n\004type\030\001 \001(\0162\036.textsecure.P" +
- "ubSubMessage.Type\022\017\n\007content\030\002 \001(\014\"=\n\004Ty" +
+ "ubSubMessage.Type\022\017\n\007content\030\002 \001(\014\"H\n\004Ty" +
"pe\022\013\n\007UNKNOWN\020\000\022\014\n\010QUERY_DB\020\001\022\013\n\007DELIVER" +
- "\020\002\022\r\n\tKEEPALIVE\020\003B8\n(org.whispersystems." +
- "textsecuregcm.storageB\014PubSubProtos"
+ "\020\002\022\r\n\tKEEPALIVE\020\003\022\t\n\005CLOSE\020\004B8\n(org.whis" +
+ "persystems.textsecuregcm.storageB\014PubSub" +
+ "Protos"
};
com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner =
new com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner() {
diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java
index 461e4c0d1..df11cf9be 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java
@@ -1,6 +1,5 @@
package org.whispersystems.textsecuregcm.websocket;
-import com.google.common.base.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.push.PushSender;
@@ -9,6 +8,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
+import org.whispersystems.textsecuregcm.storage.PubSubProtos;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener;
@@ -36,10 +36,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
Account account = context.getAuthenticated(Account.class).get();
Device device = account.getAuthenticatedDevice().get();
- if (device.getLastSeen() != Util.todayInMillis()) {
- device.setLastSeen(Util.todayInMillis());
- accountsManager.update(account);
- }
+ updateLastSeen(account, device);
+ closeExistingDeviceConnection(account, device);
final WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender,
messagesManager, pubSubManager,
@@ -55,4 +53,19 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
}
});
}
+
+ private void updateLastSeen(Account account, Device device) {
+ if (device.getLastSeen() != Util.todayInMillis()) {
+ device.setLastSeen(Util.todayInMillis());
+ accountsManager.update(account);
+ }
+ }
+
+ private void closeExistingDeviceConnection(Account account, Device device) {
+ pubSubManager.publish(new WebsocketAddress(account.getNumber(), device.getId()),
+ PubSubProtos.PubSubMessage.newBuilder()
+ .setType(PubSubProtos.PubSubMessage.Type.CLOSE)
+ .build());
+ }
}
+
diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java
index 18ffd25b2..6bc9d9130 100644
--- a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java
+++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java
@@ -80,6 +80,10 @@ public class WebSocketConnection implements PubSubListener {
case PubSubMessage.Type.DELIVER_VALUE:
sendMessage(OutgoingMessageSignal.parseFrom(pubSubMessage.getContent()), Optional.absent());
break;
+ case PubSubMessage.Type.CLOSE_VALUE:
+ client.close(1000, "OK");
+ pubSubManager.unsubscribe(address, this);
+ break;
default:
logger.warn("Unknown pubsub message: " + pubSubMessage.getType().getNumber());
}
diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java
index 8b0bf8d2a..32833cd2f 100644
--- a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java
+++ b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java
@@ -5,6 +5,7 @@ import com.google.common.util.concurrent.SettableFuture;
import com.google.protobuf.ByteString;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.junit.Test;
+import org.mockito.ArgumentCaptor;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
@@ -14,6 +15,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
+import org.whispersystems.textsecuregcm.storage.PubSubProtos;
import org.whispersystems.textsecuregcm.util.Base64;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.websocket.AuthenticatedConnectListener;
@@ -23,6 +25,7 @@ import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.session.WebSocketSessionContext;
+import org.whispersystems.websocket.setup.WebSocketConnectListener;
import java.io.IOException;
import java.util.HashMap;
@@ -32,6 +35,7 @@ import java.util.List;
import java.util.Set;
import io.dropwizard.auth.basic.BasicCredentials;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.eq;
@@ -58,6 +62,27 @@ public class WebSocketConnectionTest {
// private static final Session session = mock(Session.class );
private static final PushSender pushSender = mock(PushSender.class);
+ @Test
+ public void testCloseExisting() throws Exception {
+ MessagesManager storedMessages = mock(MessagesManager.class );
+ WebSocketConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, pushSender, storedMessages, pubSubManager);
+ WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
+ Account account = mock(Account.class );
+ Device device = mock(Device.class );
+
+ when(sessionContext.getAuthenticated(Account.class)).thenReturn(Optional.of(account));
+ when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
+ when(account.getNumber()).thenReturn("+14157777777");
+ when(device.getId()).thenReturn(1L);
+
+ connectListener.onWebSocketConnect(sessionContext);
+
+ ArgumentCaptor message = ArgumentCaptor.forClass(PubSubProtos.PubSubMessage.class);
+
+ verify(pubSubManager).publish(eq(new WebsocketAddress("+14157777777", 1L)), message.capture());
+ assertEquals(message.getValue().getType().getNumber(), PubSubProtos.PubSubMessage.Type.CLOSE_VALUE);
+ }
+
@Test
public void testCredentials() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class);