diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 8753e1fbc..995b43ef8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -300,7 +300,7 @@ public class WhisperServerService extends Application pushLatencyManager.recordQueueRead(destinationUuid, destinationDevice, userAgent)); - List messages = cachedMessagesOnly ? new ArrayList<>() : this.messages.load(destination, destinationDevice); + List messages = this.messages.load(destination, destinationDevice); if (messages.size() <= Messages.RESULT_SET_CHUNK_SIZE) { messages.addAll(messagesCache.get(destinationUuid, destinationDevice, Messages.RESULT_SET_CHUNK_SIZE - messages.size())); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index d23cc8a88..571b6ab7e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -4,7 +4,6 @@ import com.codahale.metrics.Histogram; import com.codahale.metrics.Meter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; -import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import org.slf4j.Logger; @@ -32,12 +31,9 @@ import org.whispersystems.websocket.messages.WebSocketResponseMessage; import javax.ws.rs.WebApplicationException; import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.Optional; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Semaphore; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; import static com.codahale.metrics.MetricRegistry.name; import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; @@ -67,16 +63,6 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability private final WebSocketClient client; private final String connectionId; - private final Semaphore processStoredMessagesSemaphore = new Semaphore(1); - private final AtomicReference storedMessageState = new AtomicReference<>(StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE); - private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false); - - private enum StoredMessageState { - EMPTY, - CACHED_NEW_MESSAGES_AVAILABLE, - PERSISTED_NEW_MESSAGES_AVAILABLE - } - public WebSocketConnection(PushSender pushSender, ReceiptSender receiptSender, MessagesManager messagesManager, @@ -102,12 +88,11 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability switch (pubSubMessage.getType().getNumber()) { case PubSubMessage.Type.QUERY_DB_VALUE: pubSubPersistedMeter.mark(); - storedMessageState.set(StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE); processStoredMessages(); break; case PubSubMessage.Type.DELIVER_VALUE: pubSubNewMessageMeter.mark(); - sendMessage(Envelope.parseFrom(pubSubMessage.getContent()), Optional.empty()); + sendMessage(Envelope.parseFrom(pubSubMessage.getContent()), Optional.empty(), false); break; default: logger.warn("Unknown pubsub message: " + pubSubMessage.getType().getNumber()); @@ -126,7 +111,10 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability processStoredMessages(); } - private CompletableFuture sendMessage(final Envelope message, final Optional storedMessageInfo) { + private void sendMessage(final Envelope message, + final Optional storedMessageInfo, + final boolean requery) + { try { String header; Optional body; @@ -141,7 +129,7 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability sendMessageMeter.mark(); - return client.sendRequest("PUT", "/api/v1/message", List.of(header, TimestampHeaderUtil.getTimestampHeader()), body) + client.sendRequest("PUT", "/api/v1/message", List.of(header, TimestampHeaderUtil.getTimestampHeader()), body) .thenAccept(response -> { boolean isReceipt = message.getType() == Envelope.Type.RECEIPT; @@ -152,6 +140,7 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability if (isSuccessResponse(response)) { if (storedMessageInfo.isPresent()) messagesManager.delete(account.getNumber(), account.getUuid(), device.getId(), storedMessageInfo.get().id, storedMessageInfo.get().cached); if (!isReceipt) sendDeliveryReceiptFor(message); + if (requery) processStoredMessages(); } else if (!isSuccessResponse(response) && !storedMessageInfo.isPresent()) { requeueMessage(message); } @@ -162,7 +151,6 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability }); } catch (CryptoEncodingException e) { logger.warn("Bad signaling key", e); - return CompletableFuture.failedFuture(e); } } @@ -192,42 +180,20 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability return response != null && response.getStatus() >= 200 && response.getStatus() < 300; } - @VisibleForTesting - void processStoredMessages() { - if (processStoredMessagesSemaphore.tryAcquire()) { - final StoredMessageState state = storedMessageState.getAndSet(StoredMessageState.EMPTY); - final CompletableFuture queueClearedFuture = new CompletableFuture<>(); + private void processStoredMessages() { + OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent()); + Iterator iterator = messages.getMessages().iterator(); - sendNextMessagePage(state != StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE, queueClearedFuture); - - queueClearedFuture.whenComplete((v, cause) -> { - if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) { - client.sendRequest("PUT", "/api/v1/queue/empty", Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty()); - } - - processStoredMessagesSemaphore.release(); - - if (storedMessageState.get() != StoredMessageState.EMPTY) { - processStoredMessages(); - } - }); - } - } - - private void sendNextMessagePage(final boolean cachedMessagesOnly, final CompletableFuture queueClearedFuture) { - final OutgoingMessageEntityList messages = messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), cachedMessagesOnly); - final CompletableFuture[] sendFutures = new CompletableFuture[messages.getMessages().size()]; - - for (int i = 0; i < messages.getMessages().size(); i++) { - final OutgoingMessageEntity message = messages.getMessages().get(i); - final Envelope.Builder builder = Envelope.newBuilder() - .setType(Envelope.Type.valueOf(message.getType())) - .setTimestamp(message.getTimestamp()) - .setServerTimestamp(message.getServerTimestamp()); + while (iterator.hasNext()) { + OutgoingMessageEntity message = iterator.next(); + Envelope.Builder builder = Envelope.newBuilder() + .setType(Envelope.Type.valueOf(message.getType())) + .setTimestamp(message.getTimestamp()) + .setServerTimestamp(message.getServerTimestamp()); if (!Util.isEmpty(message.getSource())) { builder.setSource(message.getSource()) - .setSourceDevice(message.getSourceDevice()); + .setSourceDevice(message.getSourceDevice()); } if (message.getMessage() != null) { @@ -242,40 +208,33 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability builder.setRelay(message.getRelay()); } - sendFutures[i] = sendMessage(builder.build(), Optional.of(new StoredMessageInfo(message.getId(), message.isCached()))); + sendMessage(builder.build(), Optional.of(new StoredMessageInfo(message.getId(), message.isCached())), !iterator.hasNext() && messages.hasMore()); } - CompletableFuture.allOf(sendFutures).whenComplete((v, cause) -> { - if (messages.hasMore()) { - sendNextMessagePage(cachedMessagesOnly, queueClearedFuture); - } else { - queueClearedFuture.complete(null); - } - }); + if (!messages.hasMore()) { + client.sendRequest("PUT", "/api/v1/queue/empty", Collections.singletonList(TimestampHeaderUtil.getTimestampHeader()), Optional.empty()); + } } @Override public void handleNewMessagesAvailable() { messageAvailableMeter.mark(); - - storedMessageState.compareAndSet(StoredMessageState.EMPTY, StoredMessageState.CACHED_NEW_MESSAGES_AVAILABLE); - processStoredMessages(); } @Override public void handleNewEphemeralMessageAvailable() { ephemeralMessageAvailableMeter.mark(); - messagesManager.takeEphemeralMessage(account.getUuid(), device.getId()) - .ifPresent(message -> sendMessage(message, Optional.empty())); + final Optional maybeMessage = messagesManager.takeEphemeralMessage(account.getUuid(), device.getId()); + + if (maybeMessage.isPresent()) { + sendMessage(maybeMessage.get(), Optional.empty(), false); + } } @Override public void handleMessagesPersisted() { messagesPersistedMeter.mark(); - - storedMessageState.set(StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE); - processStoredMessages(); } @Override diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java index dc8738068..757c6a104 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java @@ -257,7 +257,7 @@ public class MessageControllerTest { OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); - when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList); + when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString())).thenReturn(messagesList); OutgoingMessageEntityList response = resources.getJerseyTest().target("/v1/messages/") @@ -294,7 +294,7 @@ public class MessageControllerTest { OutgoingMessageEntityList messagesList = new OutgoingMessageEntityList(messages, false); - when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString(), anyBoolean())).thenReturn(messagesList); + when(messagesManager.getMessagesForDevice(eq(AuthHelper.VALID_NUMBER), eq(AuthHelper.VALID_UUID), eq(1L), anyString())).thenReturn(messagesList); Response response = resources.getJerseyTest().target("/v1/messages/") diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java similarity index 54% rename from service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java rename to service/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java index 15dcddaaf..4e115b7a7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/websocket/WebSocketConnectionTest.java @@ -1,7 +1,6 @@ -package org.whispersystems.textsecuregcm.websocket; +package org.whispersystems.textsecuregcm.tests.websocket; import com.google.protobuf.ByteString; -import io.dropwizard.auth.basic.BasicCredentials; import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.junit.Test; import org.mockito.ArgumentMatchers; @@ -22,13 +21,16 @@ import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.PubSubManager; import org.whispersystems.textsecuregcm.storage.PubSubProtos; import org.whispersystems.textsecuregcm.util.Base64; +import org.whispersystems.textsecuregcm.websocket.AuthenticatedConnectListener; +import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator; +import org.whispersystems.textsecuregcm.websocket.WebSocketConnection; +import org.whispersystems.textsecuregcm.websocket.WebsocketAddress; import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult; import org.whispersystems.websocket.messages.WebSocketResponseMessage; import org.whispersystems.websocket.session.WebSocketSessionContext; import java.io.IOException; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; @@ -37,25 +39,11 @@ import java.util.Optional; import java.util.Set; import java.util.UUID; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -import static org.mockito.ArgumentMatchers.anyBoolean; +import io.dropwizard.auth.basic.BasicCredentials; +import static org.junit.Assert.*; import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.anyInt; -import static org.mockito.Mockito.anyLong; -import static org.mockito.Mockito.anyString; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.reset; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; public class WebSocketConnectionTest { @@ -150,7 +138,7 @@ public class WebSocketConnectionTest { String userAgent = "user-agent"; - when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false)) + when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent)) .thenReturn(outgoingMessagesList); final List> futures = new LinkedList<>(); @@ -237,7 +225,7 @@ public class WebSocketConnectionTest { String userAgent = "user-agent"; - when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false)) + when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent)) .thenReturn(pendingMessagesList); final List> futures = new LinkedList<>(); @@ -286,64 +274,6 @@ public class WebSocketConnectionTest { verify(client).close(anyInt(), anyString()); } - @Test(timeout = 5_000L) - public void testOnlineSendViaKeyspaceNotification() throws Exception { - final MessagesManager messagesManager = mock(MessagesManager.class); - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); - - final UUID accountUuid = UUID.randomUUID(); - - when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(accountUuid); - when(device.getId()).thenReturn(1L); - when(client.getUserAgent()).thenReturn("Test-UA"); - - when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) - .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)) - .thenReturn(new OutgoingMessageEntityList(List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first")), false)) - .thenReturn(new OutgoingMessageEntityList(List.of(createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")), false)); - - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); - - final AtomicInteger sendCounter = new AtomicInteger(0); - - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer((Answer>)invocation -> { - synchronized (sendCounter) { - sendCounter.incrementAndGet(); - sendCounter.notifyAll(); - } - - return CompletableFuture.completedFuture(successResponse); - }); - - // This is a little hacky and non-obvious, but because the first call to getMessagesForDevice returns empty list of - // messages, the call to CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded - // future, and the whenComplete method will get called immediately on THIS thread, so we don't need to synchronize - // or wait for anything. - connection.onDispatchSubscribed("channel"); - - connection.handleNewMessagesAvailable(); - - synchronized (sendCounter) { - while (sendCounter.get() < 1) { - sendCounter.wait(); - } - } - - connection.handleNewMessagesAvailable(); - - synchronized (sendCounter) { - while (sendCounter.get() < 2) { - sendCounter.wait(); - } - } - - verify(client, times(1)).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); - verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class)); - } - @Test public void testPendingSend() throws Exception { MessagesManager storedMessages = mock(MessagesManager.class); @@ -406,7 +336,7 @@ public class WebSocketConnectionTest { String userAgent = "user-agent"; - when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent, false)) + when(storedMessages.getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), userAgent)) .thenReturn(pendingMessagesList); final List> futures = new LinkedList<>(); @@ -446,251 +376,6 @@ public class WebSocketConnectionTest { verify(client).close(anyInt(), anyString()); } - @Test(timeout = 5000L) - public void testProcessStoredMessageConcurrency() throws InterruptedException { - final MessagesManager messagesManager = mock(MessagesManager.class); - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); - - when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(UUID.randomUUID()); - when(device.getId()).thenReturn(1L); - when(client.getUserAgent()).thenReturn("Test-UA"); - - final AtomicBoolean threadWaiting = new AtomicBoolean(false); - final AtomicBoolean returnMessageList = new AtomicBoolean(false); - - when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent(), false)).thenAnswer((Answer)invocation -> { - synchronized (threadWaiting) { - threadWaiting.set(true); - threadWaiting.notifyAll(); - } - - synchronized (returnMessageList) { - while (!returnMessageList.get()) { - returnMessageList.wait(); - } - } - - return new OutgoingMessageEntityList(Collections.emptyList(), false); - }); - - final Thread[] threads = new Thread[10]; - final CountDownLatch unblockedThreadsLatch = new CountDownLatch(threads.length - 1); - - for (int i = 0; i < threads.length; i++) { - threads[i] = new Thread(() -> { - connection.processStoredMessages(); - unblockedThreadsLatch.countDown(); - }); - - threads[i].start(); - } - - unblockedThreadsLatch.await(); - - synchronized (threadWaiting) { - while (!threadWaiting.get()) { - threadWaiting.wait(); - } - } - - synchronized (returnMessageList) { - returnMessageList.set(true); - returnMessageList.notifyAll(); - } - - for (final Thread thread : threads) { - thread.join(); - } - - verify(messagesManager).getMessagesForDevice(anyString(), any(UUID.class), anyLong(), anyString(), eq(false)); - } - - @Test(timeout = 5000L) - public void testProcessStoredMessagesMultiplePages() throws InterruptedException { - final MessagesManager messagesManager = mock(MessagesManager.class); - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); - - when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(UUID.randomUUID()); - when(device.getId()).thenReturn(1L); - when(client.getUserAgent()).thenReturn("Test-UA"); - - final List firstPageMessages = - List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"), - createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")); - - final List secondPageMessages = - List.of(createMessage(3L, false, "sender1", UUID.randomUUID(), 3333, false, "third")); - - final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, true); - final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false); - - when(messagesManager.getMessagesForDevice(account.getNumber(), account.getUuid(), 1L, client.getUserAgent(), false)) - .thenReturn(firstPage) - .thenReturn(secondPage); - - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); - - final CountDownLatch sendLatch = new CountDownLatch(firstPageMessages.size() + secondPageMessages.size()); - - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer((Answer>)invocation -> { - sendLatch.countDown(); - return CompletableFuture.completedFuture(successResponse); - }); - - connection.processStoredMessages(); - - sendLatch.await(); - - verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class)); - verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); - } - - @Test - public void testProcessStoredMessagesSingleEmptyCall() { - final MessagesManager messagesManager = mock(MessagesManager.class); - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); - - final UUID accountUuid = UUID.randomUUID(); - - when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(accountUuid); - when(device.getId()).thenReturn(1L); - when(client.getUserAgent()).thenReturn("Test-UA"); - - when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) - .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); - - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); - - // This is a little hacky and non-obvious, but because we're always returning an empty list of messages, the call to - // CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded future, and the - // whenComplete method will get called immediately on THIS thread, so we don't need to synchronize or wait for - // anything. - connection.processStoredMessages(); - connection.processStoredMessages(); - - verify(client, times(1)).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); - } - - @Test(timeout = 5000L) - public void testRequeryOnStateMismatch() throws InterruptedException { - final MessagesManager messagesManager = mock(MessagesManager.class); - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); - final UUID accountUuid = UUID.randomUUID(); - - when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(accountUuid); - when(device.getId()).thenReturn(1L); - when(client.getUserAgent()).thenReturn("Test-UA"); - - final List firstPageMessages = - List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"), - createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")); - - final List secondPageMessages = - List.of(createMessage(3L, false, "sender1", UUID.randomUUID(), 3333, false, "third")); - - final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(firstPageMessages, false); - final OutgoingMessageEntityList secondPage = new OutgoingMessageEntityList(secondPageMessages, false); - - when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) - .thenReturn(firstPage) - .thenReturn(secondPage) - .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); - - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); - - final byte[] queryDbMessageBytes = PubSubProtos.PubSubMessage.newBuilder() - .setType(PubSubProtos.PubSubMessage.Type.QUERY_DB) - .build() - .toByteArray(); - - final CountDownLatch sendLatch = new CountDownLatch(firstPageMessages.size() + secondPageMessages.size()); - - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer((Answer>)invocation -> { - connection.onDispatchMessage("channel", queryDbMessageBytes); - sendLatch.countDown(); - - return CompletableFuture.completedFuture(successResponse); - }); - - connection.processStoredMessages(); - - sendLatch.await(); - - verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class)); - verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())); - } - - @Test - public void testProcessCachedMessagesOnly() { - final MessagesManager messagesManager = mock(MessagesManager.class); - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); - - final UUID accountUuid = UUID.randomUUID(); - - when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(accountUuid); - when(device.getId()).thenReturn(1L); - when(client.getUserAgent()).thenReturn("Test-UA"); - - when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) - .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); - - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); - - // This is a little hacky and non-obvious, but because we're always returning an empty list of messages, the call to - // CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded future, and the - // whenComplete method will get called immediately on THIS thread, so we don't need to synchronize or wait for - // anything. - connection.processStoredMessages(); - - verify(messagesManager).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), false); - - connection.handleNewMessagesAvailable(); - - verify(messagesManager).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), true); - } - - @Test - public void testProcessDatabaseMessagesAfterPersist() { - final MessagesManager messagesManager = mock(MessagesManager.class); - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(pushSender, receiptSender, messagesManager, account, device, client, "concurrency"); - - final UUID accountUuid = UUID.randomUUID(); - - when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(accountUuid); - when(device.getId()).thenReturn(1L); - when(client.getUserAgent()).thenReturn("Test-UA"); - - when(messagesManager.getMessagesForDevice(eq("+18005551234"), eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) - .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); - - final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); - when(successResponse.getStatus()).thenReturn(200); - - // This is a little hacky and non-obvious, but because we're always returning an empty list of messages, the call to - // CompletableFuture.allOf(...) in processStoredMessages will produce an instantly-succeeded future, and the - // whenComplete method will get called immediately on THIS thread, so we don't need to synchronize or wait for - // anything. - connection.processStoredMessages(); - connection.handleMessagesPersisted(); - - verify(messagesManager, times(2)).getMessagesForDevice(account.getNumber(), account.getUuid(), device.getId(), client.getUserAgent(), false); - } private OutgoingMessageEntity createMessage(long id, boolean cached, String sender, UUID senderUuid, long timestamp, boolean receipt, String content) { return new OutgoingMessageEntity(id, cached, UUID.randomUUID(), receipt ? Envelope.Type.RECEIPT_VALUE : Envelope.Type.CIPHERTEXT_VALUE,