From 55df5935613f694f2c6bb21b50a635f5a2645ebd Mon Sep 17 00:00:00 2001 From: Chris Eager Date: Mon, 1 Aug 2022 14:37:05 -0500 Subject: [PATCH] Clean up `MessageAvailabilityListener` if the websocket client is closed --- .../storage/MessageAvailabilityListener.java | 12 ++- .../textsecuregcm/storage/MessagesCache.java | 12 ++- .../websocket/WebSocketConnection.java | 62 +++++++++----- .../MessagePersisterIntegrationTest.java | 6 +- .../storage/MessagesCacheTest.java | 84 ++++++++++++++++++- .../websocket/WebSocketConnectionTest.java | 7 ++ 6 files changed, 154 insertions(+), 29 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageAvailabilityListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageAvailabilityListener.java index 12872c372..e7fed470a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageAvailabilityListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageAvailabilityListener.java @@ -12,7 +12,15 @@ package org.whispersystems.textsecuregcm.storage; */ public interface MessageAvailabilityListener { - void handleNewMessagesAvailable(); + /** + * @return whether the listener is still active. {@code false} indicates the listener can no longer handle messages + * and may be discarded + */ + boolean handleNewMessagesAvailable(); - void handleMessagesPersisted(); + /** + * @return whether the listener is still active. {@code false} indicates the listener can no longer handle messages + * and may be discarded + */ + boolean handleMessagesPersisted(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 8ce563544..195813e22 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -351,7 +351,11 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp newMessageNotificationCounter.increment(); notificationExecutorService.execute(() -> { try { - findListener(channel).ifPresent(MessageAvailabilityListener::handleNewMessagesAvailable); + findListener(channel).ifPresent(listener -> { + if (!listener.handleNewMessagesAvailable()) { + removeMessageAvailabilityListener(listener); + } + }); } catch (final Exception e) { logger.warn("Unexpected error handling new message", e); } @@ -360,7 +364,11 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp queuePersistedNotificationCounter.increment(); notificationExecutorService.execute(() -> { try { - findListener(channel).ifPresent(MessageAvailabilityListener::handleMessagesPersisted); + findListener(channel).ifPresent(listener -> { + if (!listener.handleMessagesPersisted()) { + removeMessageAvailabilityListener(listener); + } + }); } catch (final Exception e) { logger.warn("Unexpected error handling messages persisted", e); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 89466c63e..431494ef3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -57,24 +57,33 @@ import org.whispersystems.websocket.messages.WebSocketResponseMessage; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") public class WebSocketConnection implements MessageAvailabilityListener, DisplacedPresenceListener { - private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); - private static final Histogram messageTime = metricRegistry.histogram(name(MessageController.class, "message_delivery_duration")); - private static final Histogram primaryDeviceMessageTime = metricRegistry.histogram(name(MessageController.class, "primary_device_message_delivery_duration")); - private static final Meter sendMessageMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_message")); - private static final Meter messageAvailableMeter = metricRegistry.meter(name(WebSocketConnection.class, "messagesAvailable")); - private static final Meter messagesPersistedMeter = metricRegistry.meter(name(WebSocketConnection.class, "messagesPersisted")); - private static final Meter bytesSentMeter = metricRegistry.meter(name(WebSocketConnection.class, "bytes_sent")); - private static final Meter sendFailuresMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_failures")); - private static final Meter discardedMessagesMeter = metricRegistry.meter(name(WebSocketConnection.class, "discardedMessages")); + private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); + private static final Histogram messageTime = metricRegistry.histogram( + name(MessageController.class, "message_delivery_duration")); + private static final Histogram primaryDeviceMessageTime = metricRegistry.histogram( + name(MessageController.class, "primary_device_message_delivery_duration")); + private static final Meter sendMessageMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_message")); + private static final Meter messageAvailableMeter = metricRegistry.meter( + name(WebSocketConnection.class, "messagesAvailable")); + private static final Meter messagesPersistedMeter = metricRegistry.meter( + name(WebSocketConnection.class, "messagesPersisted")); + private static final Meter bytesSentMeter = metricRegistry.meter(name(WebSocketConnection.class, "bytes_sent")); + private static final Meter sendFailuresMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_failures")); + private static final Meter discardedMessagesMeter = metricRegistry.meter( + name(WebSocketConnection.class, "discardedMessages")); - private static final String INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME = name(WebSocketConnection.class, "initialQueueLength"); - private static final String INITIAL_QUEUE_DRAIN_TIMER_NAME = name(WebSocketConnection.class, "drainInitialQueue"); - private static final String SLOW_QUEUE_DRAIN_COUNTER_NAME = name(WebSocketConnection.class, "slowQueueDrain"); - private static final String QUEUE_DRAIN_RETRY_COUNTER_NAME = name(WebSocketConnection.class, "queueDrainRetry"); - private static final String DISPLACEMENT_COUNTER_NAME = name(WebSocketConnection.class, "displacement"); - private static final String NON_SUCCESS_RESPONSE_COUNTER_NAME = name(WebSocketConnection.class, "clientNonSuccessResponse"); - private static final String STATUS_CODE_TAG = "status"; - private static final String STATUS_MESSAGE_TAG = "message"; + private static final String INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME = name(WebSocketConnection.class, + "initialQueueLength"); + private static final String INITIAL_QUEUE_DRAIN_TIMER_NAME = name(WebSocketConnection.class, "drainInitialQueue"); + private static final String SLOW_QUEUE_DRAIN_COUNTER_NAME = name(WebSocketConnection.class, "slowQueueDrain"); + private static final String QUEUE_DRAIN_RETRY_COUNTER_NAME = name(WebSocketConnection.class, "queueDrainRetry"); + private static final String DISPLACEMENT_COUNTER_NAME = name(WebSocketConnection.class, "displacement"); + private static final String NON_SUCCESS_RESPONSE_COUNTER_NAME = name(WebSocketConnection.class, + "clientNonSuccessResponse"); + private static final String CLIENT_CLOSED_MESSAGE_AVAILABLE_COUNTER_NAME = name(WebSocketConnection.class, + "messageAvailableAfterClientClosed"); + private static final String STATUS_CODE_TAG = "status"; + private static final String STATUS_MESSAGE_TAG = "message"; private static final long SLOW_DRAIN_THRESHOLD = 10_000; @@ -350,19 +359,34 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac } @Override - public void handleNewMessagesAvailable() { + public boolean handleNewMessagesAvailable() { + if (!client.isOpen()) { + // The client may become closed without successful removal of references to the `MessageAvailabilityListener` + Metrics.counter(CLIENT_CLOSED_MESSAGE_AVAILABLE_COUNTER_NAME).increment(); + return false; + } + messageAvailableMeter.mark(); storedMessageState.compareAndSet(StoredMessageState.EMPTY, StoredMessageState.CACHED_NEW_MESSAGES_AVAILABLE); processStoredMessages(); + + return true; } @Override - public void handleMessagesPersisted() { + public boolean handleMessagesPersisted() { + if (!client.isOpen()) { + // The client may become without successful removal of references to the `MessageAvailabilityListener` + Metrics.counter(CLIENT_CLOSED_MESSAGE_AVAILABLE_COUNTER_NAME).increment(); + return false; + } messagesPersistedMeter.mark(); storedMessageState.set(StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE); processStoredMessages(); + + return true; } @Override diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java index d9a324c9c..bfbe204cb 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterIntegrationTest.java @@ -124,14 +124,16 @@ class MessagePersisterIntegrationTest { messagesManager.addMessageAvailabilityListener(account.getUuid(), 1, new MessageAvailabilityListener() { @Override - public void handleNewMessagesAvailable() { + public boolean handleNewMessagesAvailable() { + return true; } @Override - public void handleMessagesPersisted() { + public boolean handleMessagesPersisted() { synchronized (messagesPersisted) { messagesPersisted.set(true); messagesPersisted.notifyAll(); + return true; } } }); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index 9e78c77ce..aea7a0a40 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -21,10 +21,12 @@ import java.util.List; import java.util.Optional; import java.util.Random; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; import java.util.stream.Collectors; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterEach; @@ -287,15 +289,18 @@ class MessagesCacheTest { final MessageAvailabilityListener listener = new MessageAvailabilityListener() { @Override - public void handleNewMessagesAvailable() { + public boolean handleNewMessagesAvailable() { synchronized (notified) { notified.set(true); notified.notifyAll(); + + return true; } } @Override - public void handleMessagesPersisted() { + public boolean handleMessagesPersisted() { + return true; } }; @@ -320,14 +325,17 @@ class MessagesCacheTest { final MessageAvailabilityListener listener = new MessageAvailabilityListener() { @Override - public void handleNewMessagesAvailable() { + public boolean handleNewMessagesAvailable() { + return true; } @Override - public void handleMessagesPersisted() { + public boolean handleMessagesPersisted() { synchronized (notified) { notified.set(true); notified.notifyAll(); + + return true; } } }; @@ -348,4 +356,72 @@ class MessagesCacheTest { }); } + + /** + * Helper class that implements {@link MessageAvailabilityListener#handleNewMessagesAvailable()} by always returning + * {@code false}. Its {@code counter} field tracks how many times {@code handleNewMessagesAvailable} has been called. + *

+ * It uses a parameterized {@code AtomicBoolean} for asynchronous observation. It must be reset to + * {@code false} between observations. + */ + private static class NewMessagesAvailabilityClosedListener implements MessageAvailabilityListener { + + private int counter; + + private final Consumer messageHandledCallback; + private final CompletableFuture firstMessageHandled = new CompletableFuture<>(); + + private NewMessagesAvailabilityClosedListener(final Consumer messageHandledCallback) { + this.messageHandledCallback = messageHandledCallback; + } + + @Override + public boolean handleNewMessagesAvailable() { + counter++; + messageHandledCallback.accept(counter); + firstMessageHandled.complete(null); + + return false; + + } + + @Override + public boolean handleMessagesPersisted() { + return true; + } + } + + @Test + void testAvailabilityListenerResponses() { + final NewMessagesAvailabilityClosedListener listener1 = new NewMessagesAvailabilityClosedListener( + count -> assertEquals(1, count)); + final NewMessagesAvailabilityClosedListener listener2 = new NewMessagesAvailabilityClosedListener( + count -> assertEquals(1, count)); + + assertTimeoutPreemptively(Duration.ofSeconds(30), () -> { + messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener1); + final UUID messageGuid1 = UUID.randomUUID(); + messagesCache.insert(messageGuid1, DESTINATION_UUID, DESTINATION_DEVICE_ID, + generateRandomMessage(messageGuid1, true)); + + listener1.firstMessageHandled.get(); + + // Avoid a race condition by blocking on the message handled future *and* the current notification executor task— + // the notification executor task includes unsubscribing `listener1`, and, if we don’t wait, sometimes + // `listener2` will get subscribed before `listener1` is cleaned up + notificationExecutorService.submit(() -> listener1.firstMessageHandled.get()).get(); + + final UUID messageGuid2 = UUID.randomUUID(); + messagesCache.insert(messageGuid2, DESTINATION_UUID, DESTINATION_DEVICE_ID, + generateRandomMessage(messageGuid2, true)); + + messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener2); + + final UUID messageGuid3 = UUID.randomUUID(); + messagesCache.insert(messageGuid3, DESTINATION_UUID, DESTINATION_DEVICE_ID, + generateRandomMessage(messageGuid3, true)); + + listener2.firstMessageHandled.get(); + }); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 8f0d98221..8df10d2b3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -220,6 +220,7 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(accountUuid); when(device.getId()).thenReturn(1L); + when(client.isOpen()).thenReturn(true); when(client.getUserAgent()).thenReturn("Test-UA"); when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) @@ -360,6 +361,7 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(UUID.randomUUID()); when(device.getId()).thenReturn(1L); + when(client.isOpen()).thenReturn(true); when(client.getUserAgent()).thenReturn("Test-UA"); final AtomicBoolean threadWaiting = new AtomicBoolean(false); @@ -426,6 +428,7 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(UUID.randomUUID()); when(device.getId()).thenReturn(1L); + when(client.isOpen()).thenReturn(true); when(client.getUserAgent()).thenReturn("Test-UA"); final List firstPageMessages = @@ -469,6 +472,7 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(UUID.randomUUID()); when(device.getId()).thenReturn(1L); + when(client.isOpen()).thenReturn(true); when(client.getUserAgent()).thenReturn("Test-UA"); final UUID senderUuid = UUID.randomUUID(); @@ -525,6 +529,7 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(accountUuid); when(device.getId()).thenReturn(1L); + when(client.isOpen()).thenReturn(true); when(client.getUserAgent()).thenReturn("Test-UA"); when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) @@ -554,6 +559,7 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(accountUuid); when(device.getId()).thenReturn(1L); + when(client.isOpen()).thenReturn(true); when(client.getUserAgent()).thenReturn("Test-UA"); final List firstPageMessages = @@ -602,6 +608,7 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(accountUuid); when(device.getId()).thenReturn(1L); + when(client.isOpen()).thenReturn(true); when(client.getUserAgent()).thenReturn("Test-UA"); when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))