Migrate remaining JUnit 4 Redis cluster tests to `RedisClusterExtension`

* Increase redis cluster initialization wait to 10 seconds
* Move to JUnit 5 `Assumptions`
This commit is contained in:
Chris Eager 2022-01-03 14:59:39 -08:00 committed by GitHub
parent c488c14d25
commit b758737907
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 268 additions and 442 deletions

View File

@ -1,194 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.redis;
import io.lettuce.core.RedisClient;
import io.lettuce.core.RedisException;
import io.lettuce.core.RedisURI;
import io.lettuce.core.api.StatefulRedisConnection;
import io.lettuce.core.api.sync.RedisCommands;
import io.lettuce.core.cluster.RedisClusterClient;
import io.lettuce.core.cluster.SlotHash;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.configuration.RetryConfiguration;
import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import redis.embedded.RedisServer;
import java.io.File;
import java.io.IOException;
import java.net.ServerSocket;
import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import static org.junit.Assume.assumeFalse;
/**
* An abstract base class that assembles a real (local!) Redis cluster and provides a client to that cluster for
* subclasses.
*/
public abstract class AbstractRedisClusterTest {
private static final int NODE_COUNT = 2;
private static RedisServer[] clusterNodes;
private FaultTolerantRedisCluster redisCluster;
@BeforeClass
public static void setUpBeforeClass() throws Exception {
assumeFalse(System.getProperty("os.name").equalsIgnoreCase("windows"));
clusterNodes = new RedisServer[NODE_COUNT];
for (int i = 0; i < NODE_COUNT; i++) {
clusterNodes[i] = buildClusterNode(getNextRedisClusterPort());
clusterNodes[i].start();
}
assembleCluster(clusterNodes);
}
@Before
public void setUp() throws Exception {
final List<String> urls = Arrays.stream(clusterNodes)
.map(node -> String.format("redis://127.0.0.1:%d", node.ports().get(0)))
.collect(Collectors.toList());
redisCluster = new FaultTolerantRedisCluster("test-cluster",
RedisClusterClient.create(urls.stream().map(RedisURI::create).collect(Collectors.toList())),
Duration.ofSeconds(2),
new CircuitBreakerConfiguration(),
new RetryConfiguration());
redisCluster.useCluster(connection -> {
boolean setAll = false;
final String[] keys = new String[NODE_COUNT];
for (int i = 0; i < keys.length; i++) {
keys[i] = RedisClusterUtil.getMinimalHashTag(i * SlotHash.SLOT_COUNT / keys.length);
}
while (!setAll) {
try {
for (final String key : keys) {
connection.sync().set(key, "warmup");
}
setAll = true;
} catch (final RedisException ignored) {
// Cluster isn't ready; wait and retry.
try {
Thread.sleep(500);
} catch (final InterruptedException ignored2) {
}
}
}
});
redisCluster.useCluster(connection -> connection.sync().flushall());
}
protected FaultTolerantRedisCluster getRedisCluster() {
return redisCluster;
}
@After
public void tearDown() throws Exception {
redisCluster.shutdown();
}
@AfterClass
public static void tearDownAfterClass() {
for (final RedisServer node : clusterNodes) {
node.stop();
}
}
private static RedisServer buildClusterNode(final int port) throws IOException {
final File clusterConfigFile = File.createTempFile("redis", ".conf");
clusterConfigFile.deleteOnExit();
return RedisServer.builder()
.setting("cluster-enabled yes")
.setting("cluster-config-file " + clusterConfigFile.getAbsolutePath())
.setting("cluster-node-timeout 5000")
.setting("appendonly no")
.setting("dir " + System.getProperty("java.io.tmpdir"))
.port(port)
.build();
}
private static void assembleCluster(final RedisServer... nodes) throws InterruptedException {
final RedisClient meetClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[0].ports().get(0)));
try {
final StatefulRedisConnection<String, String> connection = meetClient.connect();
final RedisCommands<String, String> commands = connection.sync();
for (int i = 1; i < nodes.length; i++) {
commands.clusterMeet("127.0.0.1", nodes[i].ports().get(0));
}
} finally {
meetClient.shutdown();
}
final int slotsPerNode = SlotHash.SLOT_COUNT / nodes.length;
for (int i = 0; i < nodes.length; i++) {
final int startInclusive = i * slotsPerNode;
final int endExclusive = i == nodes.length - 1 ? SlotHash.SLOT_COUNT : (i + 1) * slotsPerNode;
final RedisClient assignSlotClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[i].ports().get(0)));
try (final StatefulRedisConnection<String, String> assignSlotConnection = assignSlotClient.connect()) {
final int[] slots = new int[endExclusive - startInclusive];
for (int s = startInclusive; s < endExclusive; s++) {
slots[s - startInclusive] = s;
}
assignSlotConnection.sync().clusterAddSlots(slots);
} finally {
assignSlotClient.shutdown();
}
}
final RedisClient waitClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[0].ports().get(0)));
try (final StatefulRedisConnection<String, String> connection = waitClient.connect()) {
// CLUSTER INFO gives us a big blob of key-value pairs, but the one we're interested in is `cluster_state`.
// According to https://redis.io/commands/cluster-info, `cluster_state:ok` means that the node is ready to
// receive queries, all slots are assigned, and a majority of master nodes are reachable.
while (!connection.sync().clusterInfo().contains("cluster_state:ok")) {
Thread.sleep(500);
}
} finally {
waitClient.shutdown();
}
}
public static int getNextRedisClusterPort() throws IOException {
final int MAX_ITERATIONS = 11_000;
int port;
for (int i = 0; i < MAX_ITERATIONS; i++) {
try (ServerSocket socket = new ServerSocket(0)) {
socket.setReuseAddress(false);
port = socket.getLocalPort();
}
if (port < 55535) {
return port;
}
}
throw new IOException("Couldn't find an open port below 55,535 in " + MAX_ITERATIONS + " tries");
}
}

View File

@ -1,11 +1,11 @@
/* /*
* Copyright 2013-2021 Signal Messenger, LLC * Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.redis; package org.whispersystems.textsecuregcm.redis;
import static org.junit.Assume.assumeFalse; import static org.junit.jupiter.api.Assumptions.assumeFalse;
import io.lettuce.core.RedisClient; import io.lettuce.core.RedisClient;
import io.lettuce.core.RedisException; import io.lettuce.core.RedisException;
@ -178,9 +178,9 @@ public class RedisClusterExtension implements BeforeAllCallback, BeforeEachCallb
Thread.sleep(sleepMillis); Thread.sleep(sleepMillis);
tries++; tries++;
if (tries == 10) { if (tries == 20) {
throw new RuntimeException( throw new RuntimeException(
String.format("Timeout: Redis not ready after waiting %d milliseconds", sleepMillis)); String.format("Timeout: Redis not ready after waiting %d milliseconds", tries * sleepMillis));
} }
} }
} finally { } finally {

View File

@ -5,7 +5,7 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static org.junit.Assert.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
@ -18,11 +18,15 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import org.junit.Before; import org.junit.jupiter.api.BeforeEach;
import org.junit.Test; import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
public class AccountDatabaseCrawlerIntegrationTest extends AbstractRedisClusterTest { class AccountDatabaseCrawlerIntegrationTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private static final UUID FIRST_UUID = UUID.fromString("82339e80-81cd-48e2-9ed2-ccd5dd262ad9"); private static final UUID FIRST_UUID = UUID.fromString("82339e80-81cd-48e2-9ed2-ccd5dd262ad9");
private static final UUID SECOND_UUID = UUID.fromString("cc705c84-33cf-456b-8239-a6a34e2f561a"); private static final UUID SECOND_UUID = UUID.fromString("cc705c84-33cf-456b-8239-a6a34e2f561a");
@ -38,9 +42,8 @@ public class AccountDatabaseCrawlerIntegrationTest extends AbstractRedisClusterT
private static final int CHUNK_SIZE = 1; private static final int CHUNK_SIZE = 1;
private static final long CHUNK_INTERVAL_MS = 0; private static final long CHUNK_INTERVAL_MS = 0;
@Before @BeforeEach
public void setUp() throws Exception { void setUp() throws Exception {
super.setUp();
firstAccount = mock(Account.class); firstAccount = mock(Account.class);
secondAccount = mock(Account.class); secondAccount = mock(Account.class);
@ -57,13 +60,13 @@ public class AccountDatabaseCrawlerIntegrationTest extends AbstractRedisClusterT
.thenReturn(new AccountCrawlChunk(List.of(secondAccount), SECOND_UUID)) .thenReturn(new AccountCrawlChunk(List.of(secondAccount), SECOND_UUID))
.thenReturn(new AccountCrawlChunk(Collections.emptyList(), null)); .thenReturn(new AccountCrawlChunk(Collections.emptyList(), null));
final AccountDatabaseCrawlerCache crawlerCache = new AccountDatabaseCrawlerCache(getRedisCluster(), "test"); final AccountDatabaseCrawlerCache crawlerCache = new AccountDatabaseCrawlerCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), "test");
accountDatabaseCrawler = new AccountDatabaseCrawler("test", accountsManager, crawlerCache, List.of(listener), CHUNK_SIZE, accountDatabaseCrawler = new AccountDatabaseCrawler("test", accountsManager, crawlerCache, List.of(listener), CHUNK_SIZE,
CHUNK_INTERVAL_MS); CHUNK_INTERVAL_MS);
} }
@Test @Test
public void testCrawlUninterrupted() throws AccountDatabaseCrawlerRestartException { void testCrawlUninterrupted() throws AccountDatabaseCrawlerRestartException {
assertFalse(accountDatabaseCrawler.doPeriodicWork()); assertFalse(accountDatabaseCrawler.doPeriodicWork());
assertFalse(accountDatabaseCrawler.doPeriodicWork()); assertFalse(accountDatabaseCrawler.doPeriodicWork());
assertFalse(accountDatabaseCrawler.doPeriodicWork()); assertFalse(accountDatabaseCrawler.doPeriodicWork());
@ -79,7 +82,7 @@ public class AccountDatabaseCrawlerIntegrationTest extends AbstractRedisClusterT
} }
@Test @Test
public void testCrawlWithReset() throws AccountDatabaseCrawlerRestartException { void testCrawlWithReset() throws AccountDatabaseCrawlerRestartException {
doThrow(new AccountDatabaseCrawlerRestartException("OH NO")).doNothing() doThrow(new AccountDatabaseCrawlerRestartException("OH NO")).doNothing()
.when(listener).timeAndProcessCrawlChunk(Optional.empty(), List.of(firstAccount)); .when(listener).timeAndProcessCrawlChunk(Optional.empty(), List.of(firstAccount));

View File

@ -5,7 +5,7 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
@ -28,194 +28,204 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.RandomStringUtils;
import org.junit.Before; import org.junit.jupiter.api.AfterEach;
import org.junit.Test; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
public class MessagePersisterTest extends AbstractRedisClusterTest { class MessagePersisterTest {
private ExecutorService notificationExecutorService; @RegisterExtension
private MessagesCache messagesCache; static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private MessagesDynamoDb messagesDynamoDb;
private MessagePersister messagePersister;
private AccountsManager accountsManager;
private static final UUID DESTINATION_ACCOUNT_UUID = UUID.randomUUID(); private ExecutorService notificationExecutorService;
private static final String DESTINATION_ACCOUNT_NUMBER = "+18005551234"; private MessagesCache messagesCache;
private static final long DESTINATION_DEVICE_ID = 7; private MessagesDynamoDb messagesDynamoDb;
private MessagePersister messagePersister;
private AccountsManager accountsManager;
private static final Duration PERSIST_DELAY = Duration.ofMinutes(5); private static final UUID DESTINATION_ACCOUNT_UUID = UUID.randomUUID();
private static final String DESTINATION_ACCOUNT_NUMBER = "+18005551234";
private static final long DESTINATION_DEVICE_ID = 7;
@Override private static final Duration PERSIST_DELAY = Duration.ofMinutes(5);
@Before
public void setUp() throws Exception {
super.setUp();
final MessagesManager messagesManager = mock(MessagesManager.class); @BeforeEach
final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class); void setUp() throws Exception {
messagesDynamoDb = mock(MessagesDynamoDb.class); final MessagesManager messagesManager = mock(MessagesManager.class);
accountsManager = mock(AccountsManager.class); final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
final Account account = mock(Account.class); messagesDynamoDb = mock(MessagesDynamoDb.class);
accountsManager = mock(AccountsManager.class);
when(accountsManager.getByAccountIdentifier(DESTINATION_ACCOUNT_UUID)).thenReturn(Optional.of(account)); final Account account = mock(Account.class);
when(account.getNumber()).thenReturn(DESTINATION_ACCOUNT_NUMBER);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
notificationExecutorService = Executors.newSingleThreadExecutor(); when(accountsManager.getByAccountIdentifier(DESTINATION_ACCOUNT_UUID)).thenReturn(Optional.of(account));
messagesCache = new MessagesCache(getRedisCluster(), getRedisCluster(), notificationExecutorService); when(account.getNumber()).thenReturn(DESTINATION_ACCOUNT_NUMBER);
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, dynamicConfigurationManager, PERSIST_DELAY); when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
doAnswer(invocation -> { notificationExecutorService = Executors.newSingleThreadExecutor();
final UUID destinationUuid = invocation.getArgument(0); messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
final long destinationDeviceId = invocation.getArgument(1); REDIS_CLUSTER_EXTENSION.getRedisCluster(), notificationExecutorService);
final List<MessageProtos.Envelope> messages = invocation.getArgument(2); messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
dynamicConfigurationManager, PERSIST_DELAY);
messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId); doAnswer(invocation -> {
final UUID destinationUuid = invocation.getArgument(0);
final long destinationDeviceId = invocation.getArgument(1);
final List<MessageProtos.Envelope> messages = invocation.getArgument(2);
for (final MessageProtos.Envelope message : messages) { messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId);
messagesCache.remove(destinationUuid, destinationDeviceId, UUID.fromString(message.getServerGuid()));
}
return null; for (final MessageProtos.Envelope message : messages) {
}).when(messagesManager).persistMessages(any(UUID.class), anyLong(), any()); messagesCache.remove(destinationUuid, destinationDeviceId, UUID.fromString(message.getServerGuid()));
}
return null;
}).when(messagesManager).persistMessages(any(UUID.class), anyLong(), any());
}
@AfterEach
void tearDown() throws Exception {
notificationExecutorService.shutdown();
notificationExecutorService.awaitTermination(1, TimeUnit.SECONDS);
}
@Test
void testPersistNextQueuesNoQueues() {
messagePersister.persistNextQueues(Instant.now());
verify(accountsManager, never()).getByAccountIdentifier(any(UUID.class));
}
@Test
void testPersistNextQueuesSingleQueue() {
final String queueName = new String(
MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7;
final Instant now = Instant.now();
insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now);
setNextSlotToPersist(SlotHash.getSlot(queueName));
messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay()));
final ArgumentCaptor<List<MessageProtos.Envelope>> messagesCaptor = ArgumentCaptor.forClass(List.class);
verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), eq(DESTINATION_ACCOUNT_UUID),
eq(DESTINATION_DEVICE_ID));
assertEquals(messageCount, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum());
}
@Test
void testPersistNextQueuesSingleQueueTooSoon() {
final String queueName = new String(
MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7;
final Instant now = Instant.now();
insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now);
setNextSlotToPersist(SlotHash.getSlot(queueName));
messagePersister.persistNextQueues(now);
verify(messagesDynamoDb, never()).store(any(), any(), anyLong());
}
@Test
void testPersistNextQueuesMultiplePages() {
final int slot = 7;
final int queueCount = (MessagePersister.QUEUE_BATCH_LIMIT * 3) + 7;
final int messagesPerQueue = 10;
final Instant now = Instant.now();
for (int i = 0; i < queueCount; i++) {
final String queueName = generateRandomQueueNameForSlot(slot);
final UUID accountUuid = MessagesCache.getAccountUuidFromQueueName(queueName);
final long deviceId = MessagesCache.getDeviceIdFromQueueName(queueName);
final String accountNumber = "+1" + RandomStringUtils.randomNumeric(10);
final Account account = mock(Account.class);
when(accountsManager.getByAccountIdentifier(accountUuid)).thenReturn(Optional.of(account));
when(account.getNumber()).thenReturn(accountNumber);
insertMessages(accountUuid, deviceId, messagesPerQueue, now);
} }
@Override setNextSlotToPersist(slot);
public void tearDown() throws Exception {
super.tearDown();
notificationExecutorService.shutdown(); messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay()));
notificationExecutorService.awaitTermination(1, TimeUnit.SECONDS);
final ArgumentCaptor<List<MessageProtos.Envelope>> messagesCaptor = ArgumentCaptor.forClass(List.class);
verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), any(UUID.class), anyLong());
assertEquals(queueCount * messagesPerQueue, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum());
}
@Test
void testPersistQueueRetry() {
final String queueName = new String(
MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7;
final Instant now = Instant.now();
insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now);
setNextSlotToPersist(SlotHash.getSlot(queueName));
doAnswer((Answer<Void>) invocation -> {
throw new RuntimeException("OH NO.");
}).when(messagesDynamoDb).store(any(), eq(DESTINATION_ACCOUNT_UUID), eq(DESTINATION_DEVICE_ID));
messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay()));
assertEquals(List.of(queueName),
messagesCache.getQueuesToPersist(SlotHash.getSlot(queueName),
Instant.now().plus(messagePersister.getPersistDelay()), 1));
}
@SuppressWarnings("SameParameterValue")
private static String generateRandomQueueNameForSlot(final int slot) {
final UUID uuid = UUID.randomUUID();
final String queueNameBase = "user_queue::{" + uuid + "::";
for (int deviceId = 0; deviceId < Integer.MAX_VALUE; deviceId++) {
final String queueName = queueNameBase + deviceId + "}";
if (SlotHash.getSlot(queueName) == slot) {
return queueName;
}
} }
@Test throw new IllegalStateException("Could not find a queue name for slot " + slot);
public void testPersistNextQueuesNoQueues() { }
messagePersister.persistNextQueues(Instant.now());
verify(accountsManager, never()).getByAccountIdentifier(any(UUID.class)); private void insertMessages(final UUID accountUuid, final long deviceId, final int messageCount,
final Instant firstMessageTimestamp) {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = MessageProtos.Envelope.newBuilder()
.setTimestamp(firstMessageTimestamp.toEpochMilli() + i)
.setServerTimestamp(firstMessageTimestamp.toEpochMilli() + i)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString())
.build();
messagesCache.insert(messageGuid, accountUuid, deviceId, envelope);
} }
}
@Test private void setNextSlotToPersist(final int nextSlot) {
public void testPersistNextQueuesSingleQueue() { REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(
final String queueName = new String(MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8); connection -> connection.sync().set(MessagesCache.NEXT_SLOT_TO_PERSIST_KEY, String.valueOf(nextSlot - 1)));
final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7; }
final Instant now = Instant.now();
insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now);
setNextSlotToPersist(SlotHash.getSlot(queueName));
messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay()));
final ArgumentCaptor<List<MessageProtos.Envelope>> messagesCaptor = ArgumentCaptor.forClass(List.class);
verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), eq(DESTINATION_ACCOUNT_UUID), eq(DESTINATION_DEVICE_ID));
assertEquals(messageCount, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum());
}
@Test
public void testPersistNextQueuesSingleQueueTooSoon() {
final String queueName = new String(MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7;
final Instant now = Instant.now();
insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now);
setNextSlotToPersist(SlotHash.getSlot(queueName));
messagePersister.persistNextQueues(now);
verify(messagesDynamoDb, never()).store(any(), any(), anyLong());
}
@Test
public void testPersistNextQueuesMultiplePages() {
final int slot = 7;
final int queueCount = (MessagePersister.QUEUE_BATCH_LIMIT * 3) + 7;
final int messagesPerQueue = 10;
final Instant now = Instant.now();
for (int i = 0; i < queueCount; i++) {
final String queueName = generateRandomQueueNameForSlot(slot);
final UUID accountUuid = MessagesCache.getAccountUuidFromQueueName(queueName);
final long deviceId = MessagesCache.getDeviceIdFromQueueName(queueName);
final String accountNumber = "+1" + RandomStringUtils.randomNumeric(10);
final Account account = mock(Account.class);
when(accountsManager.getByAccountIdentifier(accountUuid)).thenReturn(Optional.of(account));
when(account.getNumber()).thenReturn(accountNumber);
insertMessages(accountUuid, deviceId, messagesPerQueue, now);
}
setNextSlotToPersist(slot);
messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay()));
final ArgumentCaptor<List<MessageProtos.Envelope>> messagesCaptor = ArgumentCaptor.forClass(List.class);
verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), any(UUID.class), anyLong());
assertEquals(queueCount * messagesPerQueue, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum());
}
@Test
public void testPersistQueueRetry() {
final String queueName = new String(MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7;
final Instant now = Instant.now();
insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now);
setNextSlotToPersist(SlotHash.getSlot(queueName));
doAnswer((Answer<Void>)invocation -> {
throw new RuntimeException("OH NO.");
}).when(messagesDynamoDb).store(any(), eq(DESTINATION_ACCOUNT_UUID), eq(DESTINATION_DEVICE_ID));
messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay()));
assertEquals(List.of(queueName),
messagesCache.getQueuesToPersist(SlotHash.getSlot(queueName), Instant.now().plus(messagePersister.getPersistDelay()), 1));
}
@SuppressWarnings("SameParameterValue")
private static String generateRandomQueueNameForSlot(final int slot) {
final UUID uuid = UUID.randomUUID();
final String queueNameBase = "user_queue::{" + uuid.toString() + "::";
for (int deviceId = 0; deviceId < Integer.MAX_VALUE; deviceId++) {
final String queueName = queueNameBase + deviceId + "}";
if (SlotHash.getSlot(queueName) == slot) {
return queueName;
}
}
throw new IllegalStateException("Could not find a queue name for slot " + slot);
}
private void insertMessages(final UUID accountUuid, final long deviceId, final int messageCount, final Instant firstMessageTimestamp) {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = MessageProtos.Envelope.newBuilder()
.setTimestamp(firstMessageTimestamp.toEpochMilli() + i)
.setServerTimestamp(firstMessageTimestamp.toEpochMilli() + i)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString())
.build();
messagesCache.insert(messageGuid, accountUuid, deviceId, envelope);
}
}
private void setNextSlotToPersist(final int nextSlot) {
getRedisCluster().useCluster(connection -> connection.sync().set(MessagesCache.NEXT_SLOT_TO_PERSIST_KEY, String.valueOf(nextSlot - 1)));
}
} }

View File

@ -5,13 +5,15 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.lettuce.core.cluster.SlotHash; import io.lettuce.core.cluster.SlotHash;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
@ -24,18 +26,21 @@ import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.RandomStringUtils;
import org.junit.Before; import org.junit.jupiter.api.AfterEach;
import org.junit.Test; import org.junit.jupiter.api.BeforeEach;
import org.junit.runner.RunWith; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
@RunWith(JUnitParamsRunner.class) class MessagesCacheTest {
public class MessagesCacheTest extends AbstractRedisClusterTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private ExecutorService notificationExecutorService; private ExecutorService notificationExecutorService;
private MessagesCache messagesCache; private MessagesCache messagesCache;
@ -46,40 +51,38 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
private static final UUID DESTINATION_UUID = UUID.randomUUID(); private static final UUID DESTINATION_UUID = UUID.randomUUID();
private static final int DESTINATION_DEVICE_ID = 7; private static final int DESTINATION_DEVICE_ID = 7;
@Override @BeforeEach
@Before void setUp() throws Exception {
public void setUp() throws Exception {
super.setUp();
getRedisCluster().useCluster( REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(connection -> {
connection -> connection.sync().upstream().commands().configSet("notify-keyspace-events", "Klgz")); connection.sync().flushall();
connection.sync().upstream().commands().configSet("notify-keyspace-events", "K$glz");
});
notificationExecutorService = Executors.newSingleThreadExecutor(); notificationExecutorService = Executors.newSingleThreadExecutor();
messagesCache = new MessagesCache(getRedisCluster(), getRedisCluster(), notificationExecutorService); messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), REDIS_CLUSTER_EXTENSION.getRedisCluster(), notificationExecutorService);
messagesCache.start(); messagesCache.start();
} }
@Override @AfterEach
public void tearDown() throws Exception { void tearDown() throws Exception {
messagesCache.stop(); messagesCache.stop();
notificationExecutorService.shutdown(); notificationExecutorService.shutdown();
notificationExecutorService.awaitTermination(1, TimeUnit.SECONDS); notificationExecutorService.awaitTermination(1, TimeUnit.SECONDS);
super.tearDown();
} }
@Test @ParameterizedTest
@Parameters({"true", "false"}) @ValueSource(booleans = {true, false})
public void testInsert(final boolean sealedSender) { void testInsert(final boolean sealedSender) {
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
assertTrue(messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, assertTrue(messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
generateRandomMessage(messageGuid, sealedSender)) > 0); generateRandomMessage(messageGuid, sealedSender)) > 0);
} }
@Test @Test
public void testDoubleInsertGuid() { void testDoubleInsertGuid() {
final UUID duplicateGuid = UUID.randomUUID(); final UUID duplicateGuid = UUID.randomUUID();
final MessageProtos.Envelope duplicateMessage = generateRandomMessage(duplicateGuid, false); final MessageProtos.Envelope duplicateMessage = generateRandomMessage(duplicateGuid, false);
@ -90,9 +93,9 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
assertEquals(firstId, secondId); assertEquals(firstId, secondId);
} }
@Test @ParameterizedTest
@Parameters({"true", "false"}) @ValueSource(booleans = {true, false})
public void testRemoveByUUID(final boolean sealedSender) { void testRemoveByUUID(final boolean sealedSender) {
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
assertEquals(Optional.empty(), messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid)); assertEquals(Optional.empty(), messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid));
@ -107,9 +110,9 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
assertEquals(MessagesCache.constructEntityFromEnvelope(message), maybeRemovedMessage.get()); assertEquals(MessagesCache.constructEntityFromEnvelope(message), maybeRemovedMessage.get());
} }
@Test @ParameterizedTest
@Parameters({"true", "false"}) @ValueSource(booleans = {true, false})
public void testRemoveBatchByUUID(final boolean sealedSender) { void testRemoveBatchByUUID(final boolean sealedSender) {
final int messageCount = 10; final int messageCount = 10;
final List<MessageProtos.Envelope> messagesToRemove = new ArrayList<>(messageCount); final List<MessageProtos.Envelope> messagesToRemove = new ArrayList<>(messageCount);
@ -145,7 +148,7 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
} }
@Test @Test
public void testHasMessages() { void testHasMessages() {
assertFalse(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID)); assertFalse(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID));
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
@ -155,9 +158,9 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
assertTrue(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID)); assertTrue(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID));
} }
@Test @ParameterizedTest
@Parameters({"true", "false"}) @ValueSource(booleans = {true, false})
public void testGetMessages(final boolean sealedSender) { void testGetMessages(final boolean sealedSender) {
final int messageCount = 100; final int messageCount = 100;
final List<OutgoingMessageEntity> expectedMessages = new ArrayList<>(messageCount); final List<OutgoingMessageEntity> expectedMessages = new ArrayList<>(messageCount);
@ -173,9 +176,9 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
assertEquals(expectedMessages, messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount)); assertEquals(expectedMessages, messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
} }
@Test @ParameterizedTest
@Parameters({"true", "false"}) @ValueSource(booleans = {true, false})
public void testClearQueueForDevice(final boolean sealedSender) { void testClearQueueForDevice(final boolean sealedSender) {
final int messageCount = 100; final int messageCount = 100;
for (final int deviceId : new int[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) { for (final int deviceId : new int[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) {
@ -193,9 +196,9 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
assertEquals(messageCount, messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount).size()); assertEquals(messageCount, messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount).size());
} }
@Test @ParameterizedTest
@Parameters({"true", "false"}) @ValueSource(booleans = {true, false})
public void testClearQueueForAccount(final boolean sealedSender) { void testClearQueueForAccount(final boolean sealedSender) {
final int messageCount = 100; final int messageCount = 100;
for (final int deviceId : new int[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) { for (final int deviceId : new int[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) {
@ -236,13 +239,13 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
} }
@Test @Test
public void testClearNullUuid() { void testClearNullUuid() {
// We're happy as long as this doesn't throw an exception // We're happy as long as this doesn't throw an exception
messagesCache.clear(null); messagesCache.clear(null);
} }
@Test @Test
public void testGetAccountFromQueueName() { void testGetAccountFromQueueName() {
assertEquals(DESTINATION_UUID, assertEquals(DESTINATION_UUID,
MessagesCache.getAccountUuidFromQueueName( MessagesCache.getAccountUuidFromQueueName(
new String(MessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID), new String(MessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID),
@ -250,7 +253,7 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
} }
@Test @Test
public void testGetDeviceIdFromQueueName() { void testGetDeviceIdFromQueueName() {
assertEquals(DESTINATION_DEVICE_ID, assertEquals(DESTINATION_DEVICE_ID,
MessagesCache.getDeviceIdFromQueueName( MessagesCache.getDeviceIdFromQueueName(
new String(MessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID), new String(MessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID),
@ -258,14 +261,14 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
} }
@Test @Test
public void testGetQueueNameFromKeyspaceChannel() { void testGetQueueNameFromKeyspaceChannel() {
assertEquals("1b363a31-a429-4fb6-8959-984a025e72ff::7", assertEquals("1b363a31-a429-4fb6-8959-984a025e72ff::7",
MessagesCache.getQueueNameFromKeyspaceChannel( MessagesCache.getQueueNameFromKeyspaceChannel(
"__keyspace@0__:user_queue::{1b363a31-a429-4fb6-8959-984a025e72ff::7}")); "__keyspace@0__:user_queue::{1b363a31-a429-4fb6-8959-984a025e72ff::7}"));
} }
@Test @ParameterizedTest
@Parameters({"true", "false"}) @ValueSource(booleans = {true, false})
public void testGetQueuesToPersist(final boolean sealedSender) { public void testGetQueuesToPersist(final boolean sealedSender) {
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
@ -282,8 +285,8 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
assertEquals(DESTINATION_DEVICE_ID, MessagesCache.getDeviceIdFromQueueName(queues.get(0))); assertEquals(DESTINATION_DEVICE_ID, MessagesCache.getDeviceIdFromQueueName(queues.get(0)));
} }
@Test(timeout = 5_000L) @Test
public void testNotifyListenerNewMessage() throws InterruptedException { void testNotifyListenerNewMessage() {
final AtomicBoolean notified = new AtomicBoolean(false); final AtomicBoolean notified = new AtomicBoolean(false);
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
@ -301,21 +304,23 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
} }
}; };
messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener); assertTimeoutPreemptively(Duration.ofSeconds(5), () -> {
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener);
generateRandomMessage(messageGuid, true)); messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
generateRandomMessage(messageGuid, true));
synchronized (notified) { synchronized (notified) {
while (!notified.get()) { while (!notified.get()) {
notified.wait(); notified.wait();
}
} }
}
assertTrue(notified.get()); assertTrue(notified.get());
});
} }
@Test(timeout = 5_000L) @Test
public void testNotifyListenerPersisted() throws InterruptedException { void testNotifyListenerPersisted() {
final AtomicBoolean notified = new AtomicBoolean(false); final AtomicBoolean notified = new AtomicBoolean(false);
final MessageAvailabilityListener listener = new MessageAvailabilityListener() { final MessageAvailabilityListener listener = new MessageAvailabilityListener() {
@ -332,18 +337,20 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
} }
}; };
messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener); assertTimeoutPreemptively(Duration.ofSeconds(5), () -> {
messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener);
messagesCache.lockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID); messagesCache.lockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID);
messagesCache.unlockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID); messagesCache.unlockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID);
synchronized (notified) { synchronized (notified) {
while (!notified.get()) { while (!notified.get()) {
notified.wait(); notified.wait();
}
} }
}
assertTrue(notified.get()); assertTrue(notified.get());
});
} }
} }