From 05d9ec673e8c8d096e79b1f9a775eab8db5c6d4e Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Tue, 27 Oct 2020 16:02:55 -0400 Subject: [PATCH] Send push notifications if websockets close before all messages are delivered --- .../textsecuregcm/WhisperServerService.java | 2 +- .../textsecuregcm/push/MessageSender.java | 23 ++++++++++++++----- .../textsecuregcm/storage/Messages.java | 1 + .../textsecuregcm/storage/MessagesCache.java | 4 ++++ .../storage/MessagesManager.java | 4 ++++ .../AuthenticatedConnectListener.java | 9 +++++++- .../storage/MessagesCacheTest.java | 11 +++++++++ .../websocket/WebSocketConnectionTest.java | 3 ++- 8 files changed, 48 insertions(+), 9 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 75c0bef15..d5f39b81a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -423,7 +423,7 @@ public class WhisperServerService extends Application webSocketEnvironment = new WebSocketEnvironment<>(environment, config.getWebSocketConfiguration(), 90000); webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator)); - webSocketEnvironment.setConnectListener(new AuthenticatedConnectListener(receiptSender, messagesManager, apnFallbackManager, clientPresenceManager)); + webSocketEnvironment.setConnectListener(new AuthenticatedConnectListener(receiptSender, messagesManager, messageSender, apnFallbackManager, clientPresenceManager)); webSocketEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET)); webSocketEnvironment.jersey().register(new KeepAliveController(clientPresenceManager)); webSocketEnvironment.jersey().register(messageController); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java index 72f892f0c..78bbbe1bd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/MessageSender.java @@ -138,21 +138,24 @@ public class MessageSender implements Managed { throw new AssertionError(); } - final boolean clientPresent = clientPresenceManager.isPresent(account.getUuid(), device.getId()); + final boolean clientPresent; if (online) { + clientPresent = clientPresenceManager.isPresent(account.getUuid(), device.getId()); + if (clientPresent) { messagesManager.insertEphemeral(account.getUuid(), device.getId(), message); } } else { messagesManager.insert(account.getUuid(), device.getId(), message); + // We check for client presence after inserting the message to take a conservative view of notifications. If the + // client wasn't present at the time of insertion but is now, they'll retrieve the message. If they were present + // but disconnected before the message was delivered, we should send a notification. + clientPresent = clientPresenceManager.isPresent(account.getUuid(), device.getId()); + if (!clientPresent) { - if (!Util.isEmpty(device.getGcmId())) { - sendGcmNotification(account, device); - } else if (!Util.isEmpty(device.getApnId()) || !Util.isEmpty(device.getVoipApnId())) { - sendApnNotification(account, device); - } + sendNewMessageNotification(account, device); } } @@ -164,6 +167,14 @@ public class MessageSender implements Managed { Metrics.counter(SEND_COUNTER_NAME, tags).increment(); } + public void sendNewMessageNotification(final Account account, final Device device) { + if (!Util.isEmpty(device.getGcmId())) { + sendGcmNotification(account, device); + } else if (!Util.isEmpty(device.getApnId()) || !Util.isEmpty(device.getVoipApnId())) { + sendApnNotification(account, device); + } + } + private void sendGcmNotification(Account account, Device device) { GcmMessage gcmMessage = new GcmMessage(device.getGcmId(), account.getNumber(), (int)device.getId(), GcmMessage.Type.NOTIFICATION, Optional.empty()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java index f3e6636e5..f624cab53 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Messages.java @@ -41,6 +41,7 @@ public class Messages { private final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private final Timer storeTimer = metricRegistry.timer(name(Messages.class, "store" )); private final Timer loadTimer = metricRegistry.timer(name(Messages.class, "load" )); + private final Timer hasMessagesTimer = metricRegistry.timer(name(Messages.class, "hasMessages" )); private final Timer removeBySourceTimer = metricRegistry.timer(name(Messages.class, "removeBySource")); private final Timer removeByGuidTimer = metricRegistry.timer(name(Messages.class, "removeByGuid" )); private final Timer removeByIdTimer = metricRegistry.timer(name(Messages.class, "removeById" )); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 8100f2c0f..3ccfa3cce 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -220,6 +220,10 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp return removedMessages; } + public boolean hasMessages(final UUID destinationUuid, final long destinationDevice) { + return redisCluster.withBinaryCluster(connection -> connection.sync().zcard(getMessageQueueKey(destinationUuid, destinationDevice)) > 0); + } + @SuppressWarnings("unchecked") public List get(final UUID destinationUuid, final long destinationDevice, final int limit) { return getMessagesTimer.record(() -> { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java index 61c941e24..519b2aeff 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesManager.java @@ -51,6 +51,10 @@ public class MessagesManager { return messagesCache.takeEphemeralMessage(destinationUuid, destinationDevice); } + public boolean hasCachedMessages(final UUID destinationUuid, final long destinationDevice) { + return messagesCache.hasMessages(destinationUuid, destinationDevice); + } + public OutgoingMessageEntityList getMessagesForDevice(String destination, UUID destinationUuid, long destinationDevice, final String userAgent, final boolean cachedMessagesOnly) { RedisOperation.unchecked(() -> pushLatencyManager.recordQueueRead(destinationUuid, destinationDevice, userAgent)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index 25bce383b..e60dc49bc 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -6,6 +6,7 @@ import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.Timer; import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.storage.Account; @@ -26,16 +27,18 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { private final ReceiptSender receiptSender; private final MessagesManager messagesManager; + private final MessageSender messageSender; private final ApnFallbackManager apnFallbackManager; private final ClientPresenceManager clientPresenceManager; public AuthenticatedConnectListener(ReceiptSender receiptSender, MessagesManager messagesManager, - ApnFallbackManager apnFallbackManager, + final MessageSender messageSender, ApnFallbackManager apnFallbackManager, ClientPresenceManager clientPresenceManager) { this.receiptSender = receiptSender; this.messagesManager = messagesManager; + this.messageSender = messageSender; this.apnFallbackManager = apnFallbackManager; this.clientPresenceManager = clientPresenceManager; } @@ -66,6 +69,10 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { openWebsocketCounter.dec(); timer.stop(); + + if (messagesManager.hasCachedMessages(account.getUuid(), device.getId())) { + messageSender.sendNewMessageNotification(account, device); + } } }); } else { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index 7fc5996a8..17dc78aab 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -160,6 +160,17 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { assertEquals(messagesToPreserve, messagesCache.getMessagesToPersist(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); } + @Test + public void testHasMessages() { + assertFalse(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID)); + + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true); + messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + + assertTrue(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID)); + } + @Test @Parameters({"true", "false"}) public void testGetMessages(final boolean sealedSender) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index e19c24e89..14c89ffa3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -12,6 +12,7 @@ import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; @@ -72,7 +73,7 @@ public class WebSocketConnectionTest { public void testCredentials() throws Exception { MessagesManager storedMessages = mock(MessagesManager.class); WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator); - AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages, apnFallbackManager, mock(ClientPresenceManager.class)); + AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages, mock(MessageSender.class), apnFallbackManager, mock(ClientPresenceManager.class)); WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))