Migrate remaining JUnit 4 Redis cluster tests to `RedisClusterExtension`

* Increase redis cluster initialization wait to 10 seconds
* Move to JUnit 5 `Assumptions`
This commit is contained in:
Chris Eager 2022-01-03 14:59:39 -08:00 committed by GitHub
parent c488c14d25
commit b758737907
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 268 additions and 442 deletions

View File

@ -1,194 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.redis;
import io.lettuce.core.RedisClient;
import io.lettuce.core.RedisException;
import io.lettuce.core.RedisURI;
import io.lettuce.core.api.StatefulRedisConnection;
import io.lettuce.core.api.sync.RedisCommands;
import io.lettuce.core.cluster.RedisClusterClient;
import io.lettuce.core.cluster.SlotHash;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.configuration.RetryConfiguration;
import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import redis.embedded.RedisServer;
import java.io.File;
import java.io.IOException;
import java.net.ServerSocket;
import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import static org.junit.Assume.assumeFalse;
/**
* An abstract base class that assembles a real (local!) Redis cluster and provides a client to that cluster for
* subclasses.
*/
public abstract class AbstractRedisClusterTest {
private static final int NODE_COUNT = 2;
private static RedisServer[] clusterNodes;
private FaultTolerantRedisCluster redisCluster;
@BeforeClass
public static void setUpBeforeClass() throws Exception {
assumeFalse(System.getProperty("os.name").equalsIgnoreCase("windows"));
clusterNodes = new RedisServer[NODE_COUNT];
for (int i = 0; i < NODE_COUNT; i++) {
clusterNodes[i] = buildClusterNode(getNextRedisClusterPort());
clusterNodes[i].start();
}
assembleCluster(clusterNodes);
}
@Before
public void setUp() throws Exception {
final List<String> urls = Arrays.stream(clusterNodes)
.map(node -> String.format("redis://127.0.0.1:%d", node.ports().get(0)))
.collect(Collectors.toList());
redisCluster = new FaultTolerantRedisCluster("test-cluster",
RedisClusterClient.create(urls.stream().map(RedisURI::create).collect(Collectors.toList())),
Duration.ofSeconds(2),
new CircuitBreakerConfiguration(),
new RetryConfiguration());
redisCluster.useCluster(connection -> {
boolean setAll = false;
final String[] keys = new String[NODE_COUNT];
for (int i = 0; i < keys.length; i++) {
keys[i] = RedisClusterUtil.getMinimalHashTag(i * SlotHash.SLOT_COUNT / keys.length);
}
while (!setAll) {
try {
for (final String key : keys) {
connection.sync().set(key, "warmup");
}
setAll = true;
} catch (final RedisException ignored) {
// Cluster isn't ready; wait and retry.
try {
Thread.sleep(500);
} catch (final InterruptedException ignored2) {
}
}
}
});
redisCluster.useCluster(connection -> connection.sync().flushall());
}
protected FaultTolerantRedisCluster getRedisCluster() {
return redisCluster;
}
@After
public void tearDown() throws Exception {
redisCluster.shutdown();
}
@AfterClass
public static void tearDownAfterClass() {
for (final RedisServer node : clusterNodes) {
node.stop();
}
}
private static RedisServer buildClusterNode(final int port) throws IOException {
final File clusterConfigFile = File.createTempFile("redis", ".conf");
clusterConfigFile.deleteOnExit();
return RedisServer.builder()
.setting("cluster-enabled yes")
.setting("cluster-config-file " + clusterConfigFile.getAbsolutePath())
.setting("cluster-node-timeout 5000")
.setting("appendonly no")
.setting("dir " + System.getProperty("java.io.tmpdir"))
.port(port)
.build();
}
private static void assembleCluster(final RedisServer... nodes) throws InterruptedException {
final RedisClient meetClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[0].ports().get(0)));
try {
final StatefulRedisConnection<String, String> connection = meetClient.connect();
final RedisCommands<String, String> commands = connection.sync();
for (int i = 1; i < nodes.length; i++) {
commands.clusterMeet("127.0.0.1", nodes[i].ports().get(0));
}
} finally {
meetClient.shutdown();
}
final int slotsPerNode = SlotHash.SLOT_COUNT / nodes.length;
for (int i = 0; i < nodes.length; i++) {
final int startInclusive = i * slotsPerNode;
final int endExclusive = i == nodes.length - 1 ? SlotHash.SLOT_COUNT : (i + 1) * slotsPerNode;
final RedisClient assignSlotClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[i].ports().get(0)));
try (final StatefulRedisConnection<String, String> assignSlotConnection = assignSlotClient.connect()) {
final int[] slots = new int[endExclusive - startInclusive];
for (int s = startInclusive; s < endExclusive; s++) {
slots[s - startInclusive] = s;
}
assignSlotConnection.sync().clusterAddSlots(slots);
} finally {
assignSlotClient.shutdown();
}
}
final RedisClient waitClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[0].ports().get(0)));
try (final StatefulRedisConnection<String, String> connection = waitClient.connect()) {
// CLUSTER INFO gives us a big blob of key-value pairs, but the one we're interested in is `cluster_state`.
// According to https://redis.io/commands/cluster-info, `cluster_state:ok` means that the node is ready to
// receive queries, all slots are assigned, and a majority of master nodes are reachable.
while (!connection.sync().clusterInfo().contains("cluster_state:ok")) {
Thread.sleep(500);
}
} finally {
waitClient.shutdown();
}
}
public static int getNextRedisClusterPort() throws IOException {
final int MAX_ITERATIONS = 11_000;
int port;
for (int i = 0; i < MAX_ITERATIONS; i++) {
try (ServerSocket socket = new ServerSocket(0)) {
socket.setReuseAddress(false);
port = socket.getLocalPort();
}
if (port < 55535) {
return port;
}
}
throw new IOException("Couldn't find an open port below 55,535 in " + MAX_ITERATIONS + " tries");
}
}

View File

@ -1,11 +1,11 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* Copyright 2013-2022 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.redis;
import static org.junit.Assume.assumeFalse;
import static org.junit.jupiter.api.Assumptions.assumeFalse;
import io.lettuce.core.RedisClient;
import io.lettuce.core.RedisException;
@ -178,9 +178,9 @@ public class RedisClusterExtension implements BeforeAllCallback, BeforeEachCallb
Thread.sleep(sleepMillis);
tries++;
if (tries == 10) {
if (tries == 20) {
throw new RuntimeException(
String.format("Timeout: Redis not ready after waiting %d milliseconds", sleepMillis));
String.format("Timeout: Redis not ready after waiting %d milliseconds", tries * sleepMillis));
}
}
} finally {

View File

@ -5,7 +5,7 @@
package org.whispersystems.textsecuregcm.storage;
import static org.junit.Assert.assertFalse;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
@ -18,11 +18,15 @@ import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import org.junit.Before;
import org.junit.Test;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
public class AccountDatabaseCrawlerIntegrationTest extends AbstractRedisClusterTest {
class AccountDatabaseCrawlerIntegrationTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private static final UUID FIRST_UUID = UUID.fromString("82339e80-81cd-48e2-9ed2-ccd5dd262ad9");
private static final UUID SECOND_UUID = UUID.fromString("cc705c84-33cf-456b-8239-a6a34e2f561a");
@ -38,9 +42,8 @@ public class AccountDatabaseCrawlerIntegrationTest extends AbstractRedisClusterT
private static final int CHUNK_SIZE = 1;
private static final long CHUNK_INTERVAL_MS = 0;
@Before
public void setUp() throws Exception {
super.setUp();
@BeforeEach
void setUp() throws Exception {
firstAccount = mock(Account.class);
secondAccount = mock(Account.class);
@ -57,13 +60,13 @@ public class AccountDatabaseCrawlerIntegrationTest extends AbstractRedisClusterT
.thenReturn(new AccountCrawlChunk(List.of(secondAccount), SECOND_UUID))
.thenReturn(new AccountCrawlChunk(Collections.emptyList(), null));
final AccountDatabaseCrawlerCache crawlerCache = new AccountDatabaseCrawlerCache(getRedisCluster(), "test");
final AccountDatabaseCrawlerCache crawlerCache = new AccountDatabaseCrawlerCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), "test");
accountDatabaseCrawler = new AccountDatabaseCrawler("test", accountsManager, crawlerCache, List.of(listener), CHUNK_SIZE,
CHUNK_INTERVAL_MS);
}
@Test
public void testCrawlUninterrupted() throws AccountDatabaseCrawlerRestartException {
void testCrawlUninterrupted() throws AccountDatabaseCrawlerRestartException {
assertFalse(accountDatabaseCrawler.doPeriodicWork());
assertFalse(accountDatabaseCrawler.doPeriodicWork());
assertFalse(accountDatabaseCrawler.doPeriodicWork());
@ -79,7 +82,7 @@ public class AccountDatabaseCrawlerIntegrationTest extends AbstractRedisClusterT
}
@Test
public void testCrawlWithReset() throws AccountDatabaseCrawlerRestartException {
void testCrawlWithReset() throws AccountDatabaseCrawlerRestartException {
doThrow(new AccountDatabaseCrawlerRestartException("OH NO")).doNothing()
.when(listener).timeAndProcessCrawlChunk(Optional.empty(), List.of(firstAccount));

View File

@ -5,7 +5,7 @@
package org.whispersystems.textsecuregcm.storage;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
@ -28,194 +28,204 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
public class MessagePersisterTest extends AbstractRedisClusterTest {
class MessagePersisterTest {
private ExecutorService notificationExecutorService;
private MessagesCache messagesCache;
private MessagesDynamoDb messagesDynamoDb;
private MessagePersister messagePersister;
private AccountsManager accountsManager;
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private static final UUID DESTINATION_ACCOUNT_UUID = UUID.randomUUID();
private static final String DESTINATION_ACCOUNT_NUMBER = "+18005551234";
private static final long DESTINATION_DEVICE_ID = 7;
private ExecutorService notificationExecutorService;
private MessagesCache messagesCache;
private MessagesDynamoDb messagesDynamoDb;
private MessagePersister messagePersister;
private AccountsManager accountsManager;
private static final Duration PERSIST_DELAY = Duration.ofMinutes(5);
private static final UUID DESTINATION_ACCOUNT_UUID = UUID.randomUUID();
private static final String DESTINATION_ACCOUNT_NUMBER = "+18005551234";
private static final long DESTINATION_DEVICE_ID = 7;
@Override
@Before
public void setUp() throws Exception {
super.setUp();
private static final Duration PERSIST_DELAY = Duration.ofMinutes(5);
final MessagesManager messagesManager = mock(MessagesManager.class);
final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
@BeforeEach
void setUp() throws Exception {
messagesDynamoDb = mock(MessagesDynamoDb.class);
accountsManager = mock(AccountsManager.class);
final MessagesManager messagesManager = mock(MessagesManager.class);
final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
final Account account = mock(Account.class);
messagesDynamoDb = mock(MessagesDynamoDb.class);
accountsManager = mock(AccountsManager.class);
when(accountsManager.getByAccountIdentifier(DESTINATION_ACCOUNT_UUID)).thenReturn(Optional.of(account));
when(account.getNumber()).thenReturn(DESTINATION_ACCOUNT_NUMBER);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
final Account account = mock(Account.class);
notificationExecutorService = Executors.newSingleThreadExecutor();
messagesCache = new MessagesCache(getRedisCluster(), getRedisCluster(), notificationExecutorService);
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, dynamicConfigurationManager, PERSIST_DELAY);
when(accountsManager.getByAccountIdentifier(DESTINATION_ACCOUNT_UUID)).thenReturn(Optional.of(account));
when(account.getNumber()).thenReturn(DESTINATION_ACCOUNT_NUMBER);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
doAnswer(invocation -> {
final UUID destinationUuid = invocation.getArgument(0);
final long destinationDeviceId = invocation.getArgument(1);
final List<MessageProtos.Envelope> messages = invocation.getArgument(2);
notificationExecutorService = Executors.newSingleThreadExecutor();
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
REDIS_CLUSTER_EXTENSION.getRedisCluster(), notificationExecutorService);
messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager,
dynamicConfigurationManager, PERSIST_DELAY);
messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId);
doAnswer(invocation -> {
final UUID destinationUuid = invocation.getArgument(0);
final long destinationDeviceId = invocation.getArgument(1);
final List<MessageProtos.Envelope> messages = invocation.getArgument(2);
for (final MessageProtos.Envelope message : messages) {
messagesCache.remove(destinationUuid, destinationDeviceId, UUID.fromString(message.getServerGuid()));
}
messagesDynamoDb.store(messages, destinationUuid, destinationDeviceId);
return null;
}).when(messagesManager).persistMessages(any(UUID.class), anyLong(), any());
for (final MessageProtos.Envelope message : messages) {
messagesCache.remove(destinationUuid, destinationDeviceId, UUID.fromString(message.getServerGuid()));
}
return null;
}).when(messagesManager).persistMessages(any(UUID.class), anyLong(), any());
}
@AfterEach
void tearDown() throws Exception {
notificationExecutorService.shutdown();
notificationExecutorService.awaitTermination(1, TimeUnit.SECONDS);
}
@Test
void testPersistNextQueuesNoQueues() {
messagePersister.persistNextQueues(Instant.now());
verify(accountsManager, never()).getByAccountIdentifier(any(UUID.class));
}
@Test
void testPersistNextQueuesSingleQueue() {
final String queueName = new String(
MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7;
final Instant now = Instant.now();
insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now);
setNextSlotToPersist(SlotHash.getSlot(queueName));
messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay()));
final ArgumentCaptor<List<MessageProtos.Envelope>> messagesCaptor = ArgumentCaptor.forClass(List.class);
verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), eq(DESTINATION_ACCOUNT_UUID),
eq(DESTINATION_DEVICE_ID));
assertEquals(messageCount, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum());
}
@Test
void testPersistNextQueuesSingleQueueTooSoon() {
final String queueName = new String(
MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7;
final Instant now = Instant.now();
insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now);
setNextSlotToPersist(SlotHash.getSlot(queueName));
messagePersister.persistNextQueues(now);
verify(messagesDynamoDb, never()).store(any(), any(), anyLong());
}
@Test
void testPersistNextQueuesMultiplePages() {
final int slot = 7;
final int queueCount = (MessagePersister.QUEUE_BATCH_LIMIT * 3) + 7;
final int messagesPerQueue = 10;
final Instant now = Instant.now();
for (int i = 0; i < queueCount; i++) {
final String queueName = generateRandomQueueNameForSlot(slot);
final UUID accountUuid = MessagesCache.getAccountUuidFromQueueName(queueName);
final long deviceId = MessagesCache.getDeviceIdFromQueueName(queueName);
final String accountNumber = "+1" + RandomStringUtils.randomNumeric(10);
final Account account = mock(Account.class);
when(accountsManager.getByAccountIdentifier(accountUuid)).thenReturn(Optional.of(account));
when(account.getNumber()).thenReturn(accountNumber);
insertMessages(accountUuid, deviceId, messagesPerQueue, now);
}
@Override
public void tearDown() throws Exception {
super.tearDown();
setNextSlotToPersist(slot);
notificationExecutorService.shutdown();
notificationExecutorService.awaitTermination(1, TimeUnit.SECONDS);
messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay()));
final ArgumentCaptor<List<MessageProtos.Envelope>> messagesCaptor = ArgumentCaptor.forClass(List.class);
verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), any(UUID.class), anyLong());
assertEquals(queueCount * messagesPerQueue, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum());
}
@Test
void testPersistQueueRetry() {
final String queueName = new String(
MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7;
final Instant now = Instant.now();
insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now);
setNextSlotToPersist(SlotHash.getSlot(queueName));
doAnswer((Answer<Void>) invocation -> {
throw new RuntimeException("OH NO.");
}).when(messagesDynamoDb).store(any(), eq(DESTINATION_ACCOUNT_UUID), eq(DESTINATION_DEVICE_ID));
messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay()));
assertEquals(List.of(queueName),
messagesCache.getQueuesToPersist(SlotHash.getSlot(queueName),
Instant.now().plus(messagePersister.getPersistDelay()), 1));
}
@SuppressWarnings("SameParameterValue")
private static String generateRandomQueueNameForSlot(final int slot) {
final UUID uuid = UUID.randomUUID();
final String queueNameBase = "user_queue::{" + uuid + "::";
for (int deviceId = 0; deviceId < Integer.MAX_VALUE; deviceId++) {
final String queueName = queueNameBase + deviceId + "}";
if (SlotHash.getSlot(queueName) == slot) {
return queueName;
}
}
@Test
public void testPersistNextQueuesNoQueues() {
messagePersister.persistNextQueues(Instant.now());
throw new IllegalStateException("Could not find a queue name for slot " + slot);
}
verify(accountsManager, never()).getByAccountIdentifier(any(UUID.class));
private void insertMessages(final UUID accountUuid, final long deviceId, final int messageCount,
final Instant firstMessageTimestamp) {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = MessageProtos.Envelope.newBuilder()
.setTimestamp(firstMessageTimestamp.toEpochMilli() + i)
.setServerTimestamp(firstMessageTimestamp.toEpochMilli() + i)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString())
.build();
messagesCache.insert(messageGuid, accountUuid, deviceId, envelope);
}
}
@Test
public void testPersistNextQueuesSingleQueue() {
final String queueName = new String(MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7;
final Instant now = Instant.now();
insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now);
setNextSlotToPersist(SlotHash.getSlot(queueName));
messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay()));
final ArgumentCaptor<List<MessageProtos.Envelope>> messagesCaptor = ArgumentCaptor.forClass(List.class);
verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), eq(DESTINATION_ACCOUNT_UUID), eq(DESTINATION_DEVICE_ID));
assertEquals(messageCount, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum());
}
@Test
public void testPersistNextQueuesSingleQueueTooSoon() {
final String queueName = new String(MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7;
final Instant now = Instant.now();
insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now);
setNextSlotToPersist(SlotHash.getSlot(queueName));
messagePersister.persistNextQueues(now);
verify(messagesDynamoDb, never()).store(any(), any(), anyLong());
}
@Test
public void testPersistNextQueuesMultiplePages() {
final int slot = 7;
final int queueCount = (MessagePersister.QUEUE_BATCH_LIMIT * 3) + 7;
final int messagesPerQueue = 10;
final Instant now = Instant.now();
for (int i = 0; i < queueCount; i++) {
final String queueName = generateRandomQueueNameForSlot(slot);
final UUID accountUuid = MessagesCache.getAccountUuidFromQueueName(queueName);
final long deviceId = MessagesCache.getDeviceIdFromQueueName(queueName);
final String accountNumber = "+1" + RandomStringUtils.randomNumeric(10);
final Account account = mock(Account.class);
when(accountsManager.getByAccountIdentifier(accountUuid)).thenReturn(Optional.of(account));
when(account.getNumber()).thenReturn(accountNumber);
insertMessages(accountUuid, deviceId, messagesPerQueue, now);
}
setNextSlotToPersist(slot);
messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay()));
final ArgumentCaptor<List<MessageProtos.Envelope>> messagesCaptor = ArgumentCaptor.forClass(List.class);
verify(messagesDynamoDb, atLeastOnce()).store(messagesCaptor.capture(), any(UUID.class), anyLong());
assertEquals(queueCount * messagesPerQueue, messagesCaptor.getAllValues().stream().mapToInt(List::size).sum());
}
@Test
public void testPersistQueueRetry() {
final String queueName = new String(MessagesCache.getMessageQueueKey(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID), StandardCharsets.UTF_8);
final int messageCount = (MessagePersister.MESSAGE_BATCH_LIMIT * 3) + 7;
final Instant now = Instant.now();
insertMessages(DESTINATION_ACCOUNT_UUID, DESTINATION_DEVICE_ID, messageCount, now);
setNextSlotToPersist(SlotHash.getSlot(queueName));
doAnswer((Answer<Void>)invocation -> {
throw new RuntimeException("OH NO.");
}).when(messagesDynamoDb).store(any(), eq(DESTINATION_ACCOUNT_UUID), eq(DESTINATION_DEVICE_ID));
messagePersister.persistNextQueues(now.plus(messagePersister.getPersistDelay()));
assertEquals(List.of(queueName),
messagesCache.getQueuesToPersist(SlotHash.getSlot(queueName), Instant.now().plus(messagePersister.getPersistDelay()), 1));
}
@SuppressWarnings("SameParameterValue")
private static String generateRandomQueueNameForSlot(final int slot) {
final UUID uuid = UUID.randomUUID();
final String queueNameBase = "user_queue::{" + uuid.toString() + "::";
for (int deviceId = 0; deviceId < Integer.MAX_VALUE; deviceId++) {
final String queueName = queueNameBase + deviceId + "}";
if (SlotHash.getSlot(queueName) == slot) {
return queueName;
}
}
throw new IllegalStateException("Could not find a queue name for slot " + slot);
}
private void insertMessages(final UUID accountUuid, final long deviceId, final int messageCount, final Instant firstMessageTimestamp) {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = MessageProtos.Envelope.newBuilder()
.setTimestamp(firstMessageTimestamp.toEpochMilli() + i)
.setServerTimestamp(firstMessageTimestamp.toEpochMilli() + i)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString())
.build();
messagesCache.insert(messageGuid, accountUuid, deviceId, envelope);
}
}
private void setNextSlotToPersist(final int nextSlot) {
getRedisCluster().useCluster(connection -> connection.sync().set(MessagesCache.NEXT_SLOT_TO_PERSIST_KEY, String.valueOf(nextSlot - 1)));
}
private void setNextSlotToPersist(final int nextSlot) {
REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(
connection -> connection.sync().set(MessagesCache.NEXT_SLOT_TO_PERSIST_KEY, String.valueOf(nextSlot - 1)));
}
}

View File

@ -5,13 +5,15 @@
package org.whispersystems.textsecuregcm.storage;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.google.protobuf.ByteString;
import io.lettuce.core.cluster.SlotHash;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
@ -24,18 +26,21 @@ import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
@RunWith(JUnitParamsRunner.class)
public class MessagesCacheTest extends AbstractRedisClusterTest {
class MessagesCacheTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private ExecutorService notificationExecutorService;
private MessagesCache messagesCache;
@ -46,40 +51,38 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
private static final UUID DESTINATION_UUID = UUID.randomUUID();
private static final int DESTINATION_DEVICE_ID = 7;
@Override
@Before
public void setUp() throws Exception {
super.setUp();
@BeforeEach
void setUp() throws Exception {
getRedisCluster().useCluster(
connection -> connection.sync().upstream().commands().configSet("notify-keyspace-events", "Klgz"));
REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(connection -> {
connection.sync().flushall();
connection.sync().upstream().commands().configSet("notify-keyspace-events", "K$glz");
});
notificationExecutorService = Executors.newSingleThreadExecutor();
messagesCache = new MessagesCache(getRedisCluster(), getRedisCluster(), notificationExecutorService);
messagesCache = new MessagesCache(REDIS_CLUSTER_EXTENSION.getRedisCluster(), REDIS_CLUSTER_EXTENSION.getRedisCluster(), notificationExecutorService);
messagesCache.start();
}
@Override
public void tearDown() throws Exception {
@AfterEach
void tearDown() throws Exception {
messagesCache.stop();
notificationExecutorService.shutdown();
notificationExecutorService.awaitTermination(1, TimeUnit.SECONDS);
super.tearDown();
}
@Test
@Parameters({"true", "false"})
public void testInsert(final boolean sealedSender) {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testInsert(final boolean sealedSender) {
final UUID messageGuid = UUID.randomUUID();
assertTrue(messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
generateRandomMessage(messageGuid, sealedSender)) > 0);
}
@Test
public void testDoubleInsertGuid() {
void testDoubleInsertGuid() {
final UUID duplicateGuid = UUID.randomUUID();
final MessageProtos.Envelope duplicateMessage = generateRandomMessage(duplicateGuid, false);
@ -90,9 +93,9 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
assertEquals(firstId, secondId);
}
@Test
@Parameters({"true", "false"})
public void testRemoveByUUID(final boolean sealedSender) {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testRemoveByUUID(final boolean sealedSender) {
final UUID messageGuid = UUID.randomUUID();
assertEquals(Optional.empty(), messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid));
@ -107,9 +110,9 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
assertEquals(MessagesCache.constructEntityFromEnvelope(message), maybeRemovedMessage.get());
}
@Test
@Parameters({"true", "false"})
public void testRemoveBatchByUUID(final boolean sealedSender) {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testRemoveBatchByUUID(final boolean sealedSender) {
final int messageCount = 10;
final List<MessageProtos.Envelope> messagesToRemove = new ArrayList<>(messageCount);
@ -145,7 +148,7 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
}
@Test
public void testHasMessages() {
void testHasMessages() {
assertFalse(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID));
final UUID messageGuid = UUID.randomUUID();
@ -155,9 +158,9 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
assertTrue(messagesCache.hasMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID));
}
@Test
@Parameters({"true", "false"})
public void testGetMessages(final boolean sealedSender) {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testGetMessages(final boolean sealedSender) {
final int messageCount = 100;
final List<OutgoingMessageEntity> expectedMessages = new ArrayList<>(messageCount);
@ -173,9 +176,9 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
assertEquals(expectedMessages, messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageCount));
}
@Test
@Parameters({"true", "false"})
public void testClearQueueForDevice(final boolean sealedSender) {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testClearQueueForDevice(final boolean sealedSender) {
final int messageCount = 100;
for (final int deviceId : new int[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) {
@ -193,9 +196,9 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
assertEquals(messageCount, messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount).size());
}
@Test
@Parameters({"true", "false"})
public void testClearQueueForAccount(final boolean sealedSender) {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testClearQueueForAccount(final boolean sealedSender) {
final int messageCount = 100;
for (final int deviceId : new int[]{DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1}) {
@ -236,13 +239,13 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
}
@Test
public void testClearNullUuid() {
void testClearNullUuid() {
// We're happy as long as this doesn't throw an exception
messagesCache.clear(null);
}
@Test
public void testGetAccountFromQueueName() {
void testGetAccountFromQueueName() {
assertEquals(DESTINATION_UUID,
MessagesCache.getAccountUuidFromQueueName(
new String(MessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID),
@ -250,7 +253,7 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
}
@Test
public void testGetDeviceIdFromQueueName() {
void testGetDeviceIdFromQueueName() {
assertEquals(DESTINATION_DEVICE_ID,
MessagesCache.getDeviceIdFromQueueName(
new String(MessagesCache.getMessageQueueKey(DESTINATION_UUID, DESTINATION_DEVICE_ID),
@ -258,14 +261,14 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
}
@Test
public void testGetQueueNameFromKeyspaceChannel() {
void testGetQueueNameFromKeyspaceChannel() {
assertEquals("1b363a31-a429-4fb6-8959-984a025e72ff::7",
MessagesCache.getQueueNameFromKeyspaceChannel(
"__keyspace@0__:user_queue::{1b363a31-a429-4fb6-8959-984a025e72ff::7}"));
}
@Test
@Parameters({"true", "false"})
@ParameterizedTest
@ValueSource(booleans = {true, false})
public void testGetQueuesToPersist(final boolean sealedSender) {
final UUID messageGuid = UUID.randomUUID();
@ -282,8 +285,8 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
assertEquals(DESTINATION_DEVICE_ID, MessagesCache.getDeviceIdFromQueueName(queues.get(0)));
}
@Test(timeout = 5_000L)
public void testNotifyListenerNewMessage() throws InterruptedException {
@Test
void testNotifyListenerNewMessage() {
final AtomicBoolean notified = new AtomicBoolean(false);
final UUID messageGuid = UUID.randomUUID();
@ -301,21 +304,23 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
}
};
messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
generateRandomMessage(messageGuid, true));
assertTimeoutPreemptively(Duration.ofSeconds(5), () -> {
messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID,
generateRandomMessage(messageGuid, true));
synchronized (notified) {
while (!notified.get()) {
notified.wait();
synchronized (notified) {
while (!notified.get()) {
notified.wait();
}
}
}
assertTrue(notified.get());
assertTrue(notified.get());
});
}
@Test(timeout = 5_000L)
public void testNotifyListenerPersisted() throws InterruptedException {
@Test
void testNotifyListenerPersisted() {
final AtomicBoolean notified = new AtomicBoolean(false);
final MessageAvailabilityListener listener = new MessageAvailabilityListener() {
@ -332,18 +337,20 @@ public class MessagesCacheTest extends AbstractRedisClusterTest {
}
};
messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener);
assertTimeoutPreemptively(Duration.ofSeconds(5), () -> {
messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener);
messagesCache.lockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID);
messagesCache.unlockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID);
messagesCache.lockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID);
messagesCache.unlockQueueForPersistence(DESTINATION_UUID, DESTINATION_DEVICE_ID);
synchronized (notified) {
while (!notified.get()) {
notified.wait();
synchronized (notified) {
while (!notified.get()) {
notified.wait();
}
}
}
assertTrue(notified.get());
assertTrue(notified.get());
});
}
}