From 089b6b1644d59eff741b1616a9a8196082405f23 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Thu, 18 Mar 2021 19:06:07 -0400 Subject: [PATCH] Retry attempts to get messages after a delay; close connections after a finite number of retries. --- .../textsecuregcm/WhisperServerService.java | 3 +- .../AuthenticatedConnectListener.java | 16 +++-- .../websocket/WebSocketConnection.java | 68 +++++++++++++++---- .../WebSocketConnectionIntegrationTest.java | 15 ++-- .../websocket/WebSocketConnectionTest.java | 63 +++++++++++++---- 5 files changed, 126 insertions(+), 39 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 56d859fcd..06565623f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -315,6 +315,7 @@ public class WhisperServerService extends Application webSocketEnvironment = new WebSocketEnvironment<>(environment, config.getWebSocketConfiguration(), 90000); webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator)); - webSocketEnvironment.setConnectListener(new AuthenticatedConnectListener(receiptSender, messagesManager, messageSender, apnFallbackManager, clientPresenceManager)); + webSocketEnvironment.setConnectListener(new AuthenticatedConnectListener(receiptSender, messagesManager, messageSender, apnFallbackManager, clientPresenceManager, retrySchedulingExecutor)); webSocketEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET)); webSocketEnvironment.jersey().register(new KeepAliveController(clientPresenceManager)); webSocketEnvironment.jersey().register(messageController); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index 26cb56838..c735724d9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -23,6 +23,8 @@ import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.websocket.session.WebSocketSessionContext; import org.whispersystems.websocket.setup.WebSocketConnectListener; +import java.util.concurrent.ScheduledExecutorService; + import static com.codahale.metrics.MetricRegistry.name; public class AuthenticatedConnectListener implements WebSocketConnectListener { @@ -39,17 +41,20 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { private final MessageSender messageSender; private final ApnFallbackManager apnFallbackManager; private final ClientPresenceManager clientPresenceManager; + private final ScheduledExecutorService retrySchedulingExecutor; public AuthenticatedConnectListener(ReceiptSender receiptSender, - MessagesManager messagesManager, - final MessageSender messageSender, ApnFallbackManager apnFallbackManager, - ClientPresenceManager clientPresenceManager) + MessagesManager messagesManager, + final MessageSender messageSender, ApnFallbackManager apnFallbackManager, + ClientPresenceManager clientPresenceManager, + ScheduledExecutorService retrySchedulingExecutor) { this.receiptSender = receiptSender; this.messagesManager = messagesManager; this.messageSender = messageSender; this.apnFallbackManager = apnFallbackManager; this.clientPresenceManager = clientPresenceManager; + this.retrySchedulingExecutor = retrySchedulingExecutor; } @Override @@ -60,7 +65,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { final Timer.Context timer = durationTimer.time(); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, - context.getClient()); + context.getClient(), + retrySchedulingExecutor); openWebsocketCounter.inc(); RedisOperation.unchecked(() -> apnFallbackManager.cancel(account, device)); @@ -71,6 +77,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { openWebsocketCounter.dec(); timer.stop(); + connection.stop(); + RedisOperation.unchecked(() -> clientPresenceManager.clearPresence(account.getUuid(), device.getId())); RedisOperation.unchecked(() -> { messagesManager.removeMessageAvailabilityListener(connection); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index bb180a4ca..b836b90c6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -20,11 +20,15 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.Random; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.LongAdder; @@ -79,6 +83,11 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac @VisibleForTesting static final int MAX_DESKTOP_MESSAGE_SIZE = 1024 * 1024; + @VisibleForTesting + static final int MAX_CONSECUTIVE_RETRIES = 5; + private static final long RETRY_DELAY_MILLIS = 1_000; + private static final int RETRY_DELAY_JITTER_MILLIS = 500; + private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class); private final ReceiptSender receiptSender; @@ -87,6 +96,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac private final Account account; private final Device device; private final WebSocketClient client; + private final ScheduledExecutorService retrySchedulingExecutor; private final boolean isDesktopClient; @@ -95,6 +105,10 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false); private final LongAdder sentMessageCounter = new LongAdder(); private final AtomicLong queueDrainStartTime = new AtomicLong(); + private final AtomicInteger consecutiveRetries = new AtomicInteger(); + private final AtomicReference> retryFuture = new AtomicReference<>(); + + private final Random random = new Random(); private enum StoredMessageState { EMPTY, @@ -103,16 +117,18 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac } public WebSocketConnection(ReceiptSender receiptSender, - MessagesManager messagesManager, - Account account, - Device device, - WebSocketClient client) + MessagesManager messagesManager, + Account account, + Device device, + WebSocketClient client, + ScheduledExecutorService retrySchedulingExecutor) { this.receiptSender = receiptSender; this.messagesManager = messagesManager; this.account = account; this.device = device; this.client = client; + this.retrySchedulingExecutor = retrySchedulingExecutor; Optional maybePlatform; @@ -131,6 +147,12 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac } public void stop() { + final ScheduledFuture future = retryFuture.get(); + + if (future != null) { + future.cancel(false); + } + client.close(1000, "OK"); } @@ -203,24 +225,40 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac sendNextMessagePage(state != StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE, queueClearedFuture); queueClearedFuture.whenComplete((v, cause) -> { - if (cause == null && sentInitialQueueEmptyMessage.compareAndSet(false, true)) { - final List tags = List.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent())); - final long drainDuration = System.currentTimeMillis() - queueDrainStartTime.get(); + if (cause == null) { + consecutiveRetries.set(0); - Metrics.summary(INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME, tags).record(sentMessageCounter.sum()); - Metrics.timer(INITIAL_QUEUE_DRAIN_TIMER_NAME, tags).record(drainDuration, TimeUnit.MILLISECONDS); + if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) { + final List tags = List.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent())); + final long drainDuration = System.currentTimeMillis() - queueDrainStartTime.get(); - if (drainDuration > SLOW_DRAIN_THRESHOLD) { - Metrics.counter(SLOW_QUEUE_DRAIN_COUNTER_NAME, tags).increment(); + Metrics.summary(INITIAL_QUEUE_LENGTH_DISTRIBUTION_NAME, tags).record(sentMessageCounter.sum()); + Metrics.timer(INITIAL_QUEUE_DRAIN_TIMER_NAME, tags).record(drainDuration, TimeUnit.MILLISECONDS); + + if (drainDuration > SLOW_DRAIN_THRESHOLD) { + Metrics.counter(SLOW_QUEUE_DRAIN_COUNTER_NAME, tags).increment(); + } + + client.sendRequest("PUT", "/api/v1/queue/empty", + Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty()); } - - client.sendRequest("PUT", "/api/v1/queue/empty", Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty()); + } else { + storedMessageState.compareAndSet(StoredMessageState.EMPTY, state); } processStoredMessagesSemaphore.release(); - if (cause == null && storedMessageState.get() != StoredMessageState.EMPTY) { - processStoredMessages(); + if (cause == null) { + if (storedMessageState.get() != StoredMessageState.EMPTY) { + processStoredMessages(); + } + } else { + if (consecutiveRetries.incrementAndGet() > MAX_CONSECUTIVE_RETRIES) { + client.close(1011, "Failed to retrieve messages"); + } else { + final long delay = RETRY_DELAY_MILLIS + random.nextInt(RETRY_DELAY_JITTER_MILLIS); + retryFuture.set(retrySchedulingExecutor.schedule(this::processStoredMessages, delay, TimeUnit.MILLISECONDS)); + } } }); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java index a17192fe0..310613f35 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java @@ -28,6 +28,7 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import org.apache.commons.lang3.RandomStringUtils; @@ -62,12 +63,9 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest private Device device; private WebSocketClient webSocketClient; private WebSocketConnection webSocketConnection; + private ScheduledExecutorService retrySchedulingExecutor; - private long serialTimestamp = System.currentTimeMillis(); - - @Before - public void setupAccountsDao() { - } + private long serialTimestamp = System.currentTimeMillis(); @Before @Override @@ -80,6 +78,7 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest account = mock(Account.class); device = mock(Device.class); webSocketClient = mock(WebSocketClient.class); + retrySchedulingExecutor = Executors.newSingleThreadScheduledExecutor(); when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(UUID.randomUUID()); @@ -90,7 +89,8 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class)), account, device, - webSocketClient); + webSocketClient, + retrySchedulingExecutor); } @After @@ -99,6 +99,9 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest executorService.shutdown(); executorService.awaitTermination(2, TimeUnit.SECONDS); + retrySchedulingExecutor.shutdown(); + retrySchedulingExecutor.awaitTermination(2, TimeUnit.SECONDS); + super.tearDown(); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 77d3e8462..52a18f718 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -35,8 +35,11 @@ import java.util.Set; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import io.lettuce.core.RedisException; import org.apache.commons.lang3.RandomStringUtils; import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.junit.Before; @@ -75,6 +78,7 @@ public class WebSocketConnectionTest { private UpgradeRequest upgradeRequest; private ReceiptSender receiptSender; private ApnFallbackManager apnFallbackManager; + private ScheduledExecutorService retrySchedulingExecutor; @Before public void setup() { @@ -85,13 +89,15 @@ public class WebSocketConnectionTest { upgradeRequest = mock(UpgradeRequest.class); receiptSender = mock(ReceiptSender.class); apnFallbackManager = mock(ApnFallbackManager.class); + retrySchedulingExecutor = mock(ScheduledExecutorService.class); } @Test public void testCredentials() throws Exception { MessagesManager storedMessages = mock(MessagesManager.class); WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator); - AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages, mock(MessageSender.class), apnFallbackManager, mock(ClientPresenceManager.class)); + AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages, mock(MessageSender.class), apnFallbackManager, mock(ClientPresenceManager.class), + retrySchedulingExecutor); WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) @@ -178,7 +184,7 @@ public class WebSocketConnectionTest { }); WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, - account, device, client); + account, device, client, retrySchedulingExecutor); connection.start(); verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.>any()); @@ -203,7 +209,7 @@ public class WebSocketConnectionTest { public void testOnlineSend() throws Exception { final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); final UUID accountUuid = UUID.randomUUID(); @@ -330,7 +336,7 @@ public class WebSocketConnectionTest { }); WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, - account, device, client); + account, device, client, retrySchedulingExecutor); connection.start(); @@ -353,7 +359,7 @@ public class WebSocketConnectionTest { public void testProcessStoredMessageConcurrency() throws InterruptedException { final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(UUID.randomUUID()); @@ -414,7 +420,7 @@ public class WebSocketConnectionTest { public void testProcessStoredMessagesMultiplePages() throws InterruptedException { final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(UUID.randomUUID()); @@ -457,7 +463,7 @@ public class WebSocketConnectionTest { public void testProcessStoredMessagesContainsSenderUuid() throws InterruptedException { final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(UUID.randomUUID()); @@ -507,7 +513,7 @@ public class WebSocketConnectionTest { public void testProcessStoredMessagesSingleEmptyCall() { final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); final UUID accountUuid = UUID.randomUUID(); @@ -536,7 +542,7 @@ public class WebSocketConnectionTest { public void testRequeryOnStateMismatch() throws InterruptedException { final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); final UUID accountUuid = UUID.randomUUID(); when(account.getNumber()).thenReturn("+18005551234"); @@ -583,7 +589,7 @@ public class WebSocketConnectionTest { public void testProcessCachedMessagesOnly() { final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); final UUID accountUuid = UUID.randomUUID(); @@ -615,7 +621,7 @@ public class WebSocketConnectionTest { public void testProcessDatabaseMessagesAfterPersist() { final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); final UUID accountUuid = UUID.randomUUID(); @@ -693,7 +699,7 @@ public class WebSocketConnectionTest { } }); - WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client); + WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client, retrySchedulingExecutor); connection.start(); verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.>any()); @@ -766,7 +772,7 @@ public class WebSocketConnectionTest { } }); - WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client); + WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client, retrySchedulingExecutor); connection.start(); verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.>any()); @@ -785,6 +791,37 @@ public class WebSocketConnectionTest { verify(client).close(anyInt(), anyString()); } + @Test + public void testRetrieveMessageException() { + MessagesManager storedMessages = mock(MessagesManager.class); + + UUID accountUuid = UUID.randomUUID(); + + when(device.getId()).thenReturn(2L); + + when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); + when(account.getNumber()).thenReturn("+14152222222"); + when(account.getUuid()).thenReturn(accountUuid); + + String userAgent = "Signal-Android/4.68.3"; + + when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false)) + .thenThrow(new RedisException("OH NO")); + + when(retrySchedulingExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer((Answer>) invocation -> { + invocation.getArgument(0, Runnable.class).run(); + return mock(ScheduledFuture.class); + }); + + final WebSocketClient client = mock(WebSocketClient.class); + + WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client, retrySchedulingExecutor); + connection.start(); + + verify(retrySchedulingExecutor, times(WebSocketConnection.MAX_CONSECUTIVE_RETRIES)).schedule(any(Runnable.class), anyLong(), any()); + verify(client).close(eq(1011), anyString()); + } + private OutgoingMessageEntity createMessage(long id, boolean cached, String sender, UUID senderUuid, long timestamp, boolean receipt, String content) { return new OutgoingMessageEntity(id, cached, UUID.randomUUID(), receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE, null, timestamp, sender, senderUuid, 1, content.getBytes(), null, 0);