diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index ecb831c28..c0be72667 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -49,6 +49,7 @@ import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.RedisClusterUtil; +import org.whispersystems.textsecuregcm.util.Util; import reactor.core.observability.micrometer.Micrometer; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -89,6 +90,8 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp name(MessagesCache.class, "staleEphemeralMessages")); private final Counter messageAvailabilityListenerRemovedAfterAddCounter = Metrics.counter( name(MessagesCache.class, "messageAvailabilityListenerRemovedAfterAdd")); + private final Counter prunedStaleSubscriptionCounter = Metrics.counter( + name(MessagesCache.class, "prunedStaleSubscription")); static final String NEXT_SLOT_TO_PERSIST_KEY = "user_queue_persist_slot"; private static final byte[] LOCK_VALUE = "1".getBytes(StandardCharsets.UTF_8); @@ -147,7 +150,8 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp } for (final String queueName : queueNames) { - subscribeForKeyspaceNotifications(queueName); + // avoid overwhelming a newly recovered node by processing synchronously, rather than using CompletableFuture.allOf() + subscribeForKeyspaceNotifications(queueName).join(); } } @@ -384,12 +388,15 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp final MessageAvailabilityListener listener) { final String queueName = getQueueName(destinationUuid, deviceId); + final CompletableFuture subscribeFuture; synchronized (messageListenersByQueueName) { messageListenersByQueueName.put(queueName, listener); queueNamesByMessageListener.put(listener, queueName); + // Submit to the Redis queue within the synchronized block, but don’t wait until exiting + subscribeFuture = subscribeForKeyspaceNotifications(queueName); } - subscribeForKeyspaceNotifications(queueName); + subscribeFuture.join(); } public void removeMessageAvailabilityListener(final MessageAvailabilityListener listener) { @@ -399,33 +406,49 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp } if (queueName != null) { - unsubscribeFromKeyspaceNotifications(queueName); + final CompletableFuture unsubscribeFuture; synchronized (messageListenersByQueueName) { queueNamesByMessageListener.remove(listener); - if (!messageListenersByQueueName.remove(queueName, listener)) { + if (messageListenersByQueueName.remove(queueName, listener)) { + // Submit to the Redis queue within the synchronized block, but don’t wait until exiting + unsubscribeFuture = unsubscribeFromKeyspaceNotifications(queueName); + } else { messageAvailabilityListenerRemovedAfterAddCounter.increment(); + unsubscribeFuture = CompletableFuture.completedFuture(null); } } + + unsubscribeFuture.join(); } } - private void subscribeForKeyspaceNotifications(final String queueName) { - final int slot = SlotHash.getSlot(queueName); - - pubSubConnection.usePubSubConnection( - connection -> connection.sync().nodes(node -> node.is(RedisClusterNode.NodeFlag.UPSTREAM) && node.hasSlot(slot)) - .commands() - .subscribe(getKeyspaceChannels(queueName))); + private void pruneStaleSubscription(final String channel) { + unsubscribeFromKeyspaceNotifications(getQueueNameFromKeyspaceChannel(channel)) + .thenRun(prunedStaleSubscriptionCounter::increment); } - private void unsubscribeFromKeyspaceNotifications(final String queueName) { + private CompletableFuture subscribeForKeyspaceNotifications(final String queueName) { final int slot = SlotHash.getSlot(queueName); - pubSubConnection.usePubSubConnection( - connection -> connection.sync().nodes(node -> node.is(RedisClusterNode.NodeFlag.UPSTREAM) && node.hasSlot(slot)) + return pubSubConnection.withPubSubConnection( + connection -> connection.async() + .nodes(node -> node.is(RedisClusterNode.NodeFlag.UPSTREAM) && node.hasSlot(slot)) .commands() - .unsubscribe(getKeyspaceChannels(queueName))); + .subscribe(getKeyspaceChannels(queueName))).toCompletableFuture() + .thenRun(Util.NOOP); + } + + private CompletableFuture unsubscribeFromKeyspaceNotifications(final String queueName) { + final int slot = SlotHash.getSlot(queueName); + + return pubSubConnection.withPubSubConnection( + connection -> connection.async() + .nodes(node -> node.is(RedisClusterNode.NodeFlag.UPSTREAM) && node.hasSlot(slot)) + .commands() + .unsubscribe(getKeyspaceChannels(queueName))) + .toCompletableFuture() + .thenRun(Util.NOOP); } private static String[] getKeyspaceChannels(final String queueName) { @@ -443,11 +466,11 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp newMessageNotificationCounter.increment(); notificationExecutorService.execute(() -> { try { - findListener(channel).ifPresent(listener -> { + findListener(channel).ifPresentOrElse(listener -> { if (!listener.handleNewMessagesAvailable()) { removeMessageAvailabilityListener(listener); } - }); + }, () -> pruneStaleSubscription(channel)); } catch (final Exception e) { logger.warn("Unexpected error handling new message", e); } @@ -456,11 +479,11 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp queuePersistedNotificationCounter.increment(); notificationExecutorService.execute(() -> { try { - findListener(channel).ifPresent(listener -> { + findListener(channel).ifPresentOrElse(listener -> { if (!listener.handleMessagesPersisted()) { removeMessageAvailabilityListener(listener); } - }); + }, () -> pruneStaleSubscription(channel)); } catch (final Exception e) { logger.warn("Unexpected error handling messages persisted", e); }