queues = messagesCache.getQueuesToPersist(slot, Instant.now().plusSeconds(60), 100);
+
+ assertEquals(1, queues.size());
+ assertEquals(DESTINATION_UUID, MessagesCache.getAccountUuidFromQueueName(queues.get(0)));
+ assertEquals(DESTINATION_DEVICE_ID, MessagesCache.getDeviceIdFromQueueName(queues.get(0)));
+ }
+
+ @Test
+ void testNotifyListenerNewMessage() {
+ final AtomicBoolean notified = new AtomicBoolean(false);
+ final UUID messageGuid = UUID.randomUUID();
+
+ final MessageAvailabilityListener listener = new MessageAvailabilityListener() {
+ @Override
+ public boolean handleNewMessagesAvailable() {
+ synchronized (notified) {
+ notified.set(true);
+ notified.notifyAll();
+
+ return true;
+ }
+ }
+
+ @Override
+ public boolean handleMessagesPersisted() {
+ return true;
+ }
+ };
+
+ assertTimeoutPreemptively(Duration.ofSeconds(5), () -> {
+ messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener);
+ messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
+ generateRandomMessage(messageGuid, true));
+
+ synchronized (notified) {
+ while (!notified.get()) {
+ notified.wait();
+ }
+ }
+
+ assertTrue(notified.get());
+ });
+ }
+
+ @Test
+ void testNotifyListenerPersisted() {
+ final AtomicBoolean notified = new AtomicBoolean(false);
+
+ final MessageAvailabilityListener listener = new MessageAvailabilityListener() {
+ @Override
+ public boolean handleNewMessagesAvailable() {
+ return true;
+ }
+
+ @Override
+ public boolean handleMessagesPersisted() {
+ synchronized (notified) {
+ notified.set(true);
+ notified.notifyAll();
+
+ return true;
+ }
+ }
+ };
+
+ assertTimeoutPreemptively(Duration.ofSeconds(5), () -> {
+ messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener);
+
+ messagesCache.lockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID);
+ messagesCache.unlockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID);
+
+ synchronized (notified) {
+ while (!notified.get()) {
+ notified.wait();
+ }
+ }
+
+ assertTrue(notified.get());
+ });
+ }
+
+
+ /**
+ * Helper class that implements {@link MessageAvailabilityListener#handleNewMessagesAvailable()} by always returning
+ * {@code false}. Its {@code counter} field tracks how many times {@code handleNewMessagesAvailable} has been
+ * called.
+ *
+ * It uses a {@link CompletableFuture} to signal that it has received a “messages available” callback for the first
+ * time.
+ */
+ private static class NewMessagesAvailabilityClosedListener implements MessageAvailabilityListener {
+
+ private int counter;
+
+ private final Consumer messageHandledCallback;
+ private final CompletableFuture firstMessageHandled = new CompletableFuture<>();
+
+ private NewMessagesAvailabilityClosedListener(final Consumer messageHandledCallback) {
+ this.messageHandledCallback = messageHandledCallback;
+ }
+
+ @Override
+ public boolean handleNewMessagesAvailable() {
+ counter++;
+ messageHandledCallback.accept(counter);
+ firstMessageHandled.complete(null);
+
+ return false;
+
+ }
+
+ @Override
+ public boolean handleMessagesPersisted() {
+ return true;
+ }
+ }
+
+ @Test
+ void testAvailabilityListenerResponses() {
+ final NewMessagesAvailabilityClosedListener listener1 = new NewMessagesAvailabilityClosedListener(
+ count -> assertEquals(1, count));
+ final NewMessagesAvailabilityClosedListener listener2 = new NewMessagesAvailabilityClosedListener(
+ count -> assertEquals(1, count));
+
+ assertTimeoutPreemptively(Duration.ofSeconds(30), () -> {
+ messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener1);
+ final UUID messageGuid1 = UUID.randomUUID();
+ messagesCache.insert(messageGuid1, DESTINATION_UUID, DESTINATION_DEVICE_ID,
+ generateRandomMessage(messageGuid1, true));
+
+ listener1.firstMessageHandled.get();
+
+ // Avoid a race condition by blocking on the message handled future *and* the current notification executor task—
+ // the notification executor task includes unsubscribing `listener1`, and, if we don’t wait, sometimes
+ // `listener2` will get subscribed before `listener1` is cleaned up
+ sharedExecutorService.submit(() -> listener1.firstMessageHandled.get()).get();
+
+ final UUID messageGuid2 = UUID.randomUUID();
+ messagesCache.insert(messageGuid2, DESTINATION_UUID, DESTINATION_DEVICE_ID,
+ generateRandomMessage(messageGuid2, true));
+
+ messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener2);
+
+ final UUID messageGuid3 = UUID.randomUUID();
+ messagesCache.insert(messageGuid3, DESTINATION_UUID, DESTINATION_DEVICE_ID,
+ generateRandomMessage(messageGuid3, true));
+
+ listener2.firstMessageHandled.get();
+ });
+ }
+
+ private List get(final UUID destinationUuid, final long destinationDeviceId,
+ final int messageCount) {
+ return Flux.from(messagesCache.get(destinationUuid, destinationDeviceId))
+ .take(messageCount, true)
+ .collectList()
+ .block();
+ }
+
+ }
+
+ @Nested
+ class WithMockCluster {
+
+ private MessagesCache messagesCache;
+ private RedisAdvancedClusterReactiveCommands reactiveCommands;
+ private RedisAdvancedClusterAsyncCommands asyncCommands;
+
+ @SuppressWarnings("unchecked")
+ @BeforeEach
+ void setup() throws Exception {
+ reactiveCommands = mock(RedisAdvancedClusterReactiveCommands.class);
+ asyncCommands = mock(RedisAdvancedClusterAsyncCommands.class);
+
+ final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.builder()
+ .binaryReactiveCommands(reactiveCommands)
+ .binaryAsyncCommands(asyncCommands)
+ .build();
+
+ messagesCache = new MessagesCache(mockCluster, mockCluster, Clock.systemUTC(), mock(ExecutorService.class),
+ mock(ExecutorService.class));
+ }
+
+ @AfterEach
+ void teardown() {
+ StepVerifier.resetDefaultTimeout();
+ }
+
+ @Test
+ void testGetAllMessagesLimitsAndBackpressure() {
+ // this test makes sure that we don’t fetch and buffer all messages from the cache when the publisher
+ // is subscribed. Rather, we should be fetching in pages to satisfy downstream requests, so that memory usage
+ // is limited to few pages of messages
+
+ // we use a combination of Flux.just() and Sinks to control when data is “fetched” from the cache. The initial
+ // Flux.just()s are pages that are readily available, on demand. By design, there are more of these pages than
+ // the initial prefetch. The sinks allow us to create extra demand but defer producing values to satisfy the demand
+ // until later on.
+
+ final AtomicReference> page4Sink = new AtomicReference<>();
+ final AtomicReference> page56Sink = new AtomicReference<>();
+ final AtomicReference> emptyFinalPageSink = new AtomicReference<>();
+
+ final Deque> pages = new ArrayDeque<>();
+ pages.add(generatePage());
+ pages.add(generatePage());
+ pages.add(generatePage());
+ pages.add(generatePage());
+ // make sure that stale ephemeral messages are also produced by calls to getAllMessages()
+ pages.add(generateStaleEphemeralPage());
+ pages.add(generatePage());
+
+ when(reactiveCommands.evalsha(any(), any(), any(), any()))
+ .thenReturn(Flux.just(pages.pop()))
+ .thenReturn(Flux.just(pages.pop()))
+ .thenReturn(Flux.just(pages.pop()))
+ .thenReturn(Flux.create(sink -> page4Sink.compareAndSet(null, sink)))
+ .thenReturn(Flux.create(sink -> page56Sink.compareAndSet(null, sink)))
+ .thenReturn(Flux.create(sink -> emptyFinalPageSink.compareAndSet(null, sink)))
+ .thenReturn(Flux.empty());
+
+ final Flux> allMessages = messagesCache.getAllMessages(UUID.randomUUID(), 1L);
+
+ // Why initialValue = 3?
+ // 1. messagesCache.getAllMessages() above produces the first call
+ // 2. when we subscribe, the prefetch of 1 results in `expand()`, which produces a second call
+ // 3. there is an implicit “low tide mark” of 1, meaning there will be an extra call to replenish when there is
+ // 1 value remaining
+ final AtomicInteger expectedReactiveCommandInvocations = new AtomicInteger(3);
+
+ StepVerifier.setDefaultTimeout(Duration.ofSeconds(5));
+
+ final int page = 100;
+ final int halfPage = page / 2;
+
+ // in order to fully control demand and separate the prefetch mechanics, initially subscribe with a request of 0
+ StepVerifier.create(allMessages, 0)
+ .expectSubscription()
+ .then(() -> verify(reactiveCommands, times(expectedReactiveCommandInvocations.get())).evalsha(any(), any(),
+ any(), any()))
+ .thenRequest(halfPage) // page 0.5 requested
+ .expectNextCount(halfPage) // page 0.5 produced
+ // page 0.5 produced, 1.5 remain, so no additional interactions with the cache cluster
+ .then(() -> verify(reactiveCommands, times(expectedReactiveCommandInvocations.get())).evalsha(any(),
+ any(), any(), any()))
+ .then(() -> assertNull(page4Sink.get(), "page 4 should not have been fetched yet"))
+ .thenRequest(page) // page 1.5 requested
+ .expectNextCount(page) // page 1.5 produced
+
+ // we now have produced 1.5 pages, have 0.5 buffered, and two more have been prefetched.
+ // after producing more than a full page, we’ll need to replenish from the cache.
+ // future requests will depend on sink emitters.
+ // also NB: times() checks cumulative calls, hence addAndGet
+ .then(() -> verify(reactiveCommands, times(expectedReactiveCommandInvocations.addAndGet(1))).evalsha(any(),
+ any(), any(), any()))
+ .then(() -> assertNotNull(page4Sink.get(), "page 4 should have been fetched"))
+ .thenRequest(page + halfPage) // page 3 requested
+ .expectNextCount(page + halfPage) // page 1.5–3 produced
+
+ .thenRequest(halfPage) // page 3.5 requested
+ .then(() -> assertNull(page56Sink.get(), "page 5 should not have been fetched yet"))
+ .then(() -> page4Sink.get().next(pages.pop()).complete())
+ .expectNextCount(halfPage) // page 3.5 produced
+ .then(() -> verify(reactiveCommands, times(expectedReactiveCommandInvocations.addAndGet(1))).evalsha(any(),
+ any(), any(), any()))
+ .then(() -> assertNotNull(page56Sink.get(), "page 5 should have been fetched"))
+
+ .thenRequest(page) // page 4.5 requested
+ .expectNextCount(halfPage) // page 4 produced
+
+ .thenRequest(page * 4) // request more demand than we will ultimately satisfy
+
+ .then(() -> page56Sink.get().next(pages.pop()).next(pages.pop()).complete())
+ .expectNextCount(page + page) // page 5 and 6 produced
+ .then(() -> emptyFinalPageSink.get().complete())
+ // confirm that cache calls increased by 2: one for page 5-and-6 (we got a two-fer in next(pop()).next(pop()),
+ // and one for the final, empty page
+ .then(() -> verify(reactiveCommands, times(expectedReactiveCommandInvocations.addAndGet(2))).evalsha(any(),
+ any(), any(),
+ any()))
+ .expectComplete()
+ .log()
+ .verify();
+
+ // make sure that we consumed all the pages, especially in case of refactoring
+ assertTrue(pages.isEmpty());
+ }
+
+ @Test
+ void testGetDiscardsEphemeralMessages() {
+ final Deque> pages = new ArrayDeque<>();
+ pages.add(generatePage());
+ pages.add(generatePage());
+ pages.add(generateStaleEphemeralPage());
+
+ when(reactiveCommands.evalsha(any(), any(), any(), any()))
+ .thenReturn(Flux.just(pages.pop()))
+ .thenReturn(Flux.just(pages.pop()))
+ .thenReturn(Flux.just(pages.pop()))
+ .thenReturn(Flux.empty());
+
+ final AsyncCommand, ?, ?> removeSuccess = new AsyncCommand<>(mock(RedisCommand.class));
+ removeSuccess.complete();
+
+ when(asyncCommands.evalsha(any(), any(), any(), any()))
+ .thenReturn((RedisFuture) removeSuccess);
+
+ final Publisher> allMessages = messagesCache.get(UUID.randomUUID(), 1L);
+
+ StepVerifier.setDefaultTimeout(Duration.ofSeconds(5));
+
+ // async commands are used for remove(), and nothing should happen until we are subscribed
+ verify(asyncCommands, never()).evalsha(any(), any(), any(byte[][].class), any(byte[].class));
+ // the reactive commands will be called once, to prep the first page fetch (but no remote request would actually be sent)
+ verify(reactiveCommands, times(1)).evalsha(any(), any(), any(byte[][].class), any(byte[].class));
+
+ StepVerifier.create(allMessages)
+ .expectSubscription()
+ .expectNextCount(200)
+ .expectComplete()
+ .log()
+ .verify();
+
+ assertTrue(pages.isEmpty());
+ verify(asyncCommands, atLeast(1)).evalsha(any(), any(), any(), any());
+ }
+
+ private List generatePage() {
+ final List messagesAndIds = new ArrayList<>();
+
+ for (int i = 0; i < 100; i++) {
+ final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID(), true);
+ messagesAndIds.add(envelope.toByteArray());
+ messagesAndIds.add(String.valueOf(serialTimestamp).getBytes());
+ }
+
+ return messagesAndIds;
+ }
+
+ private List generateStaleEphemeralPage() {
+ final List messagesAndIds = new ArrayList<>();
+
+ for (int i = 0; i < 100; i++) {
+ final MessageProtos.Envelope envelope = generateRandomMessage(UUID.randomUUID(), true)
+ .toBuilder().setEphemeral(true).build();
+ messagesAndIds.add(envelope.toByteArray());
+ messagesAndIds.add(String.valueOf(serialTimestamp).getBytes());
+ }
+
+ return messagesAndIds;
+ }
}
private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final boolean sealedSender) {
@@ -234,194 +753,4 @@ class MessagesCacheTest {
return envelopeBuilder.build();
}
-
- @Test
- void testClearNullUuid() {
- // We're happy as long as this doesn't throw an exception
- messagesCache.clear(null);
- }
-
- @Test
- void testGetAccountFromQueueName() {
- assertEquals(DESTINATION_UUID,
- MessagesCache.getAccountUuidFromQueueName(
- new String(MessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID),
- StandardCharsets.UTF_8)));
- }
-
- @Test
- void testGetDeviceIdFromQueueName() {
- assertEquals(DESTINATION_DEVICE_ID,
- MessagesCache.getDeviceIdFromQueueName(
- new String(MessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID),
- StandardCharsets.UTF_8)));
- }
-
- @Test
- void testGetQueueNameFromKeyspaceChannel() {
- assertEquals("1b363a31-a429-4fb6-8959-984a025e72ff::7",
- MessagesCache.getQueueNameFromKeyspaceChannel(
- "__keyspace@0__:user_queue::{1b363a31-a429-4fb6-8959-984a025e72ff::7}"));
- }
-
- @ParameterizedTest
- @ValueSource(booleans = {true, false})
- public void testGetQueuesToPersist(final boolean sealedSender) {
- final UUID messageGuid = UUID.randomUUID();
-
- messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
- generateRandomMessage(messageGuid, sealedSender));
- final int slot = SlotHash.getSlot(DESTINATION_UUID + "::" + DESTINATION_DEVICE_ID);
-
- assertTrue(messagesCache.getQueuesToPersist(slot + 1, Instant.now().plusSeconds(60), 100).isEmpty());
-
- final List queues = messagesCache.getQueuesToPersist(slot, Instant.now().plusSeconds(60), 100);
-
- assertEquals(1, queues.size());
- assertEquals(DESTINATION_UUID, MessagesCache.getAccountUuidFromQueueName(queues.get(0)));
- assertEquals(DESTINATION_DEVICE_ID, MessagesCache.getDeviceIdFromQueueName(queues.get(0)));
- }
-
- @Test
- void testNotifyListenerNewMessage() {
- final AtomicBoolean notified = new AtomicBoolean(false);
- final UUID messageGuid = UUID.randomUUID();
-
- final MessageAvailabilityListener listener = new MessageAvailabilityListener() {
- @Override
- public boolean handleNewMessagesAvailable() {
- synchronized (notified) {
- notified.set(true);
- notified.notifyAll();
-
- return true;
- }
- }
-
- @Override
- public boolean handleMessagesPersisted() {
- return true;
- }
- };
-
- assertTimeoutPreemptively(Duration.ofSeconds(5), () -> {
- messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener);
- messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
- generateRandomMessage(messageGuid, true));
-
- synchronized (notified) {
- while (!notified.get()) {
- notified.wait();
- }
- }
-
- assertTrue(notified.get());
- });
- }
-
- @Test
- void testNotifyListenerPersisted() {
- final AtomicBoolean notified = new AtomicBoolean(false);
-
- final MessageAvailabilityListener listener = new MessageAvailabilityListener() {
- @Override
- public boolean handleNewMessagesAvailable() {
- return true;
- }
-
- @Override
- public boolean handleMessagesPersisted() {
- synchronized (notified) {
- notified.set(true);
- notified.notifyAll();
-
- return true;
- }
- }
- };
-
- assertTimeoutPreemptively(Duration.ofSeconds(5), () -> {
- messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener);
-
- messagesCache.lockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID);
- messagesCache.unlockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID);
-
- synchronized (notified) {
- while (!notified.get()) {
- notified.wait();
- }
- }
-
- assertTrue(notified.get());
- });
- }
-
-
- /**
- * Helper class that implements {@link MessageAvailabilityListener#handleNewMessagesAvailable()} by always returning
- * {@code false}. Its {@code counter} field tracks how many times {@code handleNewMessagesAvailable} has been called.
- *
- * It uses a {@link CompletableFuture} to signal that it has received a “messages available” callback for the first
- * time.
- */
- private static class NewMessagesAvailabilityClosedListener implements MessageAvailabilityListener {
-
- private int counter;
-
- private final Consumer messageHandledCallback;
- private final CompletableFuture firstMessageHandled = new CompletableFuture<>();
-
- private NewMessagesAvailabilityClosedListener(final Consumer messageHandledCallback) {
- this.messageHandledCallback = messageHandledCallback;
- }
-
- @Override
- public boolean handleNewMessagesAvailable() {
- counter++;
- messageHandledCallback.accept(counter);
- firstMessageHandled.complete(null);
-
- return false;
-
- }
-
- @Override
- public boolean handleMessagesPersisted() {
- return true;
- }
- }
-
- @Test
- void testAvailabilityListenerResponses() {
- final NewMessagesAvailabilityClosedListener listener1 = new NewMessagesAvailabilityClosedListener(
- count -> assertEquals(1, count));
- final NewMessagesAvailabilityClosedListener listener2 = new NewMessagesAvailabilityClosedListener(
- count -> assertEquals(1, count));
-
- assertTimeoutPreemptively(Duration.ofSeconds(30), () -> {
- messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener1);
- final UUID messageGuid1 = UUID.randomUUID();
- messagesCache.insert(messageGuid1, DESTINATION_UUID, DESTINATION_DEVICE_ID,
- generateRandomMessage(messageGuid1, true));
-
- listener1.firstMessageHandled.get();
-
- // Avoid a race condition by blocking on the message handled future *and* the current notification executor task—
- // the notification executor task includes unsubscribing `listener1`, and, if we don’t wait, sometimes
- // `listener2` will get subscribed before `listener1` is cleaned up
- notificationExecutorService.submit(() -> listener1.firstMessageHandled.get()).get();
-
- final UUID messageGuid2 = UUID.randomUUID();
- messagesCache.insert(messageGuid2, DESTINATION_UUID, DESTINATION_DEVICE_ID,
- generateRandomMessage(messageGuid2, true));
-
- messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener2);
-
- final UUID messageGuid3 = UUID.randomUUID();
- messagesCache.insert(messageGuid3, DESTINATION_UUID, DESTINATION_DEVICE_ID,
- generateRandomMessage(messageGuid3, true));
-
- listener2.firstMessageHandled.get();
- });
- }
}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java
index a98126059..eae1802ad 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesDynamoDbTest.java
@@ -9,14 +9,26 @@ import static org.assertj.core.api.Assertions.assertThat;
import com.google.protobuf.ByteString;
import java.time.Duration;
+import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.UUID;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+import org.reactivestreams.Publisher;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
+import org.whispersystems.textsecuregcm.tests.util.MessageHelper;
import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbExtension;
+import reactor.core.publisher.Flux;
+import reactor.test.StepVerifier;
class MessagesDynamoDbTest {
@@ -59,6 +71,7 @@ class MessagesDynamoDbTest {
MESSAGE3 = builder.build();
}
+ private ExecutorService messageDeletionExecutorService;
private MessagesDynamoDb messagesDynamoDb;
@@ -67,8 +80,18 @@ class MessagesDynamoDbTest {
@BeforeEach
void setup() {
- messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(), MessagesDynamoDbExtension.TABLE_NAME,
- Duration.ofDays(14));
+ messageDeletionExecutorService = Executors.newSingleThreadExecutor();
+ messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(),
+ dynamoDbExtension.getDynamoDbAsyncClient(), MessagesDynamoDbExtension.TABLE_NAME, Duration.ofDays(14),
+ messageDeletionExecutorService);
+ }
+
+ @AfterEach
+ void teardown() throws Exception {
+ messageDeletionExecutorService.shutdown();
+ messageDeletionExecutorService.awaitTermination(5, TimeUnit.SECONDS);
+
+ StepVerifier.resetDefaultTimeout();
}
@Test
@@ -77,7 +100,7 @@ class MessagesDynamoDbTest {
final int destinationDeviceId = random.nextInt(255) + 1;
messagesDynamoDb.store(List.of(MESSAGE1, MESSAGE2, MESSAGE3), destinationUuid, destinationDeviceId);
- final List messagesStored = messagesDynamoDb.load(destinationUuid, destinationDeviceId,
+ final List messagesStored = load(destinationUuid, destinationDeviceId,
MessagesDynamoDb.RESULT_SET_CHUNK_SIZE);
assertThat(messagesStored).isNotNull().hasSize(3);
final MessageProtos.Envelope firstMessage =
@@ -88,6 +111,73 @@ class MessagesDynamoDbTest {
assertThat(messagesStored).element(2).isEqualTo(MESSAGE2);
}
+ @ParameterizedTest
+ @ValueSource(ints = {10, 100, 100, 1_000, 3_000})
+ void testLoadManyAfterInsert(final int messageCount) {
+ final UUID destinationUuid = UUID.randomUUID();
+ final int destinationDeviceId = random.nextInt(255) + 1;
+
+ final List messages = new ArrayList<>(messageCount);
+ for (int i = 0; i < messageCount; i++) {
+ messages.add(MessageHelper.createMessage(UUID.randomUUID(), 1, destinationUuid, (i + 1L) * 1000, "message " + i));
+ }
+
+ messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId);
+
+ final Publisher> fetchedMessages = messagesDynamoDb.load(destinationUuid, destinationDeviceId, null);
+
+ final long firstRequest = Math.min(10, messageCount);
+ StepVerifier.setDefaultTimeout(Duration.ofSeconds(15));
+
+ StepVerifier.Step> step = StepVerifier.create(fetchedMessages, 0)
+ .expectSubscription()
+ .thenRequest(firstRequest)
+ .expectNextCount(firstRequest);
+
+ if (messageCount > firstRequest) {
+ step = step.thenRequest(messageCount)
+ .expectNextCount(messageCount - firstRequest);
+ }
+
+ step.thenCancel()
+ .verify();
+ }
+
+ @Test
+ void testLimitedLoad() {
+ final int messageCount = 200;
+ final UUID destinationUuid = UUID.randomUUID();
+ final int destinationDeviceId = random.nextInt(255) + 1;
+
+ final List messages = new ArrayList<>(messageCount);
+ for (int i = 0; i < messageCount; i++) {
+ messages.add(MessageHelper.createMessage(UUID.randomUUID(), 1, destinationUuid, (i + 1L) * 1000, "message " + i));
+ }
+
+ messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId);
+
+ final int messageLoadLimit = 100;
+ final int halfOfMessageLoadLimit = messageLoadLimit / 2;
+ final Publisher> fetchedMessages = messagesDynamoDb.load(destinationUuid, destinationDeviceId, messageLoadLimit);
+
+ StepVerifier.setDefaultTimeout(Duration.ofSeconds(10));
+
+ final AtomicInteger messagesRemaining = new AtomicInteger(messageLoadLimit);
+
+ StepVerifier.create(fetchedMessages, 0)
+ .expectSubscription()
+ .thenRequest(halfOfMessageLoadLimit)
+ .expectNextCount(halfOfMessageLoadLimit)
+ // the first 100 should be fetched and buffered, but further requests should fail
+ .then(() -> dynamoDbExtension.stopServer())
+ .thenRequest(halfOfMessageLoadLimit)
+ .expectNextCount(halfOfMessageLoadLimit)
+ // we’ve consumed all the buffered messages, so a single request will fail
+ .thenRequest(1)
+ .expectError()
+ .verify();
+ }
+
@Test
void testDeleteForDestination() {
final UUID destinationUuid = UUID.randomUUID();
@@ -96,18 +186,18 @@ class MessagesDynamoDbTest {
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
- assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
+ assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
- assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
+ assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE3);
- assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
+ assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteAllMessagesForAccount(destinationUuid);
- assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
- assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
- assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
+ assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
+ assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
+ assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
}
@@ -119,71 +209,79 @@ class MessagesDynamoDbTest {
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
- assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
+ assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
- assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
+ assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE3);
- assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
+ assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteAllMessagesForDevice(destinationUuid, 2);
- assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
+ assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
- assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
- assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
+ assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().isEmpty();
+ assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
}
@Test
- void testDeleteMessageByDestinationAndGuid() {
+ void testDeleteMessageByDestinationAndGuid() throws Exception {
final UUID destinationUuid = UUID.randomUUID();
final UUID secondDestinationUuid = UUID.randomUUID();
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
- assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
+ assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
- assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
+ assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE3);
- assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
+ assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteMessageByDestinationAndGuid(secondDestinationUuid,
- UUID.fromString(MESSAGE2.getServerGuid()));
+ UUID.fromString(MESSAGE2.getServerGuid())).get(5, TimeUnit.SECONDS);
- assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
+ assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
- assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
+ assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE3);
- assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
+ assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.isEmpty();
}
@Test
- void testDeleteSingleMessage() {
+ void testDeleteSingleMessage() throws Exception {
final UUID destinationUuid = UUID.randomUUID();
final UUID secondDestinationUuid = UUID.randomUUID();
messagesDynamoDb.store(List.of(MESSAGE1), destinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE2), secondDestinationUuid, 1);
messagesDynamoDb.store(List.of(MESSAGE3), destinationUuid, 2);
- assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
+ assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
- assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
+ assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE3);
- assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
+ assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.hasSize(1).element(0).isEqualTo(MESSAGE2);
messagesDynamoDb.deleteMessage(secondDestinationUuid, 1,
- UUID.fromString(MESSAGE2.getServerGuid()), MESSAGE2.getServerTimestamp());
+ UUID.fromString(MESSAGE2.getServerGuid()), MESSAGE2.getServerTimestamp()).get(1, TimeUnit.SECONDS);
- assertThat(messagesDynamoDb.load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
+ assertThat(load(destinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE1);
- assertThat(messagesDynamoDb.load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
+ assertThat(load(destinationUuid, 2, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull().hasSize(1)
.element(0).isEqualTo(MESSAGE3);
- assertThat(messagesDynamoDb.load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
+ assertThat(load(secondDestinationUuid, 1, MessagesDynamoDb.RESULT_SET_CHUNK_SIZE)).isNotNull()
.isEmpty();
}
+
+ private List load(final UUID destinationUuid, final long destinationDeviceId,
+ final int count) {
+ return Flux.from(messagesDynamoDb.load(destinationUuid, destinationDeviceId, count))
+ .take(count, true)
+ .collectList()
+ .block();
+ }
}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java
index 55d35fb84..109390fb3 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesManagerTest.java
@@ -14,13 +14,11 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
-import org.whispersystems.textsecuregcm.push.PushLatencyManager;
class MessagesManagerTest {
private final MessagesDynamoDb messagesDynamoDb = mock(MessagesDynamoDb.class);
private final MessagesCache messagesCache = mock(MessagesCache.class);
- private final PushLatencyManager pushLatencyManager = mock(PushLatencyManager.class);
private final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class);
private final MessagesManager messagesManager = new MessagesManager(messagesDynamoDb, messagesCache,
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/ProfilesManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/ProfilesManagerTest.java
index e3bc5ed44..894f12de0 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/ProfilesManagerTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/storage/ProfilesManagerTest.java
@@ -41,7 +41,7 @@ public class ProfilesManagerTest {
void setUp() {
//noinspection unchecked
commands = mock(RedisAdvancedClusterCommands.class);
- final FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.buildMockRedisCluster(commands);
+ final FaultTolerantRedisCluster cacheCluster = RedisClusterHelper.builder().stringCommands(commands).build();
profiles = mock(Profiles.class);
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageHelper.java
new file mode 100644
index 000000000..0ff6e7856
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/MessageHelper.java
@@ -0,0 +1,28 @@
+/*
+ * Copyright 2022 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.tests.util;
+
+import com.google.protobuf.ByteString;
+import java.nio.charset.StandardCharsets;
+import java.util.UUID;
+import org.whispersystems.textsecuregcm.entities.MessageProtos;
+
+public class MessageHelper {
+
+ public static MessageProtos.Envelope createMessage(UUID senderUuid, final int senderDeviceId, UUID destinationUuid,
+ long timestamp, String content) {
+ return MessageProtos.Envelope.newBuilder()
+ .setServerGuid(UUID.randomUUID().toString())
+ .setType(MessageProtos.Envelope.Type.CIPHERTEXT)
+ .setTimestamp(timestamp)
+ .setServerTimestamp(0)
+ .setSourceUuid(senderUuid.toString())
+ .setSourceDevice(senderDeviceId)
+ .setDestinationUuid(destinationUuid.toString())
+ .setContent(ByteString.copyFrom(content.getBytes(StandardCharsets.UTF_8)))
+ .build();
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java
index 101a2575f..3112b2d48 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisClusterHelper.java
@@ -5,70 +5,118 @@
package org.whispersystems.textsecuregcm.tests.util;
-import io.lettuce.core.cluster.api.StatefulRedisClusterConnection;
-import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
-import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
-
-import java.util.function.Consumer;
-import java.util.function.Function;
-
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
+import io.lettuce.core.cluster.api.StatefulRedisClusterConnection;
+import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
+import io.lettuce.core.cluster.api.reactive.RedisAdvancedClusterReactiveCommands;
+import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
+import java.util.function.Consumer;
+import java.util.function.Function;
+import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
+
public class RedisClusterHelper {
- @SuppressWarnings("unchecked")
- public static FaultTolerantRedisCluster buildMockRedisCluster(final RedisAdvancedClusterCommands stringCommands) {
- return buildMockRedisCluster(stringCommands, mock(RedisAdvancedClusterCommands.class));
+ public static RedisClusterHelper.Builder builder() {
+ return new Builder();
+ }
+
+ @SuppressWarnings("unchecked")
+ private static FaultTolerantRedisCluster buildMockRedisCluster(
+ final RedisAdvancedClusterCommands stringCommands,
+ final RedisAdvancedClusterCommands binaryCommands,
+ final RedisAdvancedClusterAsyncCommands binaryAsyncCommands,
+ final RedisAdvancedClusterReactiveCommands binaryReactiveCommands) {
+ final FaultTolerantRedisCluster cluster = mock(FaultTolerantRedisCluster.class);
+ final StatefulRedisClusterConnection stringConnection = mock(StatefulRedisClusterConnection.class);
+ final StatefulRedisClusterConnection binaryConnection = mock(StatefulRedisClusterConnection.class);
+
+ when(stringConnection.sync()).thenReturn(stringCommands);
+ when(binaryConnection.sync()).thenReturn(binaryCommands);
+ when(binaryConnection.async()).thenReturn(binaryAsyncCommands);
+ when(binaryConnection.reactive()).thenReturn(binaryReactiveCommands);
+
+ when(cluster.withCluster(any(Function.class))).thenAnswer(invocation -> {
+ return invocation.getArgument(0, Function.class).apply(stringConnection);
+ });
+
+ doAnswer(invocation -> {
+ invocation.getArgument(0, Consumer.class).accept(stringConnection);
+ return null;
+ }).when(cluster).useCluster(any(Consumer.class));
+
+ when(cluster.withCluster(any(Function.class))).thenAnswer(invocation -> {
+ return invocation.getArgument(0, Function.class).apply(stringConnection);
+ });
+
+ doAnswer(invocation -> {
+ invocation.getArgument(0, Consumer.class).accept(stringConnection);
+ return null;
+ }).when(cluster).useCluster(any(Consumer.class));
+
+ when(cluster.withBinaryCluster(any(Function.class))).thenAnswer(invocation -> {
+ return invocation.getArgument(0, Function.class).apply(binaryConnection);
+ });
+
+ doAnswer(invocation -> {
+ invocation.getArgument(0, Consumer.class).accept(binaryConnection);
+ return null;
+ }).when(cluster).useBinaryCluster(any(Consumer.class));
+
+ when(cluster.withBinaryCluster(any(Function.class))).thenAnswer(invocation -> {
+ return invocation.getArgument(0, Function.class).apply(binaryConnection);
+ });
+
+ doAnswer(invocation -> {
+ invocation.getArgument(0, Consumer.class).accept(binaryConnection);
+ return null;
+ }).when(cluster).useBinaryCluster(any(Consumer.class));
+
+ return cluster;
+ }
+
+ @SuppressWarnings("unchecked")
+ public static class Builder {
+
+ private RedisAdvancedClusterCommands stringCommands = mock(RedisAdvancedClusterCommands.class);
+ private RedisAdvancedClusterCommands binaryCommands = mock(RedisAdvancedClusterCommands.class);
+ private RedisAdvancedClusterAsyncCommands binaryAsyncCommands = mock(
+ RedisAdvancedClusterAsyncCommands.class);
+ private RedisAdvancedClusterReactiveCommands binaryReactiveCommands = mock(
+ RedisAdvancedClusterReactiveCommands.class);
+
+ private Builder() {
+
}
- @SuppressWarnings("unchecked")
- public static FaultTolerantRedisCluster buildMockRedisCluster(final RedisAdvancedClusterCommands stringCommands, final RedisAdvancedClusterCommands binaryCommands) {
- final FaultTolerantRedisCluster cluster = mock(FaultTolerantRedisCluster.class);
- final StatefulRedisClusterConnection stringConnection = mock(StatefulRedisClusterConnection.class);
- final StatefulRedisClusterConnection binaryConnection = mock(StatefulRedisClusterConnection.class);
-
- when(stringConnection.sync()).thenReturn(stringCommands);
- when(binaryConnection.sync()).thenReturn(binaryCommands);
-
- when(cluster.withCluster(any(Function.class))).thenAnswer(invocation -> {
- return invocation.getArgument(0, Function.class).apply(stringConnection);
- });
-
- doAnswer(invocation -> {
- invocation.getArgument(0, Consumer.class).accept(stringConnection);
- return null;
- }).when(cluster).useCluster(any(Consumer.class));
-
- when(cluster.withCluster(any(Function.class))).thenAnswer(invocation -> {
- return invocation.getArgument(0, Function.class).apply(stringConnection);
- });
-
- doAnswer(invocation -> {
- invocation.getArgument(0, Consumer.class).accept(stringConnection);
- return null;
- }).when(cluster).useCluster(any(Consumer.class));
-
- when(cluster.withBinaryCluster(any(Function.class))).thenAnswer(invocation -> {
- return invocation.getArgument(0, Function.class).apply(binaryConnection);
- });
-
- doAnswer(invocation -> {
- invocation.getArgument(0, Consumer.class).accept(binaryConnection);
- return null;
- }).when(cluster).useBinaryCluster(any(Consumer.class));
-
- when(cluster.withBinaryCluster(any(Function.class))).thenAnswer(invocation -> {
- return invocation.getArgument(0, Function.class).apply(binaryConnection);
- });
-
- doAnswer(invocation -> {
- invocation.getArgument(0, Consumer.class).accept(binaryConnection);
- return null;
- }).when(cluster).useBinaryCluster(any(Consumer.class));
-
- return cluster;
+ public Builder stringCommands(final RedisAdvancedClusterCommands stringCommands) {
+ this.stringCommands = stringCommands;
+ return this;
}
+
+ public Builder binaryCommands(final RedisAdvancedClusterCommands binaryCommands) {
+ this.binaryCommands = binaryCommands;
+ return this;
+ }
+
+ public Builder binaryAsyncCommands(final RedisAdvancedClusterAsyncCommands binaryAsyncCommands) {
+ this.binaryAsyncCommands = binaryAsyncCommands;
+ return this;
+ }
+
+ public Builder binaryReactiveCommands(
+ final RedisAdvancedClusterReactiveCommands binaryReactiveCommands) {
+ this.binaryReactiveCommands = binaryReactiveCommands;
+ return this;
+ }
+
+ public FaultTolerantRedisCluster build() {
+ return RedisClusterHelper.buildMockRedisCluster(stringCommands, binaryCommands, binaryAsyncCommands,
+ binaryReactiveCommands);
+ }
+ }
+
}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java
index 59a85f38a..6a912dff2 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2013-2021 Signal Messenger, LLC
+ * Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@@ -22,6 +22,7 @@ import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.IOException;
+import java.time.Clock;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
@@ -36,8 +37,10 @@ import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.CsvSource;
+import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
@@ -56,6 +59,7 @@ import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbExtension;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
+import reactor.core.scheduler.Schedulers;
class WebSocketConnectionIntegrationTest {
@@ -65,16 +69,13 @@ class WebSocketConnectionIntegrationTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
- private static final int SEND_FUTURES_TIMEOUT_MILLIS = 100;
-
- private ExecutorService executorService;
+ private ExecutorService sharedExecutorService;
private MessagesDynamoDb messagesDynamoDb;
private MessagesCache messagesCache;
private ReportMessageManager reportMessageManager;
private Account account;
private Device device;
private WebSocketClient webSocketClient;
- private WebSocketConnection webSocketConnection;
private ScheduledExecutorService retrySchedulingExecutor;
private long serialTimestamp = System.currentTimeMillis();
@@ -82,11 +83,12 @@ class WebSocketConnectionIntegrationTest {
@BeforeEach
void setUp() throws Exception {
- executorService = Executors.newSingleThreadExecutor();
+ sharedExecutorService = Executors.newSingleThreadExecutor();
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
- REDIS_CLUSTER_EXTENSION.getRedisCluster(), executorService);
- messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(), MessagesDynamoDbExtension.TABLE_NAME,
- Duration.ofDays(7));
+ REDIS_CLUSTER_EXTENSION.getRedisCluster(), Clock.systemUTC(), sharedExecutorService, sharedExecutorService);
+ messagesDynamoDb = new MessagesDynamoDb(dynamoDbExtension.getDynamoDbClient(),
+ dynamoDbExtension.getDynamoDbAsyncClient(), MessagesDynamoDbExtension.TABLE_NAME, Duration.ofDays(7),
+ sharedExecutorService);
reportMessageManager = mock(ReportMessageManager.class);
account = mock(Account.class);
device = mock(Device.class);
@@ -96,30 +98,36 @@ class WebSocketConnectionIntegrationTest {
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
-
- webSocketConnection = new WebSocketConnection(
- mock(ReceiptSender.class),
- new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager),
- new AuthenticatedAccount(() -> new Pair<>(account, device)),
- device,
- webSocketClient,
- SEND_FUTURES_TIMEOUT_MILLIS,
- retrySchedulingExecutor);
}
@AfterEach
void tearDown() throws Exception {
- executorService.shutdown();
- executorService.awaitTermination(2, TimeUnit.SECONDS);
+ sharedExecutorService.shutdown();
+ sharedExecutorService.awaitTermination(2, TimeUnit.SECONDS);
retrySchedulingExecutor.shutdown();
retrySchedulingExecutor.awaitTermination(2, TimeUnit.SECONDS);
}
- @Test
- void testProcessStoredMessages() {
- final int persistedMessageCount = 207;
- final int cachedMessageCount = 173;
+ @ParameterizedTest
+ @CsvSource({
+ "207, 173, true",
+ "207, 173, false",
+ "323, 0, true",
+ "323, 0, false",
+ "0, 221, true",
+ "0, 221, false",
+ })
+ void testProcessStoredMessages(final int persistedMessageCount, final int cachedMessageCount,
+ final boolean useReactive) {
+ final WebSocketConnection webSocketConnection = new WebSocketConnection(
+ mock(ReceiptSender.class),
+ new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager),
+ new AuthenticatedAccount(() -> new Pair<>(account, device)),
+ device,
+ webSocketClient,
+ retrySchedulingExecutor,
+ useReactive);
final List expectedMessages = new ArrayList<>(persistedMessageCount + cachedMessageCount);
@@ -150,8 +158,8 @@ class WebSocketConnectionIntegrationTest {
final AtomicBoolean queueCleared = new AtomicBoolean(false);
when(successResponse.getStatus()).thenReturn(200);
- when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any())).thenReturn(
- CompletableFuture.completedFuture(successResponse));
+ when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/message"), anyList(), any()))
+ .thenReturn(CompletableFuture.completedFuture(successResponse));
when(webSocketClient.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), anyList(), any())).thenAnswer(
(Answer>) invocation -> {
@@ -194,8 +202,18 @@ class WebSocketConnectionIntegrationTest {
});
}
- @Test
- void testProcessStoredMessagesClientClosed() {
+ @ParameterizedTest
+ @ValueSource(booleans = {true, false})
+ void testProcessStoredMessagesClientClosed(final boolean useReactive) {
+ final WebSocketConnection webSocketConnection = new WebSocketConnection(
+ mock(ReceiptSender.class),
+ new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager),
+ new AuthenticatedAccount(() -> new Pair<>(account, device)),
+ device,
+ webSocketClient,
+ retrySchedulingExecutor,
+ useReactive);
+
final int persistedMessageCount = 207;
final int cachedMessageCount = 173;
@@ -250,8 +268,20 @@ class WebSocketConnectionIntegrationTest {
});
}
- @Test
- void testProcessStoredMessagesSendFutureTimeout() {
+ @ParameterizedTest
+ @ValueSource(booleans = {true, false})
+ void testProcessStoredMessagesSendFutureTimeout(final boolean useReactive) {
+ final WebSocketConnection webSocketConnection = new WebSocketConnection(
+ mock(ReceiptSender.class),
+ new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager),
+ new AuthenticatedAccount(() -> new Pair<>(account, device)),
+ device,
+ webSocketClient,
+ 100, // use a very short timeout, so that this test completes quickly
+ retrySchedulingExecutor,
+ useReactive,
+ Schedulers.boundedElastic());
+
final int persistedMessageCount = 207;
final int cachedMessageCount = 173;
@@ -346,4 +376,5 @@ class WebSocketConnectionIntegrationTest {
.setDestinationUuid(UUID.randomUUID().toString())
.build();
}
+
}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java
index 5f4faa293..393526e78 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2013-2021 Signal Messenger, LLC
+ * Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@@ -12,6 +12,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.ArgumentMatchers.nullable;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.anyLong;
@@ -42,17 +43,20 @@ import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
-import org.apache.commons.lang3.RandomStringUtils;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.stream.Stream;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
+import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
-import org.mockito.ArgumentMatchers;
-import org.mockito.invocation.InvocationOnMock;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
-import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
-import org.whispersystems.textsecuregcm.push.ApnPushNotificationScheduler;
+import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
@@ -65,6 +69,10 @@ import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.session.WebSocketSessionContext;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.FluxSink;
+import reactor.core.scheduler.Schedulers;
+import reactor.test.StepVerifier;
class WebSocketConnectionTest {
@@ -83,7 +91,6 @@ class WebSocketConnectionTest {
private AuthenticatedAccount auth;
private UpgradeRequest upgradeRequest;
private ReceiptSender receiptSender;
- private PushNotificationManager pushNotificationManager;
private ScheduledExecutorService retrySchedulingExecutor;
@BeforeEach
@@ -95,17 +102,21 @@ class WebSocketConnectionTest {
auth = new AuthenticatedAccount(() -> new Pair<>(account, device));
upgradeRequest = mock(UpgradeRequest.class);
receiptSender = mock(ReceiptSender.class);
- pushNotificationManager = mock(PushNotificationManager.class);
retrySchedulingExecutor = mock(ScheduledExecutorService.class);
}
+ @AfterEach
+ void teardown() {
+ StepVerifier.resetDefaultTimeout();
+ }
+
@Test
void testCredentials() {
MessagesManager storedMessages = mock(MessagesManager.class);
WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages,
mock(PushNotificationManager.class), mock(ClientPresenceManager.class),
- retrySchedulingExecutor);
+ retrySchedulingExecutor, mock(ExperimentEnrollmentManager.class));
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
@@ -114,7 +125,6 @@ class WebSocketConnectionTest {
when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD))))
.thenReturn(Optional.empty());
-
when(upgradeRequest.getParameterMap()).thenReturn(Map.of(
"login", List.of(VALID_USER),
"password", List.of(VALID_PASSWORD)));
@@ -136,8 +146,9 @@ class WebSocketConnectionTest {
assertTrue(account.isRequired());
}
- @Test
- void testOpen() throws Exception {
+ @ParameterizedTest
+ @ValueSource(booleans = {true, false})
+ void testOpen(final boolean useReactive) throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class);
UUID accountUuid = UUID.randomUUID();
@@ -166,29 +177,31 @@ class WebSocketConnectionTest {
String userAgent = "user-agent";
- when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false))
- .thenReturn(new Pair<>(outgoingMessages, false));
+ if (useReactive) {
+ when(storedMessages.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false))
+ .thenReturn(Flux.fromIterable(outgoingMessages));
+ } else {
+ when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false))
+ .thenReturn(new Pair<>(outgoingMessages, false));
+ }
final List> futures = new LinkedList<>();
- final WebSocketClient client = mock(WebSocketClient.class);
+ final WebSocketClient client = mock(WebSocketClient.class);
when(client.getUserAgent()).thenReturn(userAgent);
- when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.>any()))
- .thenAnswer(new Answer>() {
- @Override
- public CompletableFuture answer(InvocationOnMock invocationOnMock) {
- CompletableFuture future = new CompletableFuture<>();
- futures.add(future);
- return future;
- }
+ when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), nullable(List.class), any()))
+ .thenAnswer(invocation -> {
+ CompletableFuture future = new CompletableFuture<>();
+ futures.add(future);
+ return future;
});
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages,
- auth, device, client, retrySchedulingExecutor);
+ auth, device, client, retrySchedulingExecutor, useReactive, Schedulers.immediate());
connection.start();
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class),
- ArgumentMatchers.>any());
+ any());
assertEquals(3, futures.size());
@@ -208,12 +221,13 @@ class WebSocketConnectionTest {
verify(client).close(anyInt(), anyString());
}
- @Test
- public void testOnlineSend() {
+ @ParameterizedTest
+ @ValueSource(booleans = {true, false})
+ public void testOnlineSend(final boolean useReactive) {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
- retrySchedulingExecutor);
+ retrySchedulingExecutor, useReactive, Schedulers.immediate());
final UUID accountUuid = UUID.randomUUID();
@@ -222,24 +236,36 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
- when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean()))
- .thenReturn(new Pair<>(Collections.emptyList(), false))
- .thenReturn(new Pair<>(List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first")), false))
- .thenReturn(new Pair<>(List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 2222, "second")), false));
+ if (useReactive) {
+ when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean()))
+ .thenReturn(Flux.empty())
+ .thenReturn(Flux.just(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first")))
+ .thenReturn(Flux.just(createMessage(UUID.randomUUID(), UUID.randomUUID(), 2222, "second")))
+ .thenReturn(Flux.empty());
+ } else {
+ when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean()))
+ .thenReturn(new Pair<>(Collections.emptyList(), false))
+ .thenReturn(new Pair<>(List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first")),
+ false))
+ .thenReturn(new Pair<>(List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 2222, "second")),
+ false))
+ .thenReturn(new Pair<>(Collections.emptyList(), false));
+ }
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
final AtomicInteger sendCounter = new AtomicInteger(0);
- when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer((Answer>)invocation -> {
- synchronized (sendCounter) {
- sendCounter.incrementAndGet();
- sendCounter.notifyAll();
- }
+ when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class)))
+ .thenAnswer(invocation -> {
+ synchronized (sendCounter) {
+ sendCounter.incrementAndGet();
+ sendCounter.notifyAll();
+ }
- return CompletableFuture.completedFuture(successResponse);
- });
+ return CompletableFuture.completedFuture(successResponse);
+ });
assertTimeoutPreemptively(Duration.ofSeconds(5), () -> {
// This is a little hacky and non-obvious, but because the first call to getMessagesForDevice returns empty list of
@@ -269,9 +295,10 @@ class WebSocketConnectionTest {
verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class));
}
- @Test
- void testPendingSend() throws Exception {
- MessagesManager storedMessages = mock(MessagesManager.class);
+ @ParameterizedTest
+ @ValueSource(booleans = {true, false})
+ void testPendingSend(final boolean useReactive) throws Exception {
+ MessagesManager storedMessages = mock(MessagesManager.class);
final UUID accountUuid = UUID.randomUUID();
final UUID senderTwoUuid = UUID.randomUUID();
@@ -311,15 +338,20 @@ class WebSocketConnectionTest {
when(sender1.getDevices()).thenReturn(sender1devices);
when(accountsManager.getByE164("sender1")).thenReturn(Optional.of(sender1));
- when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty());
+ when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty());
String userAgent = "user-agent";
- when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false))
- .thenReturn(new Pair<>(pendingMessages, false));
+ if (useReactive) {
+ when(storedMessages.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false))
+ .thenReturn(Flux.fromIterable(pendingMessages));
+ } else {
+ when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false))
+ .thenReturn(new Pair<>(pendingMessages, false));
+ }
final List> futures = new LinkedList<>();
- final WebSocketClient client = mock(WebSocketClient.class);
+ final WebSocketClient client = mock(WebSocketClient.class);
when(client.getUserAgent()).thenReturn(userAgent);
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(), any()))
@@ -330,7 +362,7 @@ class WebSocketConnectionTest {
});
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages,
- auth, device, client, retrySchedulingExecutor);
+ auth, device, client, retrySchedulingExecutor, useReactive, Schedulers.immediate());
connection.start();
@@ -350,12 +382,13 @@ class WebSocketConnectionTest {
verify(client).close(anyInt(), anyString());
}
- @Test
- void testProcessStoredMessageConcurrency() throws InterruptedException {
+ @ParameterizedTest
+ @ValueSource(booleans = {true, false})
+ void testProcessStoredMessageConcurrency(final boolean useReactive) {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
- retrySchedulingExecutor);
+ retrySchedulingExecutor, useReactive, Schedulers.immediate());
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
@@ -365,26 +398,45 @@ class WebSocketConnectionTest {
final AtomicBoolean threadWaiting = new AtomicBoolean(false);
final AtomicBoolean returnMessageList = new AtomicBoolean(false);
- when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, false)).thenAnswer(
- (Answer) invocation -> {
- synchronized (threadWaiting) {
- threadWaiting.set(true);
- threadWaiting.notifyAll();
- }
+ if (useReactive) {
+ when(
+ messagesManager.getMessagesForDeviceReactive(account.getUuid(), 1L, false))
+ .thenAnswer(invocation -> {
+ synchronized (threadWaiting) {
+ threadWaiting.set(true);
+ threadWaiting.notifyAll();
+ }
- synchronized (returnMessageList) {
- while (!returnMessageList.get()) {
- returnMessageList.wait();
- }
- }
+ synchronized (returnMessageList) {
+ while (!returnMessageList.get()) {
+ returnMessageList.wait();
+ }
+ }
- return new OutgoingMessageEntityList(Collections.emptyList(), false);
- });
+ return Flux.empty();
+ });
+ } else {
+ when(
+ messagesManager.getMessagesForDevice(account.getUuid(), 1L, false))
+ .thenAnswer(invocation -> {
+ synchronized (threadWaiting) {
+ threadWaiting.set(true);
+ threadWaiting.notifyAll();
+ }
- final Thread[] threads = new Thread[10];
+ synchronized (returnMessageList) {
+ while (!returnMessageList.get()) {
+ returnMessageList.wait();
+ }
+ }
+
+ return new Pair<>(Collections.emptyList(), false);
+ });
+ }
+
+ final Thread[] threads = new Thread[10];
final CountDownLatch unblockedThreadsLatch = new CountDownLatch(threads.length - 1);
-
assertTimeoutPreemptively(Duration.ofSeconds(5), () -> {
for (int i = 0; i < threads.length; i++) {
threads[i] = new Thread(() -> {
@@ -413,18 +465,24 @@ class WebSocketConnectionTest {
}
});
- verify(messagesManager).getMessagesForDevice(any(UUID.class), anyLong(), eq(false));
+ if (useReactive) {
+ verify(messagesManager).getMessagesForDeviceReactive(any(UUID.class), anyLong(), eq(false));
+ } else {
+ verify(messagesManager).getMessagesForDevice(any(UUID.class), anyLong(), eq(false));
+ }
}
- @Test
- void testProcessStoredMessagesMultiplePages() {
+ @ParameterizedTest
+ @ValueSource(booleans = {true, false})
+ void testProcessStoredMessagesMultiplePages(final boolean useReactive) {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
- retrySchedulingExecutor);
+ retrySchedulingExecutor, useReactive, Schedulers.immediate());
when(account.getNumber()).thenReturn("+18005551234");
- when(account.getUuid()).thenReturn(UUID.randomUUID());
+ final UUID accountUuid = UUID.randomUUID();
+ when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
@@ -435,39 +493,56 @@ class WebSocketConnectionTest {
final List secondPageMessages =
List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 3333, "third"));
- when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, false))
- .thenReturn(new Pair<>(firstPageMessages, true))
- .thenReturn(new Pair<>(secondPageMessages, false));
+ if (useReactive) {
+ when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), eq(false)))
+ .thenReturn(Flux.fromStream(Stream.concat(firstPageMessages.stream(), secondPageMessages.stream())));
+ } else {
+ when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq(false)))
+ .thenReturn(new Pair<>(firstPageMessages, true))
+ .thenReturn(new Pair<>(secondPageMessages, false));
+ }
+
+ when(messagesManager.delete(eq(accountUuid), eq(1L), any(), any()))
+ .thenReturn(CompletableFuture.completedFuture(null));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
- final CountDownLatch sendLatch = new CountDownLatch(firstPageMessages.size() + secondPageMessages.size());
+ final CountDownLatch queueEmptyLatch = new CountDownLatch(1);
- when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer((Answer>)invocation -> {
- sendLatch.countDown();
- return CompletableFuture.completedFuture(successResponse);
- });
+ when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class)))
+ .thenAnswer(invocation -> {
+ return CompletableFuture.completedFuture(successResponse);
+ });
+
+ when(client.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())))
+ .thenAnswer(invocation -> {
+ queueEmptyLatch.countDown();
+ return CompletableFuture.completedFuture(successResponse);
+ });
assertTimeoutPreemptively(Duration.ofSeconds(5), () -> {
connection.processStoredMessages();
- sendLatch.await();
+ queueEmptyLatch.await();
});
- verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class));
+ verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"),
+ eq("/api/v1/message"), any(List.class), any(Optional.class));
verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()));
}
- @Test
- void testProcessStoredMessagesContainsSenderUuid() {
+ @ParameterizedTest
+ @ValueSource(booleans = {true, false})
+ void testProcessStoredMessagesContainsSenderUuid(final boolean useReactive) {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
- retrySchedulingExecutor);
+ retrySchedulingExecutor, useReactive, Schedulers.immediate());
when(account.getNumber()).thenReturn("+18005551234");
- when(account.getUuid()).thenReturn(UUID.randomUUID());
+ final UUID accountUuid = UUID.randomUUID();
+ when(account.getUuid()).thenReturn(accountUuid);
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
@@ -475,50 +550,65 @@ class WebSocketConnectionTest {
final List messages = List.of(
createMessage(senderUuid, UUID.randomUUID(), 1111L, "message the first"));
- when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, false))
- .thenReturn(new Pair<>(messages, false));
+ if (useReactive) {
+ when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), 1L, false))
+ .thenReturn(Flux.fromIterable(messages))
+ .thenReturn(Flux.empty());
+ } else {
+ when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, false))
+ .thenReturn(new Pair<>(messages, false))
+ .thenReturn(new Pair<>(Collections.emptyList(), false));
+ }
+
+ when(messagesManager.delete(eq(accountUuid), eq(1L), any(UUID.class), any()))
+ .thenReturn(CompletableFuture.completedFuture(null));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
- final CountDownLatch sendLatch = new CountDownLatch(messages.size());
+ final CountDownLatch queueEmptyLatch = new CountDownLatch(1);
- when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer(invocation -> {
- sendLatch.countDown();
- return CompletableFuture.completedFuture(successResponse);
- });
+ when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer(
+ invocation -> CompletableFuture.completedFuture(successResponse));
+
+ when(client.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())))
+ .thenAnswer(invocation -> {
+ queueEmptyLatch.countDown();
+ return CompletableFuture.completedFuture(successResponse);
+ });
assertTimeoutPreemptively(Duration.ofSeconds(5), () -> {
connection.processStoredMessages();
-
- sendLatch.await();
+ queueEmptyLatch.await();
});
- verify(client, times(messages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), argThat(argument -> {
- if (argument.isEmpty()) {
- return false;
- }
+ verify(client, times(messages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class),
+ argThat(argument -> {
+ if (argument.isEmpty()) {
+ return false;
+ }
- final byte[] body = argument.get();
- try {
- final Envelope envelope = Envelope.parseFrom(body);
- if (!envelope.hasSourceUuid() || envelope.getSourceUuid().length() == 0) {
- return false;
- }
- return envelope.getSourceUuid().equals(senderUuid.toString());
- } catch (InvalidProtocolBufferException e) {
- return false;
- }
- }));
+ final byte[] body = argument.get();
+ try {
+ final Envelope envelope = Envelope.parseFrom(body);
+ if (!envelope.hasSourceUuid() || envelope.getSourceUuid().length() == 0) {
+ return false;
+ }
+ return envelope.getSourceUuid().equals(senderUuid.toString());
+ } catch (InvalidProtocolBufferException e) {
+ return false;
+ }
+ }));
verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()));
}
- @Test
- void testProcessStoredMessagesSingleEmptyCall() {
+ @ParameterizedTest
+ @ValueSource(booleans = {true, false})
+ void testProcessStoredMessagesSingleEmptyCall(final boolean useReactive) {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
- retrySchedulingExecutor);
+ retrySchedulingExecutor, useReactive, Schedulers.immediate());
final UUID accountUuid = UUID.randomUUID();
@@ -527,8 +617,13 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
- when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean()))
- .thenReturn(new Pair<>(Collections.emptyList(), false));
+ if (useReactive) {
+ when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean()))
+ .thenReturn(Flux.empty());
+ } else {
+ when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean()))
+ .thenReturn(new Pair<>(Collections.emptyList(), false));
+ }
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
@@ -543,12 +638,13 @@ class WebSocketConnectionTest {
verify(client, times(1)).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()));
}
- @Test
- public void testRequeryOnStateMismatch() {
+ @ParameterizedTest
+ @ValueSource(booleans = {true, false})
+ public void testRequeryOnStateMismatch(final boolean useReactive) {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
- retrySchedulingExecutor);
+ retrySchedulingExecutor, useReactive, Schedulers.immediate());
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
@@ -563,39 +659,57 @@ class WebSocketConnectionTest {
final List secondPageMessages =
List.of(createMessage(UUID.randomUUID(), UUID.randomUUID(), 3333, "third"));
- when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean()))
- .thenReturn(new Pair<>(firstPageMessages, false))
- .thenReturn(new Pair<>(secondPageMessages, false))
- .thenReturn(new Pair<>(Collections.emptyList(), false));
+ if (useReactive) {
+ when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean()))
+ .thenReturn(Flux.fromIterable(firstPageMessages))
+ .thenReturn(Flux.fromIterable(secondPageMessages))
+ .thenReturn(Flux.empty());
+ } else {
+ when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean()))
+ .thenReturn(new Pair<>(firstPageMessages, false))
+ .thenReturn(new Pair<>(secondPageMessages, false))
+ .thenReturn(new Pair<>(Collections.emptyList(), false));
+ }
+
+ when(messagesManager.delete(eq(accountUuid), eq(1L), any(), any()))
+ .thenReturn(CompletableFuture.completedFuture(null));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
- final CountDownLatch sendLatch = new CountDownLatch(firstPageMessages.size() + secondPageMessages.size());
+ final CountDownLatch queueEmptyLatch = new CountDownLatch(1);
- when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class))).thenAnswer((Answer>)invocation -> {
- connection.handleNewMessagesAvailable();
- sendLatch.countDown();
+ when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class)))
+ .thenAnswer(invocation -> {
+ connection.handleNewMessagesAvailable();
- return CompletableFuture.completedFuture(successResponse);
- });
+ return CompletableFuture.completedFuture(successResponse);
+ });
+
+ when(client.sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty())))
+ .thenAnswer(invocation -> {
+ queueEmptyLatch.countDown();
+ return CompletableFuture.completedFuture(successResponse);
+ });
assertTimeoutPreemptively(Duration.ofSeconds(5), () -> {
connection.processStoredMessages();
- sendLatch.await();
+ queueEmptyLatch.await();
});
- verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), any(Optional.class));
+ verify(client, times(firstPageMessages.size() + secondPageMessages.size())).sendRequest(eq("PUT"),
+ eq("/api/v1/message"), any(List.class), any(Optional.class));
verify(client).sendRequest(eq("PUT"), eq("/api/v1/queue/empty"), any(List.class), eq(Optional.empty()));
}
- @Test
- void testProcessCachedMessagesOnly() {
+ @ParameterizedTest
+ @ValueSource(booleans = {true, false})
+ void testProcessCachedMessagesOnly(final boolean useReactive) {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
- retrySchedulingExecutor);
+ retrySchedulingExecutor, useReactive, Schedulers.immediate());
final UUID accountUuid = UUID.randomUUID();
@@ -604,8 +718,13 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
- when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean()))
- .thenReturn(new Pair<>(Collections.emptyList(), false));
+ if (useReactive) {
+ when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean()))
+ .thenReturn(Flux.empty());
+ } else {
+ when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean()))
+ .thenReturn(new Pair<>(Collections.emptyList(), false));
+ }
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
@@ -616,19 +735,28 @@ class WebSocketConnectionTest {
// anything.
connection.processStoredMessages();
- verify(messagesManager).getMessagesForDevice(account.getUuid(), device.getId(), false);
+ if (useReactive) {
+ verify(messagesManager).getMessagesForDeviceReactive(account.getUuid(), device.getId(), false);
+ } else {
+ verify(messagesManager).getMessagesForDevice(account.getUuid(), device.getId(), false);
+ }
connection.handleNewMessagesAvailable();
- verify(messagesManager).getMessagesForDevice(account.getUuid(), device.getId(), true);
+ if (useReactive) {
+ verify(messagesManager).getMessagesForDeviceReactive(account.getUuid(), device.getId(), true);
+ } else {
+ verify(messagesManager).getMessagesForDevice(account.getUuid(), device.getId(), true);
+ }
}
- @Test
- void testProcessDatabaseMessagesAfterPersist() {
+ @ParameterizedTest
+ @ValueSource(booleans = {true, false})
+ void testProcessDatabaseMessagesAfterPersist(final boolean useReactive) {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
- retrySchedulingExecutor);
+ retrySchedulingExecutor, useReactive, Schedulers.immediate());
final UUID accountUuid = UUID.randomUUID();
@@ -637,8 +765,13 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(1L);
when(client.isOpen()).thenReturn(true);
- when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean()))
- .thenReturn(new Pair<>(Collections.emptyList(), false));
+ if (useReactive) {
+ when(messagesManager.getMessagesForDeviceReactive(eq(accountUuid), eq(1L), anyBoolean()))
+ .thenReturn(Flux.empty());
+ } else {
+ when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), anyBoolean()))
+ .thenReturn(new Pair<>(Collections.emptyList(), false));
+ }
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
@@ -650,151 +783,16 @@ class WebSocketConnectionTest {
connection.processStoredMessages();
connection.handleMessagesPersisted();
- verify(messagesManager, times(2)).getMessagesForDevice(account.getUuid(), device.getId(), false);
+ if (useReactive) {
+ verify(messagesManager, times(2)).getMessagesForDeviceReactive(account.getUuid(), device.getId(), false);
+ } else {
+ verify(messagesManager, times(2)).getMessagesForDevice(account.getUuid(), device.getId(), false);
+ }
}
- @Test
- void testDiscardOversizedMessagesForDesktop() {
- MessagesManager storedMessages = mock(MessagesManager.class);
-
- UUID accountUuid = UUID.randomUUID();
- UUID senderOneUuid = UUID.randomUUID();
- UUID senderTwoUuid = UUID.randomUUID();
-
- List outgoingMessages = List.of(
- createMessage(senderOneUuid, UUID.randomUUID(), 1111, "first"),
- createMessage(senderOneUuid, UUID.randomUUID(), 2222,
- RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1)),
- createMessage(senderTwoUuid, UUID.randomUUID(), 3333, "third"));
-
- when(device.getId()).thenReturn(2L);
-
- when(account.getNumber()).thenReturn("+14152222222");
- when(account.getUuid()).thenReturn(accountUuid);
-
- final Device sender1device = mock(Device.class);
-
- List sender1devices = List.of(sender1device);
-
- Account sender1 = mock(Account.class);
- when(sender1.getDevices()).thenReturn(sender1devices);
-
- when(accountsManager.getByE164("sender1")).thenReturn(Optional.of(sender1));
- when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty());
-
- String userAgent = "Signal-Desktop/1.2.3";
-
- when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false))
- .thenReturn(new Pair<>(outgoingMessages, false));
-
- final List> futures = new LinkedList<>();
- final WebSocketClient client = mock(WebSocketClient.class);
-
- when(client.getUserAgent()).thenReturn(userAgent);
- when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class),
- ArgumentMatchers.>any()))
- .thenAnswer(new Answer>() {
- @Override
- public CompletableFuture answer(InvocationOnMock invocationOnMock) {
- CompletableFuture future = new CompletableFuture<>();
- futures.add(future);
- return future;
- }
- });
-
- WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client,
- retrySchedulingExecutor);
-
- connection.start();
- verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class),
- ArgumentMatchers.>any());
-
- assertEquals(2, futures.size());
-
- WebSocketResponseMessage response = mock(WebSocketResponseMessage.class);
- when(response.getStatus()).thenReturn(200);
- futures.get(0).complete(response);
- futures.get(1).complete(response);
-
- // We should delete all three messages even though we only sent two; one got discarded because it was too big for
- // desktop clients.
- verify(storedMessages, times(3)).delete(eq(accountUuid), eq(2L), any(UUID.class), any(Long.class));
-
- connection.stop();
- verify(client).close(anyInt(), anyString());
- }
-
- @Test
- void testSendOversizedMessagesForNonDesktop() {
- MessagesManager storedMessages = mock(MessagesManager.class);
-
- UUID accountUuid = UUID.randomUUID();
- UUID senderOneUuid = UUID.randomUUID();
- UUID senderTwoUuid = UUID.randomUUID();
-
- List outgoingMessages = List.of(createMessage(senderOneUuid, UUID.randomUUID(), 1111, "first"),
- createMessage(senderOneUuid, UUID.randomUUID(), 2222,
- RandomStringUtils.randomAlphanumeric(WebSocketConnection.MAX_DESKTOP_MESSAGE_SIZE + 1)),
- createMessage(senderTwoUuid, UUID.randomUUID(), 3333, "third"));
-
- when(device.getId()).thenReturn(2L);
-
- when(account.getNumber()).thenReturn("+14152222222");
- when(account.getUuid()).thenReturn(accountUuid);
-
- final Device sender1device = mock(Device.class);
-
- List sender1devices = List.of(sender1device);
-
- Account sender1 = mock(Account.class);
- when(sender1.getDevices()).thenReturn(sender1devices);
-
- when(accountsManager.getByE164("sender1")).thenReturn(Optional.of(sender1));
- when(accountsManager.getByE164("sender2")).thenReturn(Optional.empty());
-
- String userAgent = "Signal-Android/4.68.3";
-
- when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false))
- .thenReturn(new Pair<>(outgoingMessages, false));
-
- final List> futures = new LinkedList<>();
- final WebSocketClient client = mock(WebSocketClient.class);
-
- when(client.getUserAgent()).thenReturn(userAgent);
- when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class),
- ArgumentMatchers.>any()))
- .thenAnswer(new Answer>() {
- @Override
- public CompletableFuture answer(InvocationOnMock invocationOnMock) {
- CompletableFuture future = new CompletableFuture<>();
- futures.add(future);
- return future;
- }
- });
-
- WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client,
- retrySchedulingExecutor);
-
- connection.start();
- verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class),
- ArgumentMatchers.>any());
-
- assertEquals(3, futures.size());
-
- WebSocketResponseMessage response = mock(WebSocketResponseMessage.class);
- when(response.getStatus()).thenReturn(200);
- futures.get(0).complete(response);
- futures.get(1).complete(response);
- futures.get(2).complete(response);
-
- verify(storedMessages, times(3)).delete(eq(accountUuid), eq(2L), any(UUID.class), any(Long.class));
-
- connection.stop();
- verify(client).close(anyInt(), anyString());
- }
-
- @Test
- void testRetrieveMessageException() {
+ @ParameterizedTest
+ @ValueSource(booleans = {true, false})
+ void testRetrieveMessageException(final boolean useReactive) {
MessagesManager storedMessages = mock(MessagesManager.class);
UUID accountUuid = UUID.randomUUID();
@@ -804,10 +802,13 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
- String userAgent = "Signal-Android/4.68.3";
-
- when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false))
- .thenThrow(new RedisException("OH NO"));
+ if (useReactive) {
+ when(storedMessages.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false))
+ .thenReturn(Flux.error(new RedisException("OH NO")));
+ } else {
+ when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false))
+ .thenThrow(new RedisException("OH NO"));
+ }
when(retrySchedulingExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer(
(Answer>) invocation -> {
@@ -819,7 +820,7 @@ class WebSocketConnectionTest {
when(client.isOpen()).thenReturn(true);
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client,
- retrySchedulingExecutor);
+ retrySchedulingExecutor, useReactive, Schedulers.immediate());
connection.start();
verify(retrySchedulingExecutor, times(WebSocketConnection.MAX_CONSECUTIVE_RETRIES)).schedule(any(Runnable.class),
@@ -827,8 +828,9 @@ class WebSocketConnectionTest {
verify(client).close(eq(1011), anyString());
}
- @Test
- void testRetrieveMessageExceptionClientDisconnected() {
+ @ParameterizedTest
+ @ValueSource(booleans = {true, false})
+ void testRetrieveMessageExceptionClientDisconnected(final boolean useReactive) {
MessagesManager storedMessages = mock(MessagesManager.class);
UUID accountUuid = UUID.randomUUID();
@@ -838,22 +840,143 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
- String userAgent = "Signal-Android/4.68.3";
-
- when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false))
- .thenThrow(new RedisException("OH NO"));
+ if (useReactive) {
+ when(storedMessages.getMessagesForDeviceReactive(account.getUuid(), device.getId(), false))
+ .thenReturn(Flux.error(new RedisException("OH NO")));
+ } else {
+ when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), false))
+ .thenThrow(new RedisException("OH NO"));
+ }
final WebSocketClient client = mock(WebSocketClient.class);
when(client.isOpen()).thenReturn(false);
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client,
- retrySchedulingExecutor);
+ retrySchedulingExecutor, useReactive, Schedulers.immediate());
connection.start();
verify(retrySchedulingExecutor, never()).schedule(any(Runnable.class), anyLong(), any());
verify(client, never()).close(anyInt(), anyString());
}
+ @Test
+ @Disabled("This test is flaky")
+ void testReactivePublisherLimitRate() {
+ MessagesManager storedMessages = mock(MessagesManager.class);
+
+ final UUID accountUuid = UUID.randomUUID();
+
+ final long deviceId = 2L;
+ when(device.getId()).thenReturn(deviceId);
+
+ when(account.getNumber()).thenReturn("+14152222222");
+ when(account.getUuid()).thenReturn(accountUuid);
+
+ final int totalMessages = 10;
+ final AtomicReference> sink = new AtomicReference<>();
+
+ final AtomicLong maxRequest = new AtomicLong(-1);
+ final Flux flux = Flux.create(s -> {
+ sink.set(s);
+ s.onRequest(n -> {
+ if (maxRequest.get() < n) {
+ maxRequest.set(n);
+ }
+ });
+ });
+
+ when(storedMessages.getMessagesForDeviceReactive(eq(accountUuid), eq(deviceId), anyBoolean()))
+ .thenReturn(flux);
+
+ final WebSocketClient client = mock(WebSocketClient.class);
+ when(client.isOpen()).thenReturn(true);
+ final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
+ when(successResponse.getStatus()).thenReturn(200);
+ when(client.sendRequest(any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(successResponse));
+ when(storedMessages.delete(any(), anyLong(), any(), any())).thenReturn(
+ CompletableFuture.completedFuture(Optional.empty()));
+
+ WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client,
+ retrySchedulingExecutor, true);
+
+ connection.start();
+
+ StepVerifier.setDefaultTimeout(Duration.ofSeconds(5));
+
+ StepVerifier.create(flux, 0)
+ .expectSubscription()
+ .thenRequest(totalMessages * 2)
+ .then(() -> {
+ for (long i = 0; i < totalMessages; i++) {
+ sink.get().next(createMessage(UUID.randomUUID(), accountUuid, 1111 * i + 1, "message " + i));
+ }
+ sink.get().complete();
+ })
+ .expectNextCount(totalMessages)
+ .expectComplete()
+ .log()
+ .verify();
+
+ assertEquals(WebSocketConnection.MESSAGE_PUBLISHER_LIMIT_RATE, maxRequest.get());
+ }
+
+ @Test
+ void testReactivePublisherDisposedWhenConnectionStopped() {
+ MessagesManager storedMessages = mock(MessagesManager.class);
+
+ final UUID accountUuid = UUID.randomUUID();
+
+ final long deviceId = 2L;
+ when(device.getId()).thenReturn(deviceId);
+
+ when(account.getNumber()).thenReturn("+14152222222");
+ when(account.getUuid()).thenReturn(accountUuid);
+
+ final AtomicBoolean canceled = new AtomicBoolean();
+
+ final Flux flux = Flux.create(s -> {
+ s.onRequest(n -> {
+ // the subscriber should request more than 1 message, but we will only send one, so that
+ // we are sure the subscriber is waiting for more when we stop the connection
+ assert n > 1;
+ s.next(createMessage(UUID.randomUUID(), UUID.randomUUID(), 1111, "first"));
+ });
+
+ s.onCancel(() -> canceled.set(true));
+ });
+ when(storedMessages.getMessagesForDeviceReactive(eq(accountUuid), eq(deviceId), anyBoolean()))
+ .thenReturn(flux);
+
+ final WebSocketClient client = mock(WebSocketClient.class);
+ when(client.isOpen()).thenReturn(true);
+ final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
+ when(successResponse.getStatus()).thenReturn(200);
+ when(client.sendRequest(any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(successResponse));
+ when(storedMessages.delete(any(), anyLong(), any(), any())).thenReturn(
+ CompletableFuture.completedFuture(Optional.empty()));
+
+ WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client,
+ retrySchedulingExecutor, true, Schedulers.immediate());
+
+ connection.start();
+
+ verify(client).sendRequest(any(), any(), any(), any());
+
+ // close the connection before the publisher completes
+ connection.stop();
+
+ StepVerifier.setDefaultTimeout(Duration.ofSeconds(2));
+
+ StepVerifier.create(flux)
+ .expectSubscription()
+ .expectNextCount(1)
+ .then(() -> assertTrue(canceled.get()))
+ // this is not entirely intuitive, but expecting a timeout is the recommendation for verifying cancellation
+ .expectTimeout(Duration.ofMillis(100))
+ .log()
+ .verify();
+ }
+
private Envelope createMessage(UUID senderUuid, UUID destinationUuid, long timestamp, String content) {
return Envelope.newBuilder()
.setServerGuid(UUID.randomUUID().toString())
diff --git a/service/src/test/resources/logback-test.xml b/service/src/test/resources/logback-test.xml
index 1c2c5d01c..b01f95f92 100644
--- a/service/src/test/resources/logback-test.xml
+++ b/service/src/test/resources/logback-test.xml
@@ -1,11 +1,14 @@