From a74438d1ee2213eeecb0bf547bbad5fc94a1a677 Mon Sep 17 00:00:00 2001 From: Chris Eager Date: Tue, 21 Nov 2023 17:34:55 -0600 Subject: [PATCH] Add test for concurrent in-flight outbound messages on WebSocket queue processing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This also elevates the implicit default concurrency (via reactor’s `Queues.SMALL_BUFFER_SIZE`) to be explicit. --- .../websocket/WebSocketConnection.java | 6 +- .../websocket/WebSocketConnectionTest.java | 91 ++++++++++++++++++- 2 files changed, 91 insertions(+), 6 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 57da2415d..681fa2e23 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -99,6 +99,9 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac @VisibleForTesting static final int MESSAGE_PUBLISHER_LIMIT_RATE = 100; + @VisibleForTesting + static final int MESSAGE_SENDER_MAX_CONCURRENCY = 256; + @VisibleForTesting static final int MAX_CONSECUTIVE_RETRIES = 5; private static final long RETRY_DELAY_MILLIS = 1_000; @@ -372,8 +375,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac }, // otherwise just emit nothing e -> Mono.empty() - ) - ) + ), MESSAGE_SENDER_MAX_CONCURRENCY) .subscribeOn(messageDeliveryScheduler) .subscribe( // no additional consumer of values - it is Flux by now diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 86ea865e6..fc03a4bbd 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -38,6 +38,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Queue; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; @@ -45,6 +46,7 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.IntStream; import java.util.stream.Stream; import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.junit.jupiter.api.AfterEach; @@ -463,9 +465,7 @@ class WebSocketConnectionTest { final CountDownLatch queueEmptyLatch = new CountDownLatch(1); when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))) - .thenAnswer(invocation -> { - return CompletableFuture.completedFuture(successResponse); - }); + .thenAnswer(invocation -> CompletableFuture.completedFuture(successResponse)); when(client.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()))) .thenAnswer(invocation -> { @@ -475,7 +475,6 @@ class WebSocketConnectionTest { assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { connection.processStoredMessages(); - queueEmptyLatch.await(); }); @@ -484,6 +483,90 @@ class WebSocketConnectionTest { verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); } + @Test + void testProcessStoredMessagesMultiplePagesBackpressure() { + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, + retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager); + + when(account.getNumber()).thenReturn("+18005551234"); + final UUID accountUuid = UUID.randomUUID(); + when(account.getUuid()).thenReturn(accountUuid); + when(device.getId()).thenReturn(Device.PRIMARY_ID); + when(client.isOpen()).thenReturn(true); + + // Create two publishers, each with >2x WebSocketConnection.MESSAGE_SENDER_MAX_CONCURRENCY messages + final TestPublisher firstPublisher = TestPublisher.createCold(); + final List firstPublisherMessages = IntStream.range(1, + 2 * WebSocketConnection.MESSAGE_SENDER_MAX_CONCURRENCY + 23) + .mapToObj(i -> createMessage(UUID.randomUUID(), UUID.randomUUID(), i, "content " + i)) + .toList(); + + final TestPublisher secondPublisher = TestPublisher.createCold(); + final List secondPublisherMessages = IntStream.range(firstPublisherMessages.size(), + firstPublisherMessages.size() + 2 * WebSocketConnection.MESSAGE_SENDER_MAX_CONCURRENCY + 73) + .mapToObj(i -> createMessage(UUID.randomUUID(), UUID.randomUUID(), i, "content " + i)) + .toList(); + + final Flux allMessages = Flux.concat(firstPublisher, secondPublisher); + when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(Device.PRIMARY_ID), eq(false))) + .thenReturn(allMessages); + + when(messagesManager.delete(eq(accountUuid), eq(Device.PRIMARY_ID), any(), any())) + .thenReturn(CompletableFuture.completedFuture(null)); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + final CountDownLatch queueEmptyLatch = new CountDownLatch(1); + + final Queue> pendingClientAcks = new LinkedList<>(); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))) + .thenAnswer(invocation -> { + final CompletableFuture pendingAck = new CompletableFuture<>(); + pendingClientAcks.add(pendingAck); + return pendingAck; + }); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()))) + .thenAnswer(invocation -> { + queueEmptyLatch.countDown(); + return CompletableFuture.completedFuture(successResponse); + }); + + assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { + // start processing + connection.processStoredMessages(); + + firstPublisher.assertWasRequested(); + // emit all messages from the first publisher + firstPublisher.emit(firstPublisherMessages.toArray(new Envelope[]{})); + // nothing should be requested from the second publisher, because max concurrency is less than the number emitted, + // and none have completed + secondPublisher.assertWasNotRequested(); + // there should only be MESSAGE_SENDER_MAX_CONCURRENCY pending client acknowledgements + assertEquals(WebSocketConnection.MESSAGE_SENDER_MAX_CONCURRENCY, pendingClientAcks.size()); + + while (!pendingClientAcks.isEmpty()) { + pendingClientAcks.poll().complete(successResponse); + } + + secondPublisher.assertWasRequested(); + secondPublisher.emit(secondPublisherMessages.toArray(new Envelope[0])); + + while (!pendingClientAcks.isEmpty()) { + pendingClientAcks.poll().complete(successResponse); + } + + queueEmptyLatch.await(); + }); + + verify(client, times(firstPublisherMessages.size() + secondPublisherMessages.size())).sendRequest(eq("PUT"), + eq("/api/v1/message"), any(List.class), any(Optional.class)); + verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); + } + @Test void testProcessStoredMessagesContainsSenderUuid() { final WebSocketClient client = mock(WebSocketClient.class);