From 62c31eb202461b7684ba8c82d4258ccf9c1e471d Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Mon, 14 Sep 2020 15:57:06 -0400 Subject: [PATCH] Revert "Revert keyspace delivery for all messages" This reverts commit 4dc49604b6c8be45916fd4d2a542390e53047277. --- .../textsecuregcm/WhisperServerService.java | 2 +- .../controllers/MessageController.java | 3 +- .../textsecuregcm/push/WebsocketSender.java | 28 +- .../storage/MessagesManager.java | 5 +- .../websocket/WebSocketConnection.java | 95 +++-- .../controllers/MessageControllerTest.java | 4 +- .../websocket/WebSocketConnectionTest.java | 337 +++++++++++++++++- 7 files changed, 427 insertions(+), 47 deletions(-) rename service/src/test/java/org/whispersystems/textsecuregcm/{tests => }/websocket/WebSocketConnectionTest.java (54%) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 995b43ef8..8753e1fbc 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -300,7 +300,7 @@ public class WhisperServerService extends Application pushLatencyManager.recordQueueRead(destinationUuid, destinationDevice, userAgent)); - List messages = this.messages.load(destination, destinationDevice); + List messages = cachedMessagesOnly ? new ArrayList<>() : this.messages.load(destination, destinationDevice); if (messages.size() <= Messages.RESULT_SET_CHUNK_SIZE) { messages.addAll(messagesCache.get(destinationUuid, destinationDevice, Messages.RESULT_SET_CHUNK_SIZE - messages.size())); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index 571b6ab7e..d23cc8a88 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -4,6 +4,7 @@ import com.codahale.metrics.Histogram; import com.codahale.metrics.Meter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; +import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import org.slf4j.Logger; @@ -31,9 +32,12 @@ import org.whispersystems.websocket.messages.WebSocketResponseMessage; import javax.ws.rs.WebApplicationException; import java.util.Collections; -import java.util.Iterator; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import static com.codahale.metrics.MetricRegistry.name; import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; @@ -63,6 +67,16 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability private final WebSocketClient client; private final String connectionId; + private final Semaphore processStoredMessagesSemaphore = new Semaphore(1); + private final AtomicReference storedMessageState = new AtomicReference<>(StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE); + private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false); + + private enum StoredMessageState { + EMPTY, + CACHED_NEW_MESSAGES_AVAILABLE, + PERSISTED_NEW_MESSAGES_AVAILABLE + } + public WebSocketConnection(PushSender pushSender, ReceiptSender receiptSender, MessagesManager messagesManager, @@ -88,11 +102,12 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability switch (pubSubMessage.getType().getNumber()) { case PubSubMessage.Type.QUERY_DB_VALUE: pubSubPersistedMeter.mark(); + storedMessageState.set(StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE); processStoredMessages(); break; case PubSubMessage.Type.DELIVER_VALUE: pubSubNewMessageMeter.mark(); - sendMessage(Envelope.parseFrom(pubSubMessage.getContent()), Optional.empty(), false); + sendMessage(Envelope.parseFrom(pubSubMessage.getContent()), Optional.empty()); break; default: logger.warn("Unknown pubsub message: " + pubSubMessage.getType().getNumber()); @@ -111,10 +126,7 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability processStoredMessages(); } - private void sendMessage(final Envelope message, - final Optional storedMessageInfo, - final boolean requery) - { + private CompletableFuture sendMessage(final Envelope message, final Optional storedMessageInfo) { try { String header; Optional body; @@ -129,7 +141,7 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability sendMessageMeter.mark(); - client.sendRequest("PUT", "/api/v1/message", List.of(header, TimestampHeaderUtil.getTimestampHeader()), body) + return client.sendRequest("PUT", "/api/v1/message", List.of(header, TimestampHeaderUtil.getTimestampHeader()), body) .thenAccept(response -> { boolean isReceipt = message.getType() == Envelope.Type.RECEIPT; @@ -140,7 +152,6 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability if (isSuccessResponse(response)) { if (storedMessageInfo.isPresent()) messagesManager.delete(account.getNumber(), account.getUuid(), device.getId(), storedMessageInfo.get().id, storedMessageInfo.get().cached); if (!isReceipt) sendDeliveryReceiptFor(message); - if (requery) processStoredMessages(); } else if (!isSuccessResponse(response) && !storedMessageInfo.isPresent()) { requeueMessage(message); } @@ -151,6 +162,7 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability }); } catch (CryptoEncodingException e) { logger.warn("Bad signaling key", e); + return CompletableFuture.failedFuture(e); } } @@ -180,20 +192,42 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability return response != null && response.getStatus() >= 200 && response.getStatus() < 300; } - private void processStoredMessages() { - OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent()); - Iterator iterator = messages.getMessages().iterator(); + @VisibleForTesting + void processStoredMessages() { + if (processStoredMessagesSemaphore.tryAcquire()) { + final StoredMessageState state = storedMessageState.getAndSet(StoredMessageState.EMPTY); + final CompletableFuture queueClearedFuture = new CompletableFuture<>(); - while (iterator.hasNext()) { - OutgoingMessageEntity message = iterator.next(); - Envelope.Builder builder = Envelope.newBuilder() - .setType(Envelope.Type.valueOf(message.getType())) - .setTimestamp(message.getTimestamp()) - .setServerTimestamp(message.getServerTimestamp()); + sendNextMessagePage(state != StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE, queueClearedFuture); + + queueClearedFuture.whenComplete((v, cause) -> { + if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) { + client.sendRequest("PUT", "/api/v1/queue/empty", Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty()); + } + + processStoredMessagesSemaphore.release(); + + if (storedMessageState.get() != StoredMessageState.EMPTY) { + processStoredMessages(); + } + }); + } + } + + private void sendNextMessagePage(final boolean cachedMessagesOnly, final CompletableFuture queueClearedFuture) { + final OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), cachedMessagesOnly); + final CompletableFuture[] sendFutures = new CompletableFuture[messages.getMessages().size()]; + + for (int i = 0; i < messages.getMessages().size(); i++) { + final OutgoingMessageEntity message = messages.getMessages().get(i); + final Envelope.Builder builder = Envelope.newBuilder() + .setType(Envelope.Type.valueOf(message.getType())) + .setTimestamp(message.getTimestamp()) + .setServerTimestamp(message.getServerTimestamp()); if (!Util.isEmpty(message.getSource())) { builder.setSource(message.getSource()) - .setSourceDevice(message.getSourceDevice()); + .setSourceDevice(message.getSourceDevice()); } if (message.getMessage() != null) { @@ -208,33 +242,40 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability builder.setRelay(message.getRelay()); } - sendMessage(builder.build(), Optional.of(new StoredMessageInfo(message.getId(), message.isCached())), !iterator.hasNext() && messages.hasMore()); + sendFutures[i] = sendMessage(builder.build(), Optional.of(new StoredMessageInfo(message.getId(), message.isCached()))); } - if (!messages.hasMore()) { - client.sendRequest("PUT", "/api/v1/queue/empty", Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty()); - } + CompletableFuture.allOf(sendFutures).whenComplete((v, cause) -> { + if (messages.hasMore()) { + sendNextMessagePage(cachedMessagesOnly, queueClearedFuture); + } else { + queueClearedFuture.complete(null); + } + }); } @Override public void handleNewMessagesAvailable() { messageAvailableMeter.mark(); + + storedMessageState.compareAndSet(StoredMessageState.EMPTY, StoredMessageState.CACHED_NEW_MESSAGES_AVAILABLE); + processStoredMessages(); } @Override public void handleNewEphemeralMessageAvailable() { ephemeralMessageAvailableMeter.mark(); - final Optional maybeMessage = messagesManager.takeEphemeralMessage(account.getUuid(), device.getId()); - - if (maybeMessage.isPresent()) { - sendMessage(maybeMessage.get(), Optional.empty(), false); - } + messagesManager.takeEphemeralMessage(account.getUuid(), device.getId()) + .ifPresent(message -> sendMessage(message, Optional.empty())); } @Override public void handleMessagesPersisted() { messagesPersistedMeter.mark(); + + storedMessageState.set(StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE); + processStoredMessages(); } @Override diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java index f6a808c95..df58dad3e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java @@ -257,7 +257,7 @@ public class MessageControllerTest { OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); - when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString())).thenReturn(messagesList); + when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList); OutgoingMessageEntityList response = resources.getJerseyTest().target("/v1/messages/") @@ -294,7 +294,7 @@ public class MessageControllerTest { OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); - when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString())).thenReturn(messagesList); + when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList); Response response = resources.getJerseyTest().target("/v1/messages/") diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java similarity index 54% rename from service/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java rename to service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 4e115b7a7..15dcddaaf 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -1,6 +1,7 @@ -package org.whispersystems.textsecuregcm.tests.websocket; +package org.whispersystems.textsecuregcm.websocket; import com.google.protobuf.ByteString; +import io.dropwizard.auth.basic.BasicCredentials; import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.junit.Test; import org.mockito.ArgumentMatchers; @@ -21,16 +22,13 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.PubSubManager; import org.whispersystems.textsecuregcm.storage.PubSubProtos; import org.whispersystems.textsecuregcm.util.Base64; -import org.whispersystems.textsecuregcm.websocket.AuthenticatedConnectListener; -import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator; -import org.whispersystems.textsecuregcm.websocket.WebSocketConnection; -import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult; import org.whispersystems.websocket.messages.WebSocketResponseMessage; import org.whispersystems.websocket.session.WebSocketSessionContext; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; @@ -39,11 +37,25 @@ import java.util.Optional; import java.util.Set; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; -import io.dropwizard.auth.basic.BasicCredentials; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; public class WebSocketConnectionTest { @@ -138,7 +150,7 @@ public class WebSocketConnectionTest { String userAgent = "user-agent"; - when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent)) + when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false)) .thenReturn(outgoingMessagesList); final List> futures = new LinkedList<>(); @@ -225,7 +237,7 @@ public class WebSocketConnectionTest { String userAgent = "user-agent"; - when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent)) + when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false)) .thenReturn(pendingMessagesList); final List> futures = new LinkedList<>(); @@ -274,6 +286,64 @@ public class WebSocketConnectionTest { verify(client).close(anyInt(), anyString()); } + @Test(timeout = 5_000L) + public void testOnlineSendViaKeyspaceNotification() throws Exception { + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); + + final UUID accountUuid = UUID.randomUUID(); + + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getUuid()).thenReturn(accountUuid); + when(device.getId()).thenReturn(1L); + when(client.getUserAgent()).thenReturn("Test-UA"); + + when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) + .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)) + .thenReturn(new OutgoingMessageEntityList(List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first")), false)) + .thenReturn(new OutgoingMessageEntityList(List.of(createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")), false)); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + final AtomicInteger sendCounter = new AtomicInteger(0); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer((Answer>)invocation -> { + synchronized (sendCounter) { + sendCounter.incrementAndGet(); + sendCounter.notifyAll(); + } + + return CompletableFuture.completedFuture(successResponse); + }); + + // This is a little hacky and non-obvious, but because the first call to getMessagesForDevice returns empty list of + // messages, the call to CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded + // future, and the whenComplete method will get called immediately on THIS thread, so we don't need to synchronize + // or wait for anything. + connection.onDispatchSubscribed("channel"); + + connection.handleNewMessagesAvailable(); + + synchronized (sendCounter) { + while (sendCounter.get() < 1) { + sendCounter.wait(); + } + } + + connection.handleNewMessagesAvailable(); + + synchronized (sendCounter) { + while (sendCounter.get() < 2) { + sendCounter.wait(); + } + } + + verify(client, times(1)).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); + verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class)); + } + @Test public void testPendingSend() throws Exception { MessagesManager storedMessages = mock(MessagesManager.class); @@ -336,7 +406,7 @@ public class WebSocketConnectionTest { String userAgent = "user-agent"; - when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent)) + when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false)) .thenReturn(pendingMessagesList); final List> futures = new LinkedList<>(); @@ -376,6 +446,251 @@ public class WebSocketConnectionTest { verify(client).close(anyInt(), anyString()); } + @Test(timeout = 5000L) + public void testProcessStoredMessageConcurrency() throws InterruptedException { + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); + + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getUuid()).thenReturn(UUID.randomUUID()); + when(device.getId()).thenReturn(1L); + when(client.getUserAgent()).thenReturn("Test-UA"); + + final AtomicBoolean threadWaiting = new AtomicBoolean(false); + final AtomicBoolean returnMessageList = new AtomicBoolean(false); + + when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent(), false)).thenAnswer((Answer)invocation -> { + synchronized (threadWaiting) { + threadWaiting.set(true); + threadWaiting.notifyAll(); + } + + synchronized (returnMessageList) { + while (!returnMessageList.get()) { + returnMessageList.wait(); + } + } + + return new OutgoingMessageEntityList(Collections.emptyList(), false); + }); + + final Thread[] threads = new Thread[10]; + final CountDownLatch unblockedThreadsLatch = new CountDownLatch(threads.length - 1); + + for (int i = 0; i < threads.length; i++) { + threads[i] = new Thread(() -> { + connection.processStoredMessages(); + unblockedThreadsLatch.countDown(); + }); + + threads[i].start(); + } + + unblockedThreadsLatch.await(); + + synchronized (threadWaiting) { + while (!threadWaiting.get()) { + threadWaiting.wait(); + } + } + + synchronized (returnMessageList) { + returnMessageList.set(true); + returnMessageList.notifyAll(); + } + + for (final Thread thread : threads) { + thread.join(); + } + + verify(messagesManager).getMessagesForDevice(anyString(), any(UUID.class), anyLong(), anyString(), eq(false)); + } + + @Test(timeout = 5000L) + public void testProcessStoredMessagesMultiplePages() throws InterruptedException { + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); + + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getUuid()).thenReturn(UUID.randomUUID()); + when(device.getId()).thenReturn(1L); + when(client.getUserAgent()).thenReturn("Test-UA"); + + final List firstPageMessages = + List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"), + createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")); + + final List secondPageMessages = + List.of(createMessage(3L, false, "sender1", UUID.randomUUID(), 3333, false, "third")); + + final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, true); + final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false); + + when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent(), false)) + .thenReturn(firstPage) + .thenReturn(secondPage); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + final CountDownLatch sendLatch = new CountDownLatch(firstPageMessages.size() + secondPageMessages.size()); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer((Answer>)invocation -> { + sendLatch.countDown(); + return CompletableFuture.completedFuture(successResponse); + }); + + connection.processStoredMessages(); + + sendLatch.await(); + + verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class)); + verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); + } + + @Test + public void testProcessStoredMessagesSingleEmptyCall() { + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); + + final UUID accountUuid = UUID.randomUUID(); + + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getUuid()).thenReturn(accountUuid); + when(device.getId()).thenReturn(1L); + when(client.getUserAgent()).thenReturn("Test-UA"); + + when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) + .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + // This is a little hacky and non-obvious, but because we're always returning an empty list of messages, the call to + // CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded future, and the + // whenComplete method will get called immediately on THIS thread, so we don't need to synchronize or wait for + // anything. + connection.processStoredMessages(); + connection.processStoredMessages(); + + verify(client, times(1)).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); + } + + @Test(timeout = 5000L) + public void testRequeryOnStateMismatch() throws InterruptedException { + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); + final UUID accountUuid = UUID.randomUUID(); + + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getUuid()).thenReturn(accountUuid); + when(device.getId()).thenReturn(1L); + when(client.getUserAgent()).thenReturn("Test-UA"); + + final List firstPageMessages = + List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"), + createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")); + + final List secondPageMessages = + List.of(createMessage(3L, false, "sender1", UUID.randomUUID(), 3333, false, "third")); + + final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, false); + final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false); + + when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) + .thenReturn(firstPage) + .thenReturn(secondPage) + .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + final byte[] queryDbMessageBytes = PubSubProtos.PubSubMessage.newBuilder() + .setType(PubSubProtos.PubSubMessage.Type.QUERY_DB) + .build() + .toByteArray(); + + final CountDownLatch sendLatch = new CountDownLatch(firstPageMessages.size() + secondPageMessages.size()); + + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer((Answer>)invocation -> { + connection.onDispatchMessage("channel", queryDbMessageBytes); + sendLatch.countDown(); + + return CompletableFuture.completedFuture(successResponse); + }); + + connection.processStoredMessages(); + + sendLatch.await(); + + verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class)); + verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); + } + + @Test + public void testProcessCachedMessagesOnly() { + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); + + final UUID accountUuid = UUID.randomUUID(); + + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getUuid()).thenReturn(accountUuid); + when(device.getId()).thenReturn(1L); + when(client.getUserAgent()).thenReturn("Test-UA"); + + when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) + .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + // This is a little hacky and non-obvious, but because we're always returning an empty list of messages, the call to + // CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded future, and the + // whenComplete method will get called immediately on THIS thread, so we don't need to synchronize or wait for + // anything. + connection.processStoredMessages(); + + verify(messagesManager).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), false); + + connection.handleNewMessagesAvailable(); + + verify(messagesManager).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), true); + } + + @Test + public void testProcessDatabaseMessagesAfterPersist() { + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); + + final UUID accountUuid = UUID.randomUUID(); + + when(account.getNumber()).thenReturn("+18005551234"); + when(account.getUuid()).thenReturn(accountUuid); + when(device.getId()).thenReturn(1L); + when(client.getUserAgent()).thenReturn("Test-UA"); + + when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) + .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); + + final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); + when(successResponse.getStatus()).thenReturn(200); + + // This is a little hacky and non-obvious, but because we're always returning an empty list of messages, the call to + // CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded future, and the + // whenComplete method will get called immediately on THIS thread, so we don't need to synchronize or wait for + // anything. + connection.processStoredMessages(); + connection.handleMessagesPersisted(); + + verify(messagesManager, times(2)).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), false); + } private OutgoingMessageEntity createMessage(long id, boolean cached, String sender, UUID senderUuid, long timestamp, boolean receipt, String content) { return new OutgoingMessageEntity(id, cached, UUID.randomUUID(), receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE,