From 278b4e810d292c724d3dc7be165d1106196c9a6f Mon Sep 17 00:00:00 2001 From: Chris Eager Date: Thu, 16 Dec 2021 12:03:35 -0800 Subject: [PATCH] Add (failing) test for send message timeouts --- .../WebSocketConnectionIntegrationTest.java | 88 ++++++++++++++++++- 1 file changed, 85 insertions(+), 3 deletions(-) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java index 6071d24cf..a505c5fbe 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java @@ -33,7 +33,6 @@ import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.stream.Collectors; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -243,8 +242,91 @@ class WebSocketConnectionIntegrationTest { } catch (InvalidProtocolBufferException e) { throw new RuntimeException(e); } - }) - .collect(Collectors.toList()); + }).toList(); + + assertTrue(expectedMessages.containsAll(sentMessages)); + }); + } + + @Test + void testProcessStoredMessagesSendFutureTimeout() { + final int persistedMessageCount = 207; + final int cachedMessageCount = 173; + + final List expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount); + + assertTimeoutPreemptively(Duration.ofSeconds(15), () -> { + + { + final List persistedMessages = new ArrayList<>(persistedMessageCount); + + for (int i = 0; i < persistedMessageCount; i++) { + final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID()); + persistedMessages.add(envelope); + expectedMessages.add(envelope); + } + + messagesDynamoDb.store(persistedMessages, account.getUuid(), device.getId()); + } + + for (int i = 0; i < cachedMessageCount; i++) { + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); + messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope); + + expectedMessages.add(envelope); + } + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + final CompletableFuture neverCompleting = new CompletableFuture<>(); + + // for the first message, return a future that never completes + when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())) + .thenReturn(neverCompleting) + .thenReturn(CompletableFuture.completedFuture(successResponse)); + + when(webSocketClient.isOpen()).thenReturn(true); + + final AtomicBoolean queueCleared = new AtomicBoolean(false); + + when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), any())).thenAnswer( + (Answer>) invocation -> { + synchronized (queueCleared) { + queueCleared.set(true); + queueCleared.notifyAll(); + } + + return CompletableFuture.completedFuture(successResponse); + }); + + webSocketConnection.processStoredMessages(); + + synchronized (queueCleared) { + while (!queueCleared.get()) { + queueCleared.wait(); + } + } + + //noinspection unchecked + ArgumentCaptor> messageBodyCaptor = ArgumentCaptor.forClass(Optional.class); + + // We expect all of the messages from both pools to be sent, plus one for the future that times out + verify(webSocketClient, atMost(persistedMessageCount + cachedMessageCount + 1)).sendRequest(eq("PUT"), + eq("/api/v1/message"), anyList(), messageBodyCaptor.capture()); + + verify(webSocketClient).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), eq(Optional.empty())); + + final List sentMessages = messageBodyCaptor.getAllValues().stream() + .map(Optional::get) + .map(messageBytes -> { + try { + return Envelope.parseFrom(messageBytes); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + }).toList(); assertTrue(expectedMessages.containsAll(sentMessages)); });