diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index a32fdfa25..45bb6ef4f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -601,7 +601,6 @@ public class WhisperServerService extends Application implements Managed, DisconnectionRequestListener { + private final AccountsManager accountsManager; + private final PushNotificationManager pushNotificationManager; private final FaultTolerantRedisClusterClient clusterClient; private final Executor listenerEventExecutor; @@ -81,8 +84,11 @@ public class WebSocketConnectionEventManager extends RedisClusterPubSubAdapter logger.warn("Unexpected client event type: {}", clientEvent.getClass()); } } else { - MESSAGE_WITHOUT_LISTENER_COUNTER.increment(); + PUB_SUB_EVENT_WITHOUT_LISTENER_COUNTER.increment(); + listenerEventExecutor.execute(() -> unsubscribeIfMissingListener(accountAndDeviceIdentifier)); + + if (clientEvent.getEventCase() == ClientEvent.EventCase.NEW_MESSAGE_AVAILABLE) { + MESSAGE_AVAILABLE_WITHOUT_LISTENER_COUNTER.increment(); + + // If we have an active subscription but no registered listener, it's likely that the publisher of this event + // believes that the receiving client was present when it really wasn't. Send a push notification as a + // just-in-case measure. + accountsManager.getByAccountIdentifierAsync(accountAndDeviceIdentifier.accountIdentifier()) + .thenAccept(maybeAccount -> maybeAccount.ifPresent(account -> { + try { + pushNotificationManager.sendNewMessageNotification(account, accountAndDeviceIdentifier.deviceId(), true); + } catch (final NotPushRegisteredException ignored) { + } + })) + .whenComplete((ignored, throwable) -> { + if (throwable != null) { + logger.warn("Failed to send follow-up notification to {}:{}", accountAndDeviceIdentifier.accountIdentifier(), accountAndDeviceIdentifier.deviceId()); + } + }); + } } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java index 7a6e9c4ac..9d207dd67 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java @@ -211,7 +211,6 @@ record CommandDependencies( SecureStorageClient secureStorageClient = new SecureStorageClient(storageCredentialsGenerator, storageServiceExecutor, storageServiceRetryExecutor, configuration.getSecureStorageServiceConfiguration()); DisconnectionRequestManager disconnectionRequestManager = new DisconnectionRequestManager(pubsubClient, disconnectionRequestListenerExecutor); - WebSocketConnectionEventManager webSocketConnectionEventManager = new WebSocketConnectionEventManager(messagesCluster, clientEventExecutor); MessagesCache messagesCache = new MessagesCache(messagesCluster, messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC(), dynamicConfigurationManager); ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster); @@ -264,6 +263,9 @@ record CommandDependencies( configuration.getDynamoDbTables().getPushNotificationExperimentSamples().getTableName(), Clock.systemUTC()); + WebSocketConnectionEventManager webSocketConnectionEventManager = + new WebSocketConnectionEventManager(accountsManager, pushNotificationManager, messagesCluster, clientEventExecutor); + environment.lifecycle().manage(apnSender); environment.lifecycle().manage(disconnectionRequestManager); environment.lifecycle().manage(webSocketConnectionEventManager); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManagerTest.java index 8819e1730..917cbb85a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/WebSocketConnectionEventManagerTest.java @@ -11,6 +11,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import io.lettuce.core.cluster.SlotHash; @@ -19,7 +20,9 @@ import io.lettuce.core.cluster.models.partitions.RedisClusterNode; import io.lettuce.core.cluster.pubsub.api.async.RedisClusterPubSubAsyncCommands; import io.lettuce.core.cluster.pubsub.api.sync.RedisClusterPubSubCommands; import java.util.List; +import java.util.Optional; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; @@ -35,6 +38,8 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; @@ -45,7 +50,7 @@ class WebSocketConnectionEventManagerTest { private WebSocketConnectionEventManager localEventManager; private WebSocketConnectionEventManager remoteEventManager; - private static ExecutorService clientEventExecutor; + private static ExecutorService webSocketConnectionEventExecutor; @RegisterExtension static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); @@ -67,13 +72,20 @@ class WebSocketConnectionEventManagerTest { @BeforeAll static void setUpBeforeAll() { - clientEventExecutor = Executors.newVirtualThreadPerTaskExecutor(); + webSocketConnectionEventExecutor = Executors.newVirtualThreadPerTaskExecutor(); } @BeforeEach void setUp() { - localEventManager = new WebSocketConnectionEventManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(), clientEventExecutor); - remoteEventManager = new WebSocketConnectionEventManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(), clientEventExecutor); + localEventManager = new WebSocketConnectionEventManager(mock(AccountsManager.class), + mock(PushNotificationManager.class), + REDIS_CLUSTER_EXTENSION.getRedisCluster(), + webSocketConnectionEventExecutor); + + remoteEventManager = new WebSocketConnectionEventManager(mock(AccountsManager.class), + mock(PushNotificationManager.class), + REDIS_CLUSTER_EXTENSION.getRedisCluster(), + webSocketConnectionEventExecutor); localEventManager.start(); remoteEventManager.start(); @@ -87,7 +99,7 @@ class WebSocketConnectionEventManagerTest { @AfterAll static void tearDownAfterAll() { - clientEventExecutor.shutdown(); + webSocketConnectionEventExecutor.shutdown(); } @ParameterizedTest @@ -226,7 +238,11 @@ class WebSocketConnectionEventManagerTest { .binaryPubSubAsyncCommands(pubSubAsyncCommands) .build(); - final WebSocketConnectionEventManager eventManager = new WebSocketConnectionEventManager(clusterClient, Runnable::run); + final WebSocketConnectionEventManager eventManager = new WebSocketConnectionEventManager( + mock(AccountsManager.class), + mock(PushNotificationManager.class), + clusterClient, + Runnable::run); eventManager.start(); @@ -279,7 +295,7 @@ class WebSocketConnectionEventManagerTest { } @Test - void smessageWithoutListener() { + void unsubscribeIfMissingListener() { @SuppressWarnings("unchecked") final RedisClusterPubSubAsyncCommands pubSubAsyncCommands = mock(RedisClusterPubSubAsyncCommands.class); @@ -289,7 +305,11 @@ class WebSocketConnectionEventManagerTest { .binaryPubSubAsyncCommands(pubSubAsyncCommands) .build(); - final WebSocketConnectionEventManager eventManager = new WebSocketConnectionEventManager(clusterClient, Runnable::run); + final WebSocketConnectionEventManager eventManager = new WebSocketConnectionEventManager( + mock(AccountsManager.class), + mock(PushNotificationManager.class), + clusterClient, + Runnable::run); eventManager.start(); @@ -315,4 +335,59 @@ class WebSocketConnectionEventManagerTest { verify(pubSubAsyncCommands) .sunsubscribe(WebSocketConnectionEventManager.getClientEventChannel(noListenerAccountIdentifier, noListenerDeviceId)); } + + @Test + void newMessageNotificationWithoutListener() throws NotPushRegisteredException { + final UUID listenerAccountIdentifier = UUID.randomUUID(); + final byte listenerDeviceId = Device.PRIMARY_ID; + + final UUID noListenerAccountIdentifier = UUID.randomUUID(); + final byte noListenerDeviceId = listenerDeviceId + 1; + + final Account noListenerAccount = mock(Account.class); + + final AccountsManager accountsManager = mock(AccountsManager.class); + + when(accountsManager.getByAccountIdentifierAsync(noListenerAccountIdentifier)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(noListenerAccount))); + + final PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class); + + @SuppressWarnings("unchecked") final RedisClusterPubSubAsyncCommands pubSubAsyncCommands = + mock(RedisClusterPubSubAsyncCommands.class); + + when(pubSubAsyncCommands.ssubscribe(any())).thenReturn(MockRedisFuture.completedFuture(null)); + + final FaultTolerantRedisClusterClient clusterClient = RedisClusterHelper.builder() + .binaryPubSubAsyncCommands(pubSubAsyncCommands) + .build(); + + final WebSocketConnectionEventManager eventManager = new WebSocketConnectionEventManager( + accountsManager, + pushNotificationManager, + clusterClient, + Runnable::run); + + eventManager.start(); + + eventManager.handleClientConnected(listenerAccountIdentifier, listenerDeviceId, new WebSocketConnectionEventAdapter()) + .toCompletableFuture() + .join(); + + final byte[] newMessagePayload = ClientEvent.newBuilder() + .setNewMessageAvailable(NewMessageAvailableEvent.getDefaultInstance()) + .build() + .toByteArray(); + + eventManager.smessage(mock(RedisClusterNode.class), + WebSocketConnectionEventManager.getClientEventChannel(listenerAccountIdentifier, listenerDeviceId), + newMessagePayload); + + eventManager.smessage(mock(RedisClusterNode.class), + WebSocketConnectionEventManager.getClientEventChannel(noListenerAccountIdentifier, noListenerDeviceId), + newMessagePayload); + + verify(pushNotificationManager).sendNewMessageNotification(noListenerAccount, noListenerDeviceId, true); + verifyNoMoreInteractions(pushNotificationManager); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java index 85ff58322..cfb873ae7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java @@ -32,6 +32,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventListener; import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; @@ -52,7 +53,7 @@ class MessagePersisterIntegrationTest { private Scheduler messageDeliveryScheduler; private ExecutorService messageDeletionExecutorService; - private ExecutorService clientEventExecutorService; + private ExecutorService websocketConnectionEventExecutor; private MessagesCache messagesCache; private MessagesManager messagesManager; private WebSocketConnectionEventManager webSocketConnectionEventManager; @@ -84,8 +85,12 @@ class MessagePersisterIntegrationTest { messagesManager = new MessagesManager(messagesDynamoDb, messagesCache, mock(ReportMessageManager.class), messageDeletionExecutorService); - clientEventExecutorService = Executors.newVirtualThreadPerTaskExecutor(); - webSocketConnectionEventManager = new WebSocketConnectionEventManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(), clientEventExecutorService); + websocketConnectionEventExecutor = Executors.newVirtualThreadPerTaskExecutor(); + webSocketConnectionEventManager = new WebSocketConnectionEventManager(mock(AccountsManager.class), + mock(PushNotificationManager.class), + REDIS_CLUSTER_EXTENSION.getRedisCluster(), + websocketConnectionEventExecutor); + webSocketConnectionEventManager.start(); messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, @@ -108,8 +113,8 @@ class MessagePersisterIntegrationTest { messageDeletionExecutorService.shutdown(); messageDeletionExecutorService.awaitTermination(15, TimeUnit.SECONDS); - clientEventExecutorService.shutdown(); - clientEventExecutorService.awaitTermination(15, TimeUnit.SECONDS); + websocketConnectionEventExecutor.shutdown(); + websocketConnectionEventExecutor.awaitTermination(15, TimeUnit.SECONDS); messageDeliveryScheduler.dispose();