diff --git a/protobuf/PubSubMessage.proto b/protobuf/PubSubMessage.proto index 7b257a368..980ab4840 100644 --- a/protobuf/PubSubMessage.proto +++ b/protobuf/PubSubMessage.proto @@ -25,6 +25,7 @@ message PubSubMessage { QUERY_DB = 1; DELIVER = 2; KEEPALIVE = 3; + CLOSE = 4; } optional Type type = 1; diff --git a/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubProtos.java b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubProtos.java index 8785e372d..7ca6225e3 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubProtos.java +++ b/src/main/java/org/whispersystems/textsecuregcm/storage/PubSubProtos.java @@ -158,6 +158,10 @@ public final class PubSubProtos { * KEEPALIVE = 3; */ KEEPALIVE(3, 3), + /** + * CLOSE = 4; + */ + CLOSE(4, 4), ; /** @@ -176,6 +180,10 @@ public final class PubSubProtos { * KEEPALIVE = 3; */ public static final int KEEPALIVE_VALUE = 3; + /** + * CLOSE = 4; + */ + public static final int CLOSE_VALUE = 4; public final int getNumber() { return value; } @@ -186,6 +194,7 @@ public final class PubSubProtos { case 1: return QUERY_DB; case 2: return DELIVER; case 3: return KEEPALIVE; + case 4: return CLOSE; default: return null; } } @@ -611,12 +620,13 @@ public final class PubSubProtos { descriptor; static { java.lang.String[] descriptorData = { - "\n\023PubSubMessage.proto\022\ntextsecure\"\215\001\n\rPu" + + "\n\023PubSubMessage.proto\022\ntextsecure\"\230\001\n\rPu" + "bSubMessage\022,\n\004type\030\001 \001(\0162\036.textsecure.P" + - "ubSubMessage.Type\022\017\n\007content\030\002 \001(\014\"=\n\004Ty" + + "ubSubMessage.Type\022\017\n\007content\030\002 \001(\014\"H\n\004Ty" + "pe\022\013\n\007UNKNOWN\020\000\022\014\n\010QUERY_DB\020\001\022\013\n\007DELIVER" + - "\020\002\022\r\n\tKEEPALIVE\020\003B8\n(org.whispersystems." + - "textsecuregcm.storageB\014PubSubProtos" + "\020\002\022\r\n\tKEEPALIVE\020\003\022\t\n\005CLOSE\020\004B8\n(org.whis" + + "persystems.textsecuregcm.storageB\014PubSub" + + "Protos" }; com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = new com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner() { diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index 461e4c0d1..df11cf9be 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -1,6 +1,5 @@ package org.whispersystems.textsecuregcm.websocket; -import com.google.common.base.Optional; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.push.PushSender; @@ -9,6 +8,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.PubSubManager; +import org.whispersystems.textsecuregcm.storage.PubSubProtos; import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.websocket.session.WebSocketSessionContext; import org.whispersystems.websocket.setup.WebSocketConnectListener; @@ -36,10 +36,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { Account account = context.getAuthenticated(Account.class).get(); Device device = account.getAuthenticatedDevice().get(); - if (device.getLastSeen() != Util.todayInMillis()) { - device.setLastSeen(Util.todayInMillis()); - accountsManager.update(account); - } + updateLastSeen(account, device); + closeExistingDeviceConnection(account, device); final WebSocketConnection connection = new WebSocketConnection(accountsManager, pushSender, messagesManager, pubSubManager, @@ -55,4 +53,19 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { } }); } + + private void updateLastSeen(Account account, Device device) { + if (device.getLastSeen() != Util.todayInMillis()) { + device.setLastSeen(Util.todayInMillis()); + accountsManager.update(account); + } + } + + private void closeExistingDeviceConnection(Account account, Device device) { + pubSubManager.publish(new WebsocketAddress(account.getNumber(), device.getId()), + PubSubProtos.PubSubMessage.newBuilder() + .setType(PubSubProtos.PubSubMessage.Type.CLOSE) + .build()); + } } + diff --git a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 18ffd25b2..6bc9d9130 100644 --- a/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -80,6 +80,10 @@ public class WebSocketConnection implements PubSubListener { case PubSubMessage.Type.DELIVER_VALUE: sendMessage(OutgoingMessageSignal.parseFrom(pubSubMessage.getContent()), Optional.absent()); break; + case PubSubMessage.Type.CLOSE_VALUE: + client.close(1000, "OK"); + pubSubManager.unsubscribe(address, this); + break; default: logger.warn("Unknown pubsub message: " + pubSubMessage.getType().getNumber()); } diff --git a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java index 8b0bf8d2a..32833cd2f 100644 --- a/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java +++ b/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java @@ -5,6 +5,7 @@ import com.google.common.util.concurrent.SettableFuture; import com.google.protobuf.ByteString; import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; @@ -14,6 +15,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.PubSubManager; +import org.whispersystems.textsecuregcm.storage.PubSubProtos; import org.whispersystems.textsecuregcm.util.Base64; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.websocket.AuthenticatedConnectListener; @@ -23,6 +25,7 @@ import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.messages.WebSocketResponseMessage; import org.whispersystems.websocket.session.WebSocketSessionContext; +import org.whispersystems.websocket.setup.WebSocketConnectListener; import java.io.IOException; import java.util.HashMap; @@ -32,6 +35,7 @@ import java.util.List; import java.util.Set; import io.dropwizard.auth.basic.BasicCredentials; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.eq; @@ -58,6 +62,27 @@ public class WebSocketConnectionTest { // private static final Session session = mock(Session.class ); private static final PushSender pushSender = mock(PushSender.class); + @Test + public void testCloseExisting() throws Exception { + MessagesManager storedMessages = mock(MessagesManager.class ); + WebSocketConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, pushSender, storedMessages, pubSubManager); + WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); + Account account = mock(Account.class ); + Device device = mock(Device.class ); + + when(sessionContext.getAuthenticated(Account.class)).thenReturn(Optional.of(account)); + when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); + when(account.getNumber()).thenReturn("+14157777777"); + when(device.getId()).thenReturn(1L); + + connectListener.onWebSocketConnect(sessionContext); + + ArgumentCaptor message = ArgumentCaptor.forClass(PubSubProtos.PubSubMessage.class); + + verify(pubSubManager).publish(eq(new WebsocketAddress("+14157777777", 1L)), message.capture()); + assertEquals(message.getValue().getType().getNumber(), PubSubProtos.PubSubMessage.Type.CLOSE_VALUE); + } + @Test public void testCredentials() throws Exception { MessagesManager storedMessages = mock(MessagesManager.class);