From b75873790786416db93b7b974ed27207de05ec5f Mon Sep 17 00:00:00 2001 From: Chris Eager <79161849+eager-signal@users.noreply.github.com> Date: Mon, 3 Jan 2022 14:59:39 -0800 Subject: [PATCH] Migrate remaining JUnit 4 Redis cluster tests to `RedisClusterExtension` * Increase redis cluster initialization wait to 10 seconds * Move to JUnit 5 `Assumptions` --- .../redis/AbstractRedisClusterTest.java | 194 ---------- .../redis/RedisClusterExtension.java | 8 +- ...AccountDatabaseCrawlerIntegrationTest.java | 25 +- .../storage/MessagePersisterTest.java | 340 +++++++++--------- .../storage/MessagesCacheTest.java | 143 ++++---- 5 files changed, 268 insertions(+), 442 deletions(-) delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java deleted file mode 100644 index 547817355..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Copyright 2013-2020 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.redis; - -import io.lettuce.core.RedisClient; -import io.lettuce.core.RedisException; -import io.lettuce.core.RedisURI; -import io.lettuce.core.api.StatefulRedisConnection; -import io.lettuce.core.api.sync.RedisCommands; -import io.lettuce.core.cluster.RedisClusterClient; -import io.lettuce.core.cluster.SlotHash; -import org.junit.After; -import org.junit.AfterClass; -import org.junit.Before; -import org.junit.BeforeClass; -import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; -import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; -import org.whispersystems.textsecuregcm.util.RedisClusterUtil; -import redis.embedded.RedisServer; - -import java.io.File; -import java.io.IOException; -import java.net.ServerSocket; -import java.time.Duration; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; - -import static org.junit.Assume.assumeFalse; - -/** - * An abstract base class that assembles a real (local!) Redis cluster and provides a client to that cluster for - * subclasses. - */ -public abstract class AbstractRedisClusterTest { - - private static final int NODE_COUNT = 2; - - private static RedisServer[] clusterNodes; - - private FaultTolerantRedisCluster redisCluster; - - @BeforeClass - public static void setUpBeforeClass() throws Exception { - assumeFalse(System.getProperty("os.name").equalsIgnoreCase("windows")); - - clusterNodes = new RedisServer[NODE_COUNT]; - - for (int i = 0; i < NODE_COUNT; i++) { - clusterNodes[i] = buildClusterNode(getNextRedisClusterPort()); - clusterNodes[i].start(); - } - - assembleCluster(clusterNodes); - } - - @Before - public void setUp() throws Exception { - final List urls = Arrays.stream(clusterNodes) - .map(node -> String.format("redis://127.0.0.1:%d", node.ports().get(0))) - .collect(Collectors.toList()); - - redisCluster = new FaultTolerantRedisCluster("test-cluster", - RedisClusterClient.create(urls.stream().map(RedisURI::create).collect(Collectors.toList())), - Duration.ofSeconds(2), - new CircuitBreakerConfiguration(), - new RetryConfiguration()); - - redisCluster.useCluster(connection -> { - boolean setAll = false; - - final String[] keys = new String[NODE_COUNT]; - - for (int i = 0; i < keys.length; i++) { - keys[i] = RedisClusterUtil.getMinimalHashTag(i * SlotHash.SLOT_COUNT / keys.length); - } - - while (!setAll) { - try { - for (final String key : keys) { - connection.sync().set(key, "warmup"); - } - - setAll = true; - } catch (final RedisException ignored) { - // Cluster isn't ready; wait and retry. - try { - Thread.sleep(500); - } catch (final InterruptedException ignored2) { - } - } - } - }); - - redisCluster.useCluster(connection -> connection.sync().flushall()); - } - - protected FaultTolerantRedisCluster getRedisCluster() { - return redisCluster; - } - - @After - public void tearDown() throws Exception { - redisCluster.shutdown(); - } - - @AfterClass - public static void tearDownAfterClass() { - for (final RedisServer node : clusterNodes) { - node.stop(); - } - } - - private static RedisServer buildClusterNode(final int port) throws IOException { - final File clusterConfigFile = File.createTempFile("redis", ".conf"); - clusterConfigFile.deleteOnExit(); - - return RedisServer.builder() - .setting("cluster-enabled yes") - .setting("cluster-config-file " + clusterConfigFile.getAbsolutePath()) - .setting("cluster-node-timeout 5000") - .setting("appendonly no") - .setting("dir " + System.getProperty("java.io.tmpdir")) - .port(port) - .build(); - } - - private static void assembleCluster(final RedisServer... nodes) throws InterruptedException { - final RedisClient meetClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[0].ports().get(0))); - - try { - final StatefulRedisConnection connection = meetClient.connect(); - final RedisCommands commands = connection.sync(); - - for (int i = 1; i < nodes.length; i++) { - commands.clusterMeet("127.0.0.1", nodes[i].ports().get(0)); - } - } finally { - meetClient.shutdown(); - } - - final int slotsPerNode = SlotHash.SLOT_COUNT / nodes.length; - - for (int i = 0; i < nodes.length; i++) { - final int startInclusive = i * slotsPerNode; - final int endExclusive = i == nodes.length - 1 ? SlotHash.SLOT_COUNT : (i + 1) * slotsPerNode; - - final RedisClient assignSlotClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[i].ports().get(0))); - - try (final StatefulRedisConnection assignSlotConnection = assignSlotClient.connect()) { - final int[] slots = new int[endExclusive - startInclusive]; - - for (int s = startInclusive; s < endExclusive; s++) { - slots[s - startInclusive] = s; - } - - assignSlotConnection.sync().clusterAddSlots(slots); - } finally { - assignSlotClient.shutdown(); - } - } - - final RedisClient waitClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[0].ports().get(0))); - - try (final StatefulRedisConnection connection = waitClient.connect()) { - // CLUSTER INFO gives us a big blob of key-value pairs, but the one we're interested in is `cluster_state`. - // According to https://redis.io/commands/cluster-info, `cluster_state:ok` means that the node is ready to - // receive queries, all slots are assigned, and a majority of master nodes are reachable. - while (!connection.sync().clusterInfo().contains("cluster_state:ok")) { - Thread.sleep(500); - } - } finally { - waitClient.shutdown(); - } - } - - public static int getNextRedisClusterPort() throws IOException { - final int MAX_ITERATIONS = 11_000; - int port; - for (int i = 0; i < MAX_ITERATIONS; i++) { - try (ServerSocket socket = new ServerSocket(0)) { - socket.setReuseAddress(false); - port = socket.getLocalPort(); - } - if (port < 55535) { - return port; - } - } - throw new IOException("Couldn't find an open port below 55,535 in " + MAX_ITERATIONS + " tries"); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/RedisClusterExtension.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/RedisClusterExtension.java index 6cc961fc9..96cde43c8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/RedisClusterExtension.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/RedisClusterExtension.java @@ -1,11 +1,11 @@ /* - * Copyright 2013-2021 Signal Messenger, LLC + * Copyright 2013-2022 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.redis; -import static org.junit.Assume.assumeFalse; +import static org.junit.jupiter.api.Assumptions.assumeFalse; import io.lettuce.core.RedisClient; import io.lettuce.core.RedisException; @@ -178,9 +178,9 @@ public class RedisClusterExtension implements BeforeAllCallback, BeforeEachCallb Thread.sleep(sleepMillis); tries++; - if (tries == 10) { + if (tries == 20) { throw new RuntimeException( - String.format("Timeout: Redis not ready after waiting %d milliseconds", sleepMillis)); + String.format("Timeout: Redis not ready after waiting %d milliseconds", tries * sleepMillis)); } } } finally { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountDatabaseCrawlerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountDatabaseCrawlerIntegrationTest.java index 636677fc7..2dcef4636 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountDatabaseCrawlerIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountDatabaseCrawlerIntegrationTest.java @@ -5,7 +5,7 @@ package org.whispersystems.textsecuregcm.storage; -import static org.junit.Assert.assertFalse; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doThrow; @@ -18,11 +18,15 @@ import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.UUID; -import org.junit.Before; -import org.junit.Test; -import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; -public class AccountDatabaseCrawlerIntegrationTest extends AbstractRedisClusterTest { +class AccountDatabaseCrawlerIntegrationTest { + + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); private static final UUID FIRST_UUID = UUID.fromString("82339e80-81cd-48e2-9ed2-ccd5dd262ad9"); private static final UUID SECOND_UUID = UUID.fromString("cc705c84-33cf-456b-8239-a6a34e2f561a"); @@ -38,9 +42,8 @@ public class AccountDatabaseCrawlerIntegrationTest extends AbstractRedisClusterT private static final int CHUNK_SIZE = 1; private static final long CHUNK_INTERVAL_MS = 0; - @Before - public void setUp() throws Exception { - super.setUp(); + @BeforeEach + void setUp() throws Exception { firstAccount = mock(Account.class); secondAccount = mock(Account.class); @@ -57,13 +60,13 @@ public class AccountDatabaseCrawlerIntegrationTest extends AbstractRedisClusterT .thenReturn(new AccountCrawlChunk(List.of(secondAccount), SECOND_UUID)) .thenReturn(new AccountCrawlChunk(Collections.emptyList(), null)); - final AccountDatabaseCrawlerCache crawlerCache = new AccountDatabaseCrawlerCache(getRedisCluster(), "test"); + final AccountDatabaseCrawlerCache crawlerCache = new AccountDatabaseCrawlerCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), "test"); accountDatabaseCrawler = new AccountDatabaseCrawler("test", accountsManager, crawlerCache, List.of(listener), CHUNK_SIZE, CHUNK_INTERVAL_MS); } @Test - public void testCrawlUninterrupted() throws AccountDatabaseCrawlerRestartException { + void testCrawlUninterrupted() throws AccountDatabaseCrawlerRestartException { assertFalse(accountDatabaseCrawler.doPeriodicWork()); assertFalse(accountDatabaseCrawler.doPeriodicWork()); assertFalse(accountDatabaseCrawler.doPeriodicWork()); @@ -79,7 +82,7 @@ public class AccountDatabaseCrawlerIntegrationTest extends AbstractRedisClusterT } @Test - public void testCrawlWithReset() throws AccountDatabaseCrawlerRestartException { + void testCrawlWithReset() throws AccountDatabaseCrawlerRestartException { doThrow(new AccountDatabaseCrawlerRestartException("OH NO")).doNothing() .when(listener).timeAndProcessCrawlChunk(Optional.empty(), List.of(firstAccount)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java index 77721ee76..2bf11dcbc 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java @@ -5,7 +5,7 @@ package org.whispersystems.textsecuregcm.storage; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; @@ -28,194 +28,204 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import org.apache.commons.lang3.RandomStringUtils; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.entities.MessageProtos; -import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; -public class MessagePersisterTest extends AbstractRedisClusterTest { +class MessagePersisterTest { - private ExecutorService notificationExecutorService; - private MessagesCache messagesCache; - private MessagesDynamoDb messagesDynamoDb; - private MessagePersister messagePersister; - private AccountsManager accountsManager; + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); - private static final UUID DESTINATION_ACCOUNT_UUID = UUID.randomUUID(); - private static final String DESTINATION_ACCOUNT_NUMBER = "+18005551234"; - private static final long DESTINATION_DEVICE_ID = 7; + private ExecutorService notificationExecutorService; + private MessagesCache messagesCache; + private MessagesDynamoDb messagesDynamoDb; + private MessagePersister messagePersister; + private AccountsManager accountsManager; - private static final Duration PERSIST_DELAY = Duration.ofMinutes(5); + private static final UUID DESTINATION_ACCOUNT_UUID = UUID.randomUUID(); + private static final String DESTINATION_ACCOUNT_NUMBER = "+18005551234"; + private static final long DESTINATION_DEVICE_ID = 7; - @Override - @Before - public void setUp() throws Exception { - super.setUp(); + private static final Duration PERSIST_DELAY = Duration.ofMinutes(5); - final MessagesManager messagesManager = mock(MessagesManager.class); - final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class); + @BeforeEach + void setUp() throws Exception { - messagesDynamoDb = mock(MessagesDynamoDb.class); - accountsManager = mock(AccountsManager.class); + final MessagesManager messagesManager = mock(MessagesManager.class); + final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class); - final Account account = mock(Account.class); + messagesDynamoDb = mock(MessagesDynamoDb.class); + accountsManager = mock(AccountsManager.class); - when(accountsManager.getByAccountIdentifier(DESTINATION_ACCOUNT_UUID)).thenReturn(Optional.of(account)); - when(account.getNumber()).thenReturn(DESTINATION_ACCOUNT_NUMBER); - when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); + final Account account = mock(Account.class); - notificationExecutorService = Executors.newSingleThreadExecutor(); - messagesCache = new MessagesCache(getRedisCluster(), getRedisCluster(), notificationExecutorService); - messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, dynamicConfigurationManager, PERSIST_DELAY); + when(accountsManager.getByAccountIdentifier(DESTINATION_ACCOUNT_UUID)).thenReturn(Optional.of(account)); + when(account.getNumber()).thenReturn(DESTINATION_ACCOUNT_NUMBER); + when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); - doAnswer(invocation -> { - final UUID destinationUuid = invocation.getArgument(0); - final long destinationDeviceId = invocation.getArgument(1); - final List messages = invocation.getArgument(2); + notificationExecutorService = Executors.newSingleThreadExecutor(); + messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), + REDIS_CLUSTER_EXTENSION.getRedisCluster(), notificationExecutorService); + messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, + dynamicConfigurationManager, PERSIST_DELAY); - messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId); + doAnswer(invocation -> { + final UUID destinationUuid = invocation.getArgument(0); + final long destinationDeviceId = invocation.getArgument(1); + final List messages = invocation.getArgument(2); - for (final MessageProtos.Envelope message : messages) { - messagesCache.remove(destinationUuid, destinationDeviceId, UUID.fromString(message.getServerGuid())); - } + messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId); - return null; - }).when(messagesManager).persistMessages(any(UUID.class), anyLong(), any()); + for (final MessageProtos.Envelope message : messages) { + messagesCache.remove(destinationUuid, destinationDeviceId, UUID.fromString(message.getServerGuid())); + } + + return null; + }).when(messagesManager).persistMessages(any(UUID.class), anyLong(), any()); + } + + @AfterEach + void tearDown() throws Exception { + notificationExecutorService.shutdown(); + notificationExecutorService.awaitTermination(1, TimeUnit.SECONDS); + } + + @Test + void testPersistNextQueuesNoQueues() { + messagePersister.persistNextQueues(Instant.now()); + + verify(accountsManager, never()).getByAccountIdentifier(any(UUID.class)); + } + + @Test + void testPersistNextQueuesSingleQueue() { + final String queueName = new String( + MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8); + final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7; + final Instant now = Instant.now(); + + insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now); + setNextSlotToPersist(SlotHash.getSlot(queueName)); + + messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay())); + + final ArgumentCaptor> messagesCaptor = ArgumentCaptor.forClass(List.class); + + verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), eq(DESTINATION_ACCOUNT_UUID), + eq(DESTINATION_DEVICE_ID)); + assertEquals(messageCount, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum()); + } + + @Test + void testPersistNextQueuesSingleQueueTooSoon() { + final String queueName = new String( + MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8); + final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7; + final Instant now = Instant.now(); + + insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now); + setNextSlotToPersist(SlotHash.getSlot(queueName)); + + messagePersister.persistNextQueues(now); + + verify(messagesDynamoDb, never()).store(any(), any(), anyLong()); + } + + @Test + void testPersistNextQueuesMultiplePages() { + final int slot = 7; + final int queueCount = (MessagePersister.QUEUE_BATCH_LIMIT * 3) + 7; + final int messagesPerQueue = 10; + final Instant now = Instant.now(); + + for (int i = 0; i < queueCount; i++) { + final String queueName = generateRandomQueueNameForSlot(slot); + final UUID accountUuid = MessagesCache.getAccountUuidFromQueueName(queueName); + final long deviceId = MessagesCache.getDeviceIdFromQueueName(queueName); + final String accountNumber = "+1" + RandomStringUtils.randomNumeric(10); + + final Account account = mock(Account.class); + + when(accountsManager.getByAccountIdentifier(accountUuid)).thenReturn(Optional.of(account)); + when(account.getNumber()).thenReturn(accountNumber); + + insertMessages(accountUuid, deviceId, messagesPerQueue, now); } - @Override - public void tearDown() throws Exception { - super.tearDown(); + setNextSlotToPersist(slot); - notificationExecutorService.shutdown(); - notificationExecutorService.awaitTermination(1, TimeUnit.SECONDS); + messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay())); + + final ArgumentCaptor> messagesCaptor = ArgumentCaptor.forClass(List.class); + + verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), any(UUID.class), anyLong()); + assertEquals(queueCount * messagesPerQueue, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum()); + } + + @Test + void testPersistQueueRetry() { + final String queueName = new String( + MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8); + final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7; + final Instant now = Instant.now(); + + insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now); + setNextSlotToPersist(SlotHash.getSlot(queueName)); + + doAnswer((Answer) invocation -> { + throw new RuntimeException("OH NO."); + }).when(messagesDynamoDb).store(any(), eq(DESTINATION_ACCOUNT_UUID), eq(DESTINATION_DEVICE_ID)); + + messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay())); + + assertEquals(List.of(queueName), + messagesCache.getQueuesToPersist(SlotHash.getSlot(queueName), + Instant.now().plus(messagePersister.getPersistDelay()), 1)); + } + + @SuppressWarnings("SameParameterValue") + private static String generateRandomQueueNameForSlot(final int slot) { + final UUID uuid = UUID.randomUUID(); + + final String queueNameBase = "user_queue::{" + uuid + "::"; + + for (int deviceId = 0; deviceId < Integer.MAX_VALUE; deviceId++) { + final String queueName = queueNameBase + deviceId + "}"; + + if (SlotHash.getSlot(queueName) == slot) { + return queueName; + } } - @Test - public void testPersistNextQueuesNoQueues() { - messagePersister.persistNextQueues(Instant.now()); + throw new IllegalStateException("Could not find a queue name for slot " + slot); + } - verify(accountsManager, never()).getByAccountIdentifier(any(UUID.class)); + private void insertMessages(final UUID accountUuid, final long deviceId, final int messageCount, + final Instant firstMessageTimestamp) { + for (int i = 0; i < messageCount; i++) { + final UUID messageGuid = UUID.randomUUID(); + + final MessageProtos.Envelope envelope = MessageProtos.Envelope.newBuilder() + .setTimestamp(firstMessageTimestamp.toEpochMilli() + i) + .setServerTimestamp(firstMessageTimestamp.toEpochMilli() + i) + .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) + .setType(MessageProtos.Envelope.Type.CIPHERTEXT) + .setServerGuid(messageGuid.toString()) + .build(); + + messagesCache.insert(messageGuid, accountUuid, deviceId, envelope); } + } - @Test - public void testPersistNextQueuesSingleQueue() { - final String queueName = new String(MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8); - final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7; - final Instant now = Instant.now(); - - insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now); - setNextSlotToPersist(SlotHash.getSlot(queueName)); - - messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay())); - - final ArgumentCaptor> messagesCaptor = ArgumentCaptor.forClass(List.class); - - verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), eq(DESTINATION_ACCOUNT_UUID), eq(DESTINATION_DEVICE_ID)); - assertEquals(messageCount, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum()); - } - - @Test - public void testPersistNextQueuesSingleQueueTooSoon() { - final String queueName = new String(MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8); - final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7; - final Instant now = Instant.now(); - - insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now); - setNextSlotToPersist(SlotHash.getSlot(queueName)); - - messagePersister.persistNextQueues(now); - - verify(messagesDynamoDb, never()).store(any(), any(), anyLong()); - } - - @Test - public void testPersistNextQueuesMultiplePages() { - final int slot = 7; - final int queueCount = (MessagePersister.QUEUE_BATCH_LIMIT * 3) + 7; - final int messagesPerQueue = 10; - final Instant now = Instant.now(); - - for (int i = 0; i < queueCount; i++) { - final String queueName = generateRandomQueueNameForSlot(slot); - final UUID accountUuid = MessagesCache.getAccountUuidFromQueueName(queueName); - final long deviceId = MessagesCache.getDeviceIdFromQueueName(queueName); - final String accountNumber = "+1" + RandomStringUtils.randomNumeric(10); - - final Account account = mock(Account.class); - - when(accountsManager.getByAccountIdentifier(accountUuid)).thenReturn(Optional.of(account)); - when(account.getNumber()).thenReturn(accountNumber); - - insertMessages(accountUuid, deviceId, messagesPerQueue, now); - } - - setNextSlotToPersist(slot); - - messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay())); - - final ArgumentCaptor> messagesCaptor = ArgumentCaptor.forClass(List.class); - - verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), any(UUID.class), anyLong()); - assertEquals(queueCount * messagesPerQueue, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum()); - } - - @Test - public void testPersistQueueRetry() { - final String queueName = new String(MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8); - final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7; - final Instant now = Instant.now(); - - insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now); - setNextSlotToPersist(SlotHash.getSlot(queueName)); - - doAnswer((Answer)invocation -> { - throw new RuntimeException("OH NO."); - }).when(messagesDynamoDb).store(any(), eq(DESTINATION_ACCOUNT_UUID), eq(DESTINATION_DEVICE_ID)); - - messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay())); - - assertEquals(List.of(queueName), - messagesCache.getQueuesToPersist(SlotHash.getSlot(queueName), Instant.now().plus(messagePersister.getPersistDelay()), 1)); - } - - @SuppressWarnings("SameParameterValue") - private static String generateRandomQueueNameForSlot(final int slot) { - final UUID uuid = UUID.randomUUID(); - - final String queueNameBase = "user_queue::{" + uuid.toString() + "::"; - - for (int deviceId = 0; deviceId < Integer.MAX_VALUE; deviceId++) { - final String queueName = queueNameBase + deviceId + "}"; - - if (SlotHash.getSlot(queueName) == slot) { - return queueName; - } - } - - throw new IllegalStateException("Could not find a queue name for slot " + slot); - } - - private void insertMessages(final UUID accountUuid, final long deviceId, final int messageCount, final Instant firstMessageTimestamp) { - for (int i = 0; i < messageCount; i++) { - final UUID messageGuid = UUID.randomUUID(); - - final MessageProtos.Envelope envelope = MessageProtos.Envelope.newBuilder() - .setTimestamp(firstMessageTimestamp.toEpochMilli() + i) - .setServerTimestamp(firstMessageTimestamp.toEpochMilli() + i) - .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) - .setType(MessageProtos.Envelope.Type.CIPHERTEXT) - .setServerGuid(messageGuid.toString()) - .build(); - - messagesCache.insert(messageGuid, accountUuid, deviceId, envelope); - } - } - - private void setNextSlotToPersist(final int nextSlot) { - getRedisCluster().useCluster(connection -> connection.sync().set(MessagesCache.NEXT_SLOT_TO_PERSIST_KEY, String.valueOf(nextSlot - 1))); - } + private void setNextSlotToPersist(final int nextSlot) { + REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster( + connection -> connection.sync().set(MessagesCache.NEXT_SLOT_TO_PERSIST_KEY, String.valueOf(nextSlot - 1))); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index acc333a35..a7abfbb86 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -5,13 +5,15 @@ package org.whispersystems.textsecuregcm.storage; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; +import static org.junit.jupiter.api.Assertions.assertTrue; import com.google.protobuf.ByteString; import io.lettuce.core.cluster.SlotHash; import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.Collections; @@ -24,18 +26,21 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; -import junitparams.JUnitParamsRunner; -import junitparams.Parameters; import org.apache.commons.lang3.RandomStringUtils; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; -import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; -@RunWith(JUnitParamsRunner.class) -public class MessagesCacheTest extends AbstractRedisClusterTest { +class MessagesCacheTest { + + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); private ExecutorService notificationExecutorService; private MessagesCache messagesCache; @@ -46,40 +51,38 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { private static final UUID DESTINATION_UUID = UUID.randomUUID(); private static final int DESTINATION_DEVICE_ID = 7; - @Override - @Before - public void setUp() throws Exception { - super.setUp(); + @BeforeEach + void setUp() throws Exception { - getRedisCluster().useCluster( - connection -> connection.sync().upstream().commands().configSet("notify-keyspace-events", "Klgz")); + REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(connection -> { + connection.sync().flushall(); + connection.sync().upstream().commands().configSet("notify-keyspace-events", "K$glz"); + }); notificationExecutorService = Executors.newSingleThreadExecutor(); - messagesCache = new MessagesCache(getRedisCluster(), getRedisCluster(), notificationExecutorService); + messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), REDIS_CLUSTER_EXTENSION.getRedisCluster(), notificationExecutorService); messagesCache.start(); } - @Override - public void tearDown() throws Exception { + @AfterEach + void tearDown() throws Exception { messagesCache.stop(); notificationExecutorService.shutdown(); notificationExecutorService.awaitTermination(1, TimeUnit.SECONDS); - - super.tearDown(); } - @Test - @Parameters({"true", "false"}) - public void testInsert(final boolean sealedSender) { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testInsert(final boolean sealedSender) { final UUID messageGuid = UUID.randomUUID(); assertTrue(messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, generateRandomMessage(messageGuid, sealedSender)) > 0); } @Test - public void testDoubleInsertGuid() { + void testDoubleInsertGuid() { final UUID duplicateGuid = UUID.randomUUID(); final MessageProtos.Envelope duplicateMessage = generateRandomMessage(duplicateGuid, false); @@ -90,9 +93,9 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { assertEquals(firstId, secondId); } - @Test - @Parameters({"true", "false"}) - public void testRemoveByUUID(final boolean sealedSender) { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testRemoveByUUID(final boolean sealedSender) { final UUID messageGuid = UUID.randomUUID(); assertEquals(Optional.empty(), messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid)); @@ -107,9 +110,9 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { assertEquals(MessagesCache.constructEntityFromEnvelope(message), maybeRemovedMessage.get()); } - @Test - @Parameters({"true", "false"}) - public void testRemoveBatchByUUID(final boolean sealedSender) { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testRemoveBatchByUUID(final boolean sealedSender) { final int messageCount = 10; final List messagesToRemove = new ArrayList<>(messageCount); @@ -145,7 +148,7 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { } @Test - public void testHasMessages() { + void testHasMessages() { assertFalse(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID)); final UUID messageGuid = UUID.randomUUID(); @@ -155,9 +158,9 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { assertTrue(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID)); } - @Test - @Parameters({"true", "false"}) - public void testGetMessages(final boolean sealedSender) { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testGetMessages(final boolean sealedSender) { final int messageCount = 100; final List expectedMessages = new ArrayList<>(messageCount); @@ -173,9 +176,9 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { assertEquals(expectedMessages, messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); } - @Test - @Parameters({"true", "false"}) - public void testClearQueueForDevice(final boolean sealedSender) { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testClearQueueForDevice(final boolean sealedSender) { final int messageCount = 100; for (final int deviceId : new int[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) { @@ -193,9 +196,9 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { assertEquals(messageCount, messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount).size()); } - @Test - @Parameters({"true", "false"}) - public void testClearQueueForAccount(final boolean sealedSender) { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testClearQueueForAccount(final boolean sealedSender) { final int messageCount = 100; for (final int deviceId : new int[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) { @@ -236,13 +239,13 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { } @Test - public void testClearNullUuid() { + void testClearNullUuid() { // We're happy as long as this doesn't throw an exception messagesCache.clear(null); } @Test - public void testGetAccountFromQueueName() { + void testGetAccountFromQueueName() { assertEquals(DESTINATION_UUID, MessagesCache.getAccountUuidFromQueueName( new String(MessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID), @@ -250,7 +253,7 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { } @Test - public void testGetDeviceIdFromQueueName() { + void testGetDeviceIdFromQueueName() { assertEquals(DESTINATION_DEVICE_ID, MessagesCache.getDeviceIdFromQueueName( new String(MessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID), @@ -258,14 +261,14 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { } @Test - public void testGetQueueNameFromKeyspaceChannel() { + void testGetQueueNameFromKeyspaceChannel() { assertEquals("1b363a31-a429-4fb6-8959-984a025e72ff::7", MessagesCache.getQueueNameFromKeyspaceChannel( "__keyspace@0__:user_queue::{1b363a31-a429-4fb6-8959-984a025e72ff::7}")); } - @Test - @Parameters({"true", "false"}) + @ParameterizedTest + @ValueSource(booleans = {true, false}) public void testGetQueuesToPersist(final boolean sealedSender) { final UUID messageGuid = UUID.randomUUID(); @@ -282,8 +285,8 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { assertEquals(DESTINATION_DEVICE_ID, MessagesCache.getDeviceIdFromQueueName(queues.get(0))); } - @Test(timeout = 5_000L) - public void testNotifyListenerNewMessage() throws InterruptedException { + @Test + void testNotifyListenerNewMessage() { final AtomicBoolean notified = new AtomicBoolean(false); final UUID messageGuid = UUID.randomUUID(); @@ -301,21 +304,23 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { } }; - messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener); - messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, - generateRandomMessage(messageGuid, true)); + assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { + messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener); + messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, + generateRandomMessage(messageGuid, true)); - synchronized (notified) { - while (!notified.get()) { - notified.wait(); + synchronized (notified) { + while (!notified.get()) { + notified.wait(); + } } - } - assertTrue(notified.get()); + assertTrue(notified.get()); + }); } - @Test(timeout = 5_000L) - public void testNotifyListenerPersisted() throws InterruptedException { + @Test + void testNotifyListenerPersisted() { final AtomicBoolean notified = new AtomicBoolean(false); final MessageAvailabilityListener listener = new MessageAvailabilityListener() { @@ -332,18 +337,20 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { } }; - messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener); + assertTimeoutPreemptively(Duration.ofSeconds(5), () -> { + messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener); - messagesCache.lockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID); - messagesCache.unlockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID); + messagesCache.lockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID); + messagesCache.unlockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID); - synchronized (notified) { - while (!notified.get()) { - notified.wait(); + synchronized (notified) { + while (!notified.get()) { + notified.wait(); + } } - } - assertTrue(notified.get()); + assertTrue(notified.get()); + }); } }