Clean up `MessageAvailabilityListener` if the websocket client is closed
This commit is contained in:
parent
a06a663b94
commit
55df593561
|
@ -12,7 +12,15 @@ package org.whispersystems.textsecuregcm.storage;
|
|||
*/
|
||||
public interface MessageAvailabilityListener {
|
||||
|
||||
void handleNewMessagesAvailable();
|
||||
/**
|
||||
* @return whether the listener is still active. {@code false} indicates the listener can no longer handle messages
|
||||
* and may be discarded
|
||||
*/
|
||||
boolean handleNewMessagesAvailable();
|
||||
|
||||
void handleMessagesPersisted();
|
||||
/**
|
||||
* @return whether the listener is still active. {@code false} indicates the listener can no longer handle messages
|
||||
* and may be discarded
|
||||
*/
|
||||
boolean handleMessagesPersisted();
|
||||
}
|
||||
|
|
|
@ -351,7 +351,11 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
|
|||
newMessageNotificationCounter.increment();
|
||||
notificationExecutorService.execute(() -> {
|
||||
try {
|
||||
findListener(channel).ifPresent(MessageAvailabilityListener::handleNewMessagesAvailable);
|
||||
findListener(channel).ifPresent(listener -> {
|
||||
if (!listener.handleNewMessagesAvailable()) {
|
||||
removeMessageAvailabilityListener(listener);
|
||||
}
|
||||
});
|
||||
} catch (final Exception e) {
|
||||
logger.warn("Unexpected error handling new message", e);
|
||||
}
|
||||
|
@ -360,7 +364,11 @@ public class MessagesCache extends RedisClusterPubSubAdapter<String, String> imp
|
|||
queuePersistedNotificationCounter.increment();
|
||||
notificationExecutorService.execute(() -> {
|
||||
try {
|
||||
findListener(channel).ifPresent(MessageAvailabilityListener::handleMessagesPersisted);
|
||||
findListener(channel).ifPresent(listener -> {
|
||||
if (!listener.handleMessagesPersisted()) {
|
||||
removeMessageAvailabilityListener(listener);
|
||||
}
|
||||
});
|
||||
} catch (final Exception e) {
|
||||
logger.warn("Unexpected error handling messages persisted", e);
|
||||
}
|
||||
|
|
|
@ -57,24 +57,33 @@ import org.whispersystems.websocket.messages.WebSocketResponseMessage;
|
|||
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
|
||||
public class WebSocketConnection implements MessageAvailabilityListener, DisplacedPresenceListener {
|
||||
|
||||
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
|
||||
private static final Histogram messageTime = metricRegistry.histogram(name(MessageController.class, "message_delivery_duration"));
|
||||
private static final Histogram primaryDeviceMessageTime = metricRegistry.histogram(name(MessageController.class, "primary_device_message_delivery_duration"));
|
||||
private static final Meter sendMessageMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_message"));
|
||||
private static final Meter messageAvailableMeter = metricRegistry.meter(name(WebSocketConnection.class, "messagesAvailable"));
|
||||
private static final Meter messagesPersistedMeter = metricRegistry.meter(name(WebSocketConnection.class, "messagesPersisted"));
|
||||
private static final Meter bytesSentMeter = metricRegistry.meter(name(WebSocketConnection.class, "bytes_sent"));
|
||||
private static final Meter sendFailuresMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_failures"));
|
||||
private static final Meter discardedMessagesMeter = metricRegistry.meter(name(WebSocketConnection.class, "discardedMessages"));
|
||||
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
|
||||
private static final Histogram messageTime = metricRegistry.histogram(
|
||||
name(MessageController.class, "message_delivery_duration"));
|
||||
private static final Histogram primaryDeviceMessageTime = metricRegistry.histogram(
|
||||
name(MessageController.class, "primary_device_message_delivery_duration"));
|
||||
private static final Meter sendMessageMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_message"));
|
||||
private static final Meter messageAvailableMeter = metricRegistry.meter(
|
||||
name(WebSocketConnection.class, "messagesAvailable"));
|
||||
private static final Meter messagesPersistedMeter = metricRegistry.meter(
|
||||
name(WebSocketConnection.class, "messagesPersisted"));
|
||||
private static final Meter bytesSentMeter = metricRegistry.meter(name(WebSocketConnection.class, "bytes_sent"));
|
||||
private static final Meter sendFailuresMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_failures"));
|
||||
private static final Meter discardedMessagesMeter = metricRegistry.meter(
|
||||
name(WebSocketConnection.class, "discardedMessages"));
|
||||
|
||||
private static final String INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME = name(WebSocketConnection.class, "initialQueueLength");
|
||||
private static final String INITIAL_QUEUE_DRAIN_TIMER_NAME = name(WebSocketConnection.class, "drainInitialQueue");
|
||||
private static final String SLOW_QUEUE_DRAIN_COUNTER_NAME = name(WebSocketConnection.class, "slowQueueDrain");
|
||||
private static final String QUEUE_DRAIN_RETRY_COUNTER_NAME = name(WebSocketConnection.class, "queueDrainRetry");
|
||||
private static final String DISPLACEMENT_COUNTER_NAME = name(WebSocketConnection.class, "displacement");
|
||||
private static final String NON_SUCCESS_RESPONSE_COUNTER_NAME = name(WebSocketConnection.class, "clientNonSuccessResponse");
|
||||
private static final String STATUS_CODE_TAG = "status";
|
||||
private static final String STATUS_MESSAGE_TAG = "message";
|
||||
private static final String INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME = name(WebSocketConnection.class,
|
||||
"initialQueueLength");
|
||||
private static final String INITIAL_QUEUE_DRAIN_TIMER_NAME = name(WebSocketConnection.class, "drainInitialQueue");
|
||||
private static final String SLOW_QUEUE_DRAIN_COUNTER_NAME = name(WebSocketConnection.class, "slowQueueDrain");
|
||||
private static final String QUEUE_DRAIN_RETRY_COUNTER_NAME = name(WebSocketConnection.class, "queueDrainRetry");
|
||||
private static final String DISPLACEMENT_COUNTER_NAME = name(WebSocketConnection.class, "displacement");
|
||||
private static final String NON_SUCCESS_RESPONSE_COUNTER_NAME = name(WebSocketConnection.class,
|
||||
"clientNonSuccessResponse");
|
||||
private static final String CLIENT_CLOSED_MESSAGE_AVAILABLE_COUNTER_NAME = name(WebSocketConnection.class,
|
||||
"messageAvailableAfterClientClosed");
|
||||
private static final String STATUS_CODE_TAG = "status";
|
||||
private static final String STATUS_MESSAGE_TAG = "message";
|
||||
|
||||
private static final long SLOW_DRAIN_THRESHOLD = 10_000;
|
||||
|
||||
|
@ -350,19 +359,34 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
|
|||
}
|
||||
|
||||
@Override
|
||||
public void handleNewMessagesAvailable() {
|
||||
public boolean handleNewMessagesAvailable() {
|
||||
if (!client.isOpen()) {
|
||||
// The client may become closed without successful removal of references to the `MessageAvailabilityListener`
|
||||
Metrics.counter(CLIENT_CLOSED_MESSAGE_AVAILABLE_COUNTER_NAME).increment();
|
||||
return false;
|
||||
}
|
||||
|
||||
messageAvailableMeter.mark();
|
||||
|
||||
storedMessageState.compareAndSet(StoredMessageState.EMPTY, StoredMessageState.CACHED_NEW_MESSAGES_AVAILABLE);
|
||||
processStoredMessages();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleMessagesPersisted() {
|
||||
public boolean handleMessagesPersisted() {
|
||||
if (!client.isOpen()) {
|
||||
// The client may become without successful removal of references to the `MessageAvailabilityListener`
|
||||
Metrics.counter(CLIENT_CLOSED_MESSAGE_AVAILABLE_COUNTER_NAME).increment();
|
||||
return false;
|
||||
}
|
||||
messagesPersistedMeter.mark();
|
||||
|
||||
storedMessageState.set(StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE);
|
||||
processStoredMessages();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -124,14 +124,16 @@ class MessagePersisterIntegrationTest {
|
|||
|
||||
messagesManager.addMessageAvailabilityListener(account.getUuid(), 1, new MessageAvailabilityListener() {
|
||||
@Override
|
||||
public void handleNewMessagesAvailable() {
|
||||
public boolean handleNewMessagesAvailable() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleMessagesPersisted() {
|
||||
public boolean handleMessagesPersisted() {
|
||||
synchronized (messagesPersisted) {
|
||||
messagesPersisted.set(true);
|
||||
messagesPersisted.notifyAll();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
|
|
@ -21,10 +21,12 @@ import java.util.List;
|
|||
import java.util.Optional;
|
||||
import java.util.Random;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
|
@ -287,15 +289,18 @@ class MessagesCacheTest {
|
|||
|
||||
final MessageAvailabilityListener listener = new MessageAvailabilityListener() {
|
||||
@Override
|
||||
public void handleNewMessagesAvailable() {
|
||||
public boolean handleNewMessagesAvailable() {
|
||||
synchronized (notified) {
|
||||
notified.set(true);
|
||||
notified.notifyAll();
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleMessagesPersisted() {
|
||||
public boolean handleMessagesPersisted() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -320,14 +325,17 @@ class MessagesCacheTest {
|
|||
|
||||
final MessageAvailabilityListener listener = new MessageAvailabilityListener() {
|
||||
@Override
|
||||
public void handleNewMessagesAvailable() {
|
||||
public boolean handleNewMessagesAvailable() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleMessagesPersisted() {
|
||||
public boolean handleMessagesPersisted() {
|
||||
synchronized (notified) {
|
||||
notified.set(true);
|
||||
notified.notifyAll();
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -348,4 +356,72 @@ class MessagesCacheTest {
|
|||
});
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Helper class that implements {@link MessageAvailabilityListener#handleNewMessagesAvailable()} by always returning
|
||||
* {@code false}. Its {@code counter} field tracks how many times {@code handleNewMessagesAvailable} has been called.
|
||||
* <p>
|
||||
* It uses a parameterized {@code AtomicBoolean} for asynchronous observation. It <em>must</em> be reset to
|
||||
* {@code false} between observations.
|
||||
*/
|
||||
private static class NewMessagesAvailabilityClosedListener implements MessageAvailabilityListener {
|
||||
|
||||
private int counter;
|
||||
|
||||
private final Consumer<Integer> messageHandledCallback;
|
||||
private final CompletableFuture<Void> firstMessageHandled = new CompletableFuture<>();
|
||||
|
||||
private NewMessagesAvailabilityClosedListener(final Consumer<Integer> messageHandledCallback) {
|
||||
this.messageHandledCallback = messageHandledCallback;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean handleNewMessagesAvailable() {
|
||||
counter++;
|
||||
messageHandledCallback.accept(counter);
|
||||
firstMessageHandled.complete(null);
|
||||
|
||||
return false;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean handleMessagesPersisted() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAvailabilityListenerResponses() {
|
||||
final NewMessagesAvailabilityClosedListener listener1 = new NewMessagesAvailabilityClosedListener(
|
||||
count -> assertEquals(1, count));
|
||||
final NewMessagesAvailabilityClosedListener listener2 = new NewMessagesAvailabilityClosedListener(
|
||||
count -> assertEquals(1, count));
|
||||
|
||||
assertTimeoutPreemptively(Duration.ofSeconds(30), () -> {
|
||||
messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener1);
|
||||
final UUID messageGuid1 = UUID.randomUUID();
|
||||
messagesCache.insert(messageGuid1, DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
||||
generateRandomMessage(messageGuid1, true));
|
||||
|
||||
listener1.firstMessageHandled.get();
|
||||
|
||||
// Avoid a race condition by blocking on the message handled future *and* the current notification executor task—
|
||||
// the notification executor task includes unsubscribing `listener1`, and, if we don’t wait, sometimes
|
||||
// `listener2` will get subscribed before `listener1` is cleaned up
|
||||
notificationExecutorService.submit(() -> listener1.firstMessageHandled.get()).get();
|
||||
|
||||
final UUID messageGuid2 = UUID.randomUUID();
|
||||
messagesCache.insert(messageGuid2, DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
||||
generateRandomMessage(messageGuid2, true));
|
||||
|
||||
messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener2);
|
||||
|
||||
final UUID messageGuid3 = UUID.randomUUID();
|
||||
messagesCache.insert(messageGuid3, DESTINATION_UUID, DESTINATION_DEVICE_ID,
|
||||
generateRandomMessage(messageGuid3, true));
|
||||
|
||||
listener2.firstMessageHandled.get();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -220,6 +220,7 @@ class WebSocketConnectionTest {
|
|||
when(account.getNumber()).thenReturn("+18005551234");
|
||||
when(account.getUuid()).thenReturn(accountUuid);
|
||||
when(device.getId()).thenReturn(1L);
|
||||
when(client.isOpen()).thenReturn(true);
|
||||
when(client.getUserAgent()).thenReturn("Test-UA");
|
||||
|
||||
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
|
||||
|
@ -360,6 +361,7 @@ class WebSocketConnectionTest {
|
|||
when(account.getNumber()).thenReturn("+18005551234");
|
||||
when(account.getUuid()).thenReturn(UUID.randomUUID());
|
||||
when(device.getId()).thenReturn(1L);
|
||||
when(client.isOpen()).thenReturn(true);
|
||||
when(client.getUserAgent()).thenReturn("Test-UA");
|
||||
|
||||
final AtomicBoolean threadWaiting = new AtomicBoolean(false);
|
||||
|
@ -426,6 +428,7 @@ class WebSocketConnectionTest {
|
|||
when(account.getNumber()).thenReturn("+18005551234");
|
||||
when(account.getUuid()).thenReturn(UUID.randomUUID());
|
||||
when(device.getId()).thenReturn(1L);
|
||||
when(client.isOpen()).thenReturn(true);
|
||||
when(client.getUserAgent()).thenReturn("Test-UA");
|
||||
|
||||
final List<Envelope> firstPageMessages =
|
||||
|
@ -469,6 +472,7 @@ class WebSocketConnectionTest {
|
|||
when(account.getNumber()).thenReturn("+18005551234");
|
||||
when(account.getUuid()).thenReturn(UUID.randomUUID());
|
||||
when(device.getId()).thenReturn(1L);
|
||||
when(client.isOpen()).thenReturn(true);
|
||||
when(client.getUserAgent()).thenReturn("Test-UA");
|
||||
|
||||
final UUID senderUuid = UUID.randomUUID();
|
||||
|
@ -525,6 +529,7 @@ class WebSocketConnectionTest {
|
|||
when(account.getNumber()).thenReturn("+18005551234");
|
||||
when(account.getUuid()).thenReturn(accountUuid);
|
||||
when(device.getId()).thenReturn(1L);
|
||||
when(client.isOpen()).thenReturn(true);
|
||||
when(client.getUserAgent()).thenReturn("Test-UA");
|
||||
|
||||
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
|
||||
|
@ -554,6 +559,7 @@ class WebSocketConnectionTest {
|
|||
when(account.getNumber()).thenReturn("+18005551234");
|
||||
when(account.getUuid()).thenReturn(accountUuid);
|
||||
when(device.getId()).thenReturn(1L);
|
||||
when(client.isOpen()).thenReturn(true);
|
||||
when(client.getUserAgent()).thenReturn("Test-UA");
|
||||
|
||||
final List<Envelope> firstPageMessages =
|
||||
|
@ -602,6 +608,7 @@ class WebSocketConnectionTest {
|
|||
when(account.getNumber()).thenReturn("+18005551234");
|
||||
when(account.getUuid()).thenReturn(accountUuid);
|
||||
when(device.getId()).thenReturn(1L);
|
||||
when(client.isOpen()).thenReturn(true);
|
||||
when(client.getUserAgent()).thenReturn("Test-UA");
|
||||
|
||||
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
|
||||
|
|
Loading…
Reference in New Issue