From f766c57743d02e880ad1edfcf1e2ab923838ad16 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Wed, 9 Sep 2020 14:26:13 -0400 Subject: [PATCH] Query for more stored messages if an update happens while we're already processing a batch. --- .../websocket/WebSocketConnection.java | 11 +++- .../websocket/WebSocketConnectionTest.java | 51 +++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index d81b9f753..b67a24f31 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -65,6 +65,7 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability private final WebSocketClient client; private final String connectionId; + private int storedMessageState = 0; private boolean processingStoredMessages = false; private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false); @@ -93,6 +94,11 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability switch (pubSubMessage.getType().getNumber()) { case PubSubMessage.Type.QUERY_DB_VALUE: pubSubPersistedMeter.mark(); + + synchronized (this) { + storedMessageState++; + } + processStoredMessages(); break; case PubSubMessage.Type.DELIVER_VALUE: @@ -184,12 +190,15 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability @VisibleForTesting void processStoredMessages() { + final int processedState; + synchronized (this) { if (processingStoredMessages) { return; } processingStoredMessages = true; + processedState = storedMessageState; } OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent()); @@ -227,7 +236,7 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability processingStoredMessages = false; } - if (messages.hasMore()) { + if (messages.hasMore() || storedMessageState > processedState) { processStoredMessages(); } else { final boolean shouldSendEmptyQueueMessage; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index edc91d681..f06f3dd62 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -522,6 +522,57 @@ public class WebSocketConnectionTest { verify(client, times(1)).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); } + @Test + public void testRequeryAfterOnStateMismatch() throws InterruptedException { + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); + + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getUuid()).thenReturn(UUID.randomUUID()); + when(device.getId()).thenReturn(1L); + when(client.getUserAgent()).thenReturn("Test-UA"); + + final List firstPageMessages = + List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"), + createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")); + + final List secondPageMessages = + List.of(createMessage(3L, false, "sender1", UUID.randomUUID(), 3333, false, "third")); + + final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, false); + final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false); + + when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent())) + .thenReturn(firstPage) + .thenReturn(secondPage) + .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + final byte[] queryDbMessageBytes = PubSubProtos.PubSubMessage.newBuilder() + .setType(PubSubProtos.PubSubMessage.Type.QUERY_DB) + .build() + .toByteArray(); + + final CountDownLatch sendLatch = new CountDownLatch(firstPageMessages.size() + secondPageMessages.size()); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer((Answer>)invocation -> { + connection.onDispatchMessage("channel", queryDbMessageBytes); + sendLatch.countDown(); + + return CompletableFuture.completedFuture(successResponse); + }); + + connection.processStoredMessages(); + + sendLatch.await(); + + verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class)); + verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); + } + private OutgoingMessageEntity createMessage(long id, boolean cached, String sender, UUID senderUuid, long timestamp, boolean receipt, String content) { return new OutgoingMessageEntity(id, cached, UUID.randomUUID(), receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE, null, timestamp, sender, senderUuid, 1, content.getBytes(), null, 0);