Clean up `MessageAvailabilityListener` if the websocket client is closed

This commit is contained in:
Chris Eager 2022-08-01 14:37:05 -05:00 committed by Chris Eager
parent a06a663b94
commit 55df593561
6 changed files with 154 additions and 29 deletions

View File

@ -12,7 +12,15 @@ package org.whispersystems.textsecuregcm.storage;
*/
public interface MessageAvailabilityListener {
void handleNewMessagesAvailable();
/**
* @return whether the listener is still active. {@code false} indicates the listener can no longer handle messages
* and may be discarded
*/
boolean handleNewMessagesAvailable();
void handleMessagesPersisted();
/**
* @return whether the listener is still active. {@code false} indicates the listener can no longer handle messages
* and may be discarded
*/
boolean handleMessagesPersisted();
}

View File

@ -351,7 +351,11 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
newMessageNotificationCounter.increment();
notificationExecutorService.execute(() -> {
try {
findListener(channel).ifPresent(MessageAvailabilityListener::handleNewMessagesAvailable);
findListener(channel).ifPresent(listener -> {
if (!listener.handleNewMessagesAvailable()) {
removeMessageAvailabilityListener(listener);
}
});
} catch (final Exception e) {
logger.warn("Unexpected error handling new message", e);
}
@ -360,7 +364,11 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
queuePersistedNotificationCounter.increment();
notificationExecutorService.execute(() -> {
try {
findListener(channel).ifPresent(MessageAvailabilityListener::handleMessagesPersisted);
findListener(channel).ifPresent(listener -> {
if (!listener.handleMessagesPersisted()) {
removeMessageAvailabilityListener(listener);
}
});
} catch (final Exception e) {
logger.warn("Unexpected error handling messages persisted", e);
}

View File

@ -57,24 +57,33 @@ import org.whispersystems.websocket.messages.WebSocketResponseMessage;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class WebSocketConnection implements MessageAvailabilityListener, DisplacedPresenceListener {
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Histogram messageTime = metricRegistry.histogram(name(MessageController.class, "message_delivery_duration"));
private static final Histogram primaryDeviceMessageTime = metricRegistry.histogram(name(MessageController.class, "primary_device_message_delivery_duration"));
private static final Meter sendMessageMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_message"));
private static final Meter messageAvailableMeter = metricRegistry.meter(name(WebSocketConnection.class, "messagesAvailable"));
private static final Meter messagesPersistedMeter = metricRegistry.meter(name(WebSocketConnection.class, "messagesPersisted"));
private static final Meter bytesSentMeter = metricRegistry.meter(name(WebSocketConnection.class, "bytes_sent"));
private static final Meter sendFailuresMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_failures"));
private static final Meter discardedMessagesMeter = metricRegistry.meter(name(WebSocketConnection.class, "discardedMessages"));
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Histogram messageTime = metricRegistry.histogram(
name(MessageController.class, "message_delivery_duration"));
private static final Histogram primaryDeviceMessageTime = metricRegistry.histogram(
name(MessageController.class, "primary_device_message_delivery_duration"));
private static final Meter sendMessageMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_message"));
private static final Meter messageAvailableMeter = metricRegistry.meter(
name(WebSocketConnection.class, "messagesAvailable"));
private static final Meter messagesPersistedMeter = metricRegistry.meter(
name(WebSocketConnection.class, "messagesPersisted"));
private static final Meter bytesSentMeter = metricRegistry.meter(name(WebSocketConnection.class, "bytes_sent"));
private static final Meter sendFailuresMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_failures"));
private static final Meter discardedMessagesMeter = metricRegistry.meter(
name(WebSocketConnection.class, "discardedMessages"));
private static final String INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME = name(WebSocketConnection.class, "initialQueueLength");
private static final String INITIAL_QUEUE_DRAIN_TIMER_NAME = name(WebSocketConnection.class, "drainInitialQueue");
private static final String SLOW_QUEUE_DRAIN_COUNTER_NAME = name(WebSocketConnection.class, "slowQueueDrain");
private static final String QUEUE_DRAIN_RETRY_COUNTER_NAME = name(WebSocketConnection.class, "queueDrainRetry");
private static final String DISPLACEMENT_COUNTER_NAME = name(WebSocketConnection.class, "displacement");
private static final String NON_SUCCESS_RESPONSE_COUNTER_NAME = name(WebSocketConnection.class, "clientNonSuccessResponse");
private static final String STATUS_CODE_TAG = "status";
private static final String STATUS_MESSAGE_TAG = "message";
private static final String INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME = name(WebSocketConnection.class,
"initialQueueLength");
private static final String INITIAL_QUEUE_DRAIN_TIMER_NAME = name(WebSocketConnection.class, "drainInitialQueue");
private static final String SLOW_QUEUE_DRAIN_COUNTER_NAME = name(WebSocketConnection.class, "slowQueueDrain");
private static final String QUEUE_DRAIN_RETRY_COUNTER_NAME = name(WebSocketConnection.class, "queueDrainRetry");
private static final String DISPLACEMENT_COUNTER_NAME = name(WebSocketConnection.class, "displacement");
private static final String NON_SUCCESS_RESPONSE_COUNTER_NAME = name(WebSocketConnection.class,
"clientNonSuccessResponse");
private static final String CLIENT_CLOSED_MESSAGE_AVAILABLE_COUNTER_NAME = name(WebSocketConnection.class,
"messageAvailableAfterClientClosed");
private static final String STATUS_CODE_TAG = "status";
private static final String STATUS_MESSAGE_TAG = "message";
private static final long SLOW_DRAIN_THRESHOLD = 10_000;
@ -350,19 +359,34 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
}
@Override
public void handleNewMessagesAvailable() {
public boolean handleNewMessagesAvailable() {
if (!client.isOpen()) {
// The client may become closed without successful removal of references to the `MessageAvailabilityListener`
Metrics.counter(CLIENT_CLOSED_MESSAGE_AVAILABLE_COUNTER_NAME).increment();
return false;
}
messageAvailableMeter.mark();
storedMessageState.compareAndSet(StoredMessageState.EMPTY, StoredMessageState.CACHED_NEW_MESSAGES_AVAILABLE);
processStoredMessages();
return true;
}
@Override
public void handleMessagesPersisted() {
public boolean handleMessagesPersisted() {
if (!client.isOpen()) {
// The client may become without successful removal of references to the `MessageAvailabilityListener`
Metrics.counter(CLIENT_CLOSED_MESSAGE_AVAILABLE_COUNTER_NAME).increment();
return false;
}
messagesPersistedMeter.mark();
storedMessageState.set(StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE);
processStoredMessages();
return true;
}
@Override

View File

@ -124,14 +124,16 @@ class MessagePersisterIntegrationTest {
messagesManager.addMessageAvailabilityListener(account.getUuid(), 1, new MessageAvailabilityListener() {
@Override
public void handleNewMessagesAvailable() {
public boolean handleNewMessagesAvailable() {
return true;
}
@Override
public void handleMessagesPersisted() {
public boolean handleMessagesPersisted() {
synchronized (messagesPersisted) {
messagesPersisted.set(true);
messagesPersisted.notifyAll();
return true;
}
}
});

View File

@ -21,10 +21,12 @@ import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
@ -287,15 +289,18 @@ class MessagesCacheTest {
final MessageAvailabilityListener listener = new MessageAvailabilityListener() {
@Override
public void handleNewMessagesAvailable() {
public boolean handleNewMessagesAvailable() {
synchronized (notified) {
notified.set(true);
notified.notifyAll();
return true;
}
}
@Override
public void handleMessagesPersisted() {
public boolean handleMessagesPersisted() {
return true;
}
};
@ -320,14 +325,17 @@ class MessagesCacheTest {
final MessageAvailabilityListener listener = new MessageAvailabilityListener() {
@Override
public void handleNewMessagesAvailable() {
public boolean handleNewMessagesAvailable() {
return true;
}
@Override
public void handleMessagesPersisted() {
public boolean handleMessagesPersisted() {
synchronized (notified) {
notified.set(true);
notified.notifyAll();
return true;
}
}
};
@ -348,4 +356,72 @@ class MessagesCacheTest {
});
}
/**
* Helper class that implements {@link MessageAvailabilityListener#handleNewMessagesAvailable()} by always returning
* {@code false}. Its {@code counter} field tracks how many times {@code handleNewMessagesAvailable} has been called.
* <p>
* It uses a parameterized {@code AtomicBoolean} for asynchronous observation. It <em>must</em> be reset to
* {@code false} between observations.
*/
private static class NewMessagesAvailabilityClosedListener implements MessageAvailabilityListener {
private int counter;
private final Consumer<Integer> messageHandledCallback;
private final CompletableFuture<Void> firstMessageHandled = new CompletableFuture<>();
private NewMessagesAvailabilityClosedListener(final Consumer<Integer> messageHandledCallback) {
this.messageHandledCallback = messageHandledCallback;
}
@Override
public boolean handleNewMessagesAvailable() {
counter++;
messageHandledCallback.accept(counter);
firstMessageHandled.complete(null);
return false;
}
@Override
public boolean handleMessagesPersisted() {
return true;
}
}
@Test
void testAvailabilityListenerResponses() {
final NewMessagesAvailabilityClosedListener listener1 = new NewMessagesAvailabilityClosedListener(
count -> assertEquals(1, count));
final NewMessagesAvailabilityClosedListener listener2 = new NewMessagesAvailabilityClosedListener(
count -> assertEquals(1, count));
assertTimeoutPreemptively(Duration.ofSeconds(30), () -> {
messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener1);
final UUID messageGuid1 = UUID.randomUUID();
messagesCache.insert(messageGuid1, DESTINATION_UUID, DESTINATION_DEVICE_ID,
generateRandomMessage(messageGuid1, true));
listener1.firstMessageHandled.get();
// Avoid a race condition by blocking on the message handled future *and* the current notification executor task
// the notification executor task includes unsubscribing `listener1`, and, if we dont wait, sometimes
// `listener2` will get subscribed before `listener1` is cleaned up
notificationExecutorService.submit(() -> listener1.firstMessageHandled.get()).get();
final UUID messageGuid2 = UUID.randomUUID();
messagesCache.insert(messageGuid2, DESTINATION_UUID, DESTINATION_DEVICE_ID,
generateRandomMessage(messageGuid2, true));
messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener2);
final UUID messageGuid3 = UUID.randomUUID();
messagesCache.insert(messageGuid3, DESTINATION_UUID, DESTINATION_DEVICE_ID,
generateRandomMessage(messageGuid3, true));
listener2.firstMessageHandled.get();
});
}
}

View File

@ -220,6 +220,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
@ -360,6 +361,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
when(client.getUserAgent()).thenReturn("Test-UA");
final AtomicBoolean threadWaiting = new AtomicBoolean(false);
@ -426,6 +428,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
when(client.getUserAgent()).thenReturn("Test-UA");
final List<Envelope> firstPageMessages =
@ -469,6 +472,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
when(client.getUserAgent()).thenReturn("Test-UA");
final UUID senderUuid = UUID.randomUUID();
@ -525,6 +529,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
@ -554,6 +559,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
when(client.getUserAgent()).thenReturn("Test-UA");
final List<Envelope> firstPageMessages =
@ -602,6 +608,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))