diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/CardinalityRateLimiterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/CardinalityRateLimiterTest.java index 1065f615e..e670a799b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/CardinalityRateLimiterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/CardinalityRateLimiterTest.java @@ -5,33 +5,26 @@ package org.whispersystems.textsecuregcm.limits; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import java.time.Duration; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; -import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; -public class CardinalityRateLimiterTest extends AbstractRedisClusterTest { +class CardinalityRateLimiterTest { - @Before - public void setUp() throws Exception { - super.setUp(); - } - - @After - public void tearDown() throws Exception { - super.tearDown(); - } + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); @Test - public void testValidate() { + void testValidate() { final int maxCardinality = 10; final CardinalityRateLimiter rateLimiter = - new CardinalityRateLimiter(getRedisCluster(), "test", Duration.ofDays(1), maxCardinality); + new CardinalityRateLimiter(REDIS_CLUSTER_EXTENSION.getRedisCluster(), "test", Duration.ofDays(1), + maxCardinality); final String source = "+18005551234"; int validatedAttempts = 0; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitResetMetricsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitResetMetricsManagerTest.java index de19204e4..cfc56fc7a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitResetMetricsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitResetMetricsManagerTest.java @@ -1,6 +1,6 @@ package org.whispersystems.textsecuregcm.limits; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -8,27 +8,28 @@ import io.dropwizard.util.Duration; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.simple.SimpleMeterRegistry; import java.util.UUID; -import org.junit.Before; -import org.junit.Test; -import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.storage.Account; -public class RateLimitResetMetricsManagerTest extends AbstractRedisClusterTest { +class RateLimitResetMetricsManagerTest { + + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); private RateLimitResetMetricsManager metricsManager; private SimpleMeterRegistry meterRegistry; - @Before - @Override - public void setUp() throws Exception { - super.setUp(); - + @BeforeEach + void setUp() { meterRegistry = new SimpleMeterRegistry(); - metricsManager = new RateLimitResetMetricsManager(getRedisCluster(), meterRegistry); + metricsManager = new RateLimitResetMetricsManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(), meterRegistry); } @Test - public void testRecordMetrics() { + void testRecordMetrics() { final Account firstAccount = mock(Account.class); when(firstAccount.getUuid()).thenReturn(UUID.randomUUID()); @@ -45,13 +46,15 @@ public class RateLimitResetMetricsManagerTest extends AbstractRedisClusterTest { .orElseThrow(); assertEquals(3, counterTotal, 0.0); - final long enforcedCount = getRedisCluster().withCluster(conn -> conn.sync().pfcount("enforced")); + final long enforcedCount = REDIS_CLUSTER_EXTENSION.getRedisCluster() + .withCluster(conn -> conn.sync().pfcount("enforced")); assertEquals(1L, enforcedCount); - final long unenforcedCount = getRedisCluster().withCluster(conn -> conn.sync().pfcount("unenforced")); + final long unenforcedCount = REDIS_CLUSTER_EXTENSION.getRedisCluster() + .withCluster(conn -> conn.sync().pfcount("unenforced")); assertEquals(1L, unenforcedCount); - final long total = getRedisCluster().withCluster(conn -> conn.sync().pfcount("total")); + final long total = REDIS_CLUSTER_EXTENSION.getRedisCluster().withCluster(conn -> conn.sync().pfcount("total")); assertEquals(2L, total); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/UnsealedSenderRateLimiterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/UnsealedSenderRateLimiterTest.java index 6590c8866..e519e7ac4 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/UnsealedSenderRateLimiterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/UnsealedSenderRateLimiterTest.java @@ -5,24 +5,27 @@ package org.whispersystems.textsecuregcm.limits; -import static org.junit.Assert.assertThrows; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import java.time.Duration; import java.util.UUID; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; -import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitChallengeConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitsConfiguration; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; -import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; -public class UnsealedSenderRateLimiterTest extends AbstractRedisClusterTest { +class UnsealedSenderRateLimiterTest { + + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); private Account sender; private Account firstDestination; @@ -32,14 +35,12 @@ public class UnsealedSenderRateLimiterTest extends AbstractRedisClusterTest { private DynamicRateLimitChallengeConfiguration rateLimitChallengeConfiguration; - @Before - @Override - public void setUp() throws Exception { - super.setUp(); + @BeforeEach + void setUp() throws Exception { final RateLimiters rateLimiters = mock(RateLimiters.class); final CardinalityRateLimiter cardinalityRateLimiter = - new CardinalityRateLimiter(getRedisCluster(), "test", Duration.ofDays(1), 1); + new CardinalityRateLimiter(REDIS_CLUSTER_EXTENSION.getRedisCluster(), "test", Duration.ofDays(1), 1); when(rateLimiters.getUnsealedSenderCardinalityLimiter()).thenReturn(cardinalityRateLimiter); when(rateLimiters.getRateLimitResetLimiter()).thenReturn(mock(RateLimiter.class)); @@ -56,7 +57,8 @@ public class UnsealedSenderRateLimiterTest extends AbstractRedisClusterTest { when(dynamicConfiguration.getRateLimitChallengeConfiguration()).thenReturn(rateLimitChallengeConfiguration); when(rateLimitChallengeConfiguration.isUnsealedSenderLimitEnforced()).thenReturn(true); - unsealedSenderRateLimiter = new UnsealedSenderRateLimiter(rateLimiters, getRedisCluster(), dynamicConfigurationManager, + unsealedSenderRateLimiter = new UnsealedSenderRateLimiter(rateLimiters, REDIS_CLUSTER_EXTENSION.getRedisCluster(), + dynamicConfigurationManager, mock(RateLimitResetMetricsManager.class)); sender = mock(Account.class); @@ -73,7 +75,7 @@ public class UnsealedSenderRateLimiterTest extends AbstractRedisClusterTest { } @Test - public void validate() throws RateLimitExceededException { + void validate() throws RateLimitExceededException { unsealedSenderRateLimiter.validate(sender, firstDestination); assertThrows(RateLimitExceededException.class, () -> unsealedSenderRateLimiter.validate(sender, secondDestination)); @@ -82,7 +84,7 @@ public class UnsealedSenderRateLimiterTest extends AbstractRedisClusterTest { } @Test - public void handleRateLimitReset() throws RateLimitExceededException { + void handleRateLimitReset() throws RateLimitExceededException { unsealedSenderRateLimiter.validate(sender, firstDestination); assertThrows(RateLimitExceededException.class, () -> unsealedSenderRateLimiter.validate(sender, secondDestination)); @@ -93,7 +95,7 @@ public class UnsealedSenderRateLimiterTest extends AbstractRedisClusterTest { } @Test - public void enforcementConfiguration() throws RateLimitExceededException { + void enforcementConfiguration() throws RateLimitExceededException { when(rateLimitChallengeConfiguration.isUnsealedSenderLimitEnforced()).thenReturn(false); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/PushLatencyManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/PushLatencyManagerTest.java index 63ad01463..f3cf26810 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/PushLatencyManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/PushLatencyManagerTest.java @@ -5,33 +5,39 @@ package org.whispersystems.textsecuregcm.metrics; -import org.junit.Test; -import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; import java.util.UUID; import java.util.concurrent.ExecutionException; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; +class PushLatencyManagerTest { -public class PushLatencyManagerTest extends AbstractRedisClusterTest { + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); - @Test - public void testGetLatency() throws ExecutionException, InterruptedException { - final PushLatencyManager pushLatencyManager = new PushLatencyManager(getRedisCluster()); - final UUID accountUuid = UUID.randomUUID(); - final long deviceId = 1; - final long expectedLatency = 1234; - final long pushSentTimestamp = System.currentTimeMillis(); - final long clearQueueTimestamp = pushSentTimestamp + expectedLatency; + @Test + void testGetLatency() throws ExecutionException, InterruptedException { - assertNull(pushLatencyManager.getLatencyAndClearTimestamp(accountUuid, deviceId, System.currentTimeMillis()).get()); + final PushLatencyManager pushLatencyManager = new PushLatencyManager(REDIS_CLUSTER_EXTENSION.getRedisCluster()); + final UUID accountUuid = UUID.randomUUID(); + final long deviceId = 1; + final long expectedLatency = 1234; + final long pushSentTimestamp = System.currentTimeMillis(); + final long clearQueueTimestamp = pushSentTimestamp + expectedLatency; - { - pushLatencyManager.recordPushSent(accountUuid, deviceId, pushSentTimestamp); + assertNull(pushLatencyManager.getLatencyAndClearTimestamp(accountUuid, deviceId, System.currentTimeMillis()).get()); - assertEquals(expectedLatency, (long)pushLatencyManager.getLatencyAndClearTimestamp(accountUuid, deviceId, clearQueueTimestamp).get()); - assertNull(pushLatencyManager.getLatencyAndClearTimestamp(accountUuid, deviceId, System.currentTimeMillis()).get()); - } + { + pushLatencyManager.recordPushSent(accountUuid, deviceId, pushSentTimestamp); + + assertEquals(expectedLatency, + (long) pushLatencyManager.getLatencyAndClearTimestamp(accountUuid, deviceId, clearQueueTimestamp).get()); + assertNull( + pushLatencyManager.getLatencyAndClearTimestamp(accountUuid, deviceId, System.currentTimeMillis()).get()); } + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/push/ApnFallbackManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/push/ApnFallbackManagerTest.java index 80cf0dfa9..5fc0b658c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/push/ApnFallbackManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/push/ApnFallbackManagerTest.java @@ -5,107 +5,111 @@ package org.whispersystems.textsecuregcm.push; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + import io.lettuce.core.cluster.SlotHash; +import java.util.List; +import java.util.Optional; +import java.util.UUID; import org.apache.commons.lang3.RandomStringUtils; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; import org.mockito.ArgumentCaptor; -import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisException; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.Pair; -import java.util.List; -import java.util.Optional; -import java.util.UUID; +class ApnFallbackManagerTest { -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); -public class ApnFallbackManagerTest extends AbstractRedisClusterTest { + private Account account; + private Device device; - private Account account; - private Device device; + private APNSender apnSender; - private APNSender apnSender; + private ApnFallbackManager apnFallbackManager; - private ApnFallbackManager apnFallbackManager; + private static final UUID ACCOUNT_UUID = UUID.randomUUID(); + private static final String ACCOUNT_NUMBER = "+18005551234"; + private static final long DEVICE_ID = 1L; + private static final String VOIP_APN_ID = RandomStringUtils.randomAlphanumeric(32); - private static final UUID ACCOUNT_UUID = UUID.randomUUID(); - private static final String ACCOUNT_NUMBER = "+18005551234"; - private static final long DEVICE_ID = 1L; - private static final String VOIP_APN_ID = RandomStringUtils.randomAlphanumeric(32); + @BeforeEach + void setUp() throws Exception { - @Before - public void setUp() throws Exception { - super.setUp(); + device = mock(Device.class); + when(device.getId()).thenReturn(DEVICE_ID); + when(device.getVoipApnId()).thenReturn(VOIP_APN_ID); + when(device.getLastSeen()).thenReturn(System.currentTimeMillis()); - device = mock(Device.class); - when(device.getId()).thenReturn(DEVICE_ID); - when(device.getVoipApnId()).thenReturn(VOIP_APN_ID); - when(device.getLastSeen()).thenReturn(System.currentTimeMillis()); + account = mock(Account.class); + when(account.getUuid()).thenReturn(ACCOUNT_UUID); + when(account.getNumber()).thenReturn(ACCOUNT_NUMBER); + when(account.getDevice(DEVICE_ID)).thenReturn(Optional.of(device)); - account = mock(Account.class); - when(account.getUuid()).thenReturn(ACCOUNT_UUID); - when(account.getNumber()).thenReturn(ACCOUNT_NUMBER); - when(account.getDevice(DEVICE_ID)).thenReturn(Optional.of(device)); + final AccountsManager accountsManager = mock(AccountsManager.class); + when(accountsManager.get(ACCOUNT_NUMBER)).thenReturn(Optional.of(account)); + when(accountsManager.get(ACCOUNT_UUID)).thenReturn(Optional.of(account)); - final AccountsManager accountsManager = mock(AccountsManager.class); - when(accountsManager.get(ACCOUNT_NUMBER)).thenReturn(Optional.of(account)); - when(accountsManager.get(ACCOUNT_UUID)).thenReturn(Optional.of(account)); + apnSender = mock(APNSender.class); - apnSender = mock(APNSender.class); + apnFallbackManager = new ApnFallbackManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(), apnSender, accountsManager); + } - apnFallbackManager = new ApnFallbackManager(getRedisCluster(), apnSender, accountsManager); - } + @Test + void testClusterInsert() throws RedisException { + final String endpoint = apnFallbackManager.getEndpointKey(account, device); - @Test - public void testClusterInsert() throws RedisException { - final String endpoint = apnFallbackManager.getEndpointKey(account, device); + assertTrue(apnFallbackManager.getPendingDestinations(SlotHash.getSlot(endpoint), 1).isEmpty()); - assertTrue(apnFallbackManager.getPendingDestinations(SlotHash.getSlot(endpoint), 1).isEmpty()); + apnFallbackManager.schedule(account, device, System.currentTimeMillis() - 30_000); - apnFallbackManager.schedule(account, device, System.currentTimeMillis() - 30_000); + final List pendingDestinations = apnFallbackManager.getPendingDestinations(SlotHash.getSlot(endpoint), 2); + assertEquals(1, pendingDestinations.size()); - final List pendingDestinations = apnFallbackManager.getPendingDestinations(SlotHash.getSlot(endpoint), 2); - assertEquals(1, pendingDestinations.size()); + final Optional> maybeUuidAndDeviceId = ApnFallbackManager.getSeparated( + pendingDestinations.get(0)); - final Optional> maybeUuidAndDeviceId = ApnFallbackManager.getSeparated(pendingDestinations.get(0)); + assertTrue(maybeUuidAndDeviceId.isPresent()); + assertEquals(ACCOUNT_UUID.toString(), maybeUuidAndDeviceId.get().first()); + assertEquals(DEVICE_ID, (long) maybeUuidAndDeviceId.get().second()); - assertTrue(maybeUuidAndDeviceId.isPresent()); - assertEquals(ACCOUNT_UUID.toString(), maybeUuidAndDeviceId.get().first()); - assertEquals(DEVICE_ID, (long)maybeUuidAndDeviceId.get().second()); + assertTrue(apnFallbackManager.getPendingDestinations(SlotHash.getSlot(endpoint), 1).isEmpty()); + } - assertTrue(apnFallbackManager.getPendingDestinations(SlotHash.getSlot(endpoint), 1).isEmpty()); - } + @Test + void testProcessNextSlot() throws RedisException { + final ApnFallbackManager.NotificationWorker worker = apnFallbackManager.new NotificationWorker(); - @Test - public void testProcessNextSlot() throws RedisException { - final ApnFallbackManager.NotificationWorker worker = apnFallbackManager.new NotificationWorker(); + apnFallbackManager.schedule(account, device, System.currentTimeMillis() - 30_000); - apnFallbackManager.schedule(account, device, System.currentTimeMillis() - 30_000); + final int slot = SlotHash.getSlot(apnFallbackManager.getEndpointKey(account, device)); + final int previousSlot = (slot + SlotHash.SLOT_COUNT - 1) % SlotHash.SLOT_COUNT; - final int slot = SlotHash.getSlot(apnFallbackManager.getEndpointKey(account, device)); - final int previousSlot = (slot + SlotHash.SLOT_COUNT - 1) % SlotHash.SLOT_COUNT; + REDIS_CLUSTER_EXTENSION.getRedisCluster().withCluster(connection -> connection.sync() + .set(ApnFallbackManager.NEXT_SLOT_TO_PERSIST_KEY, String.valueOf(previousSlot))); - getRedisCluster().withCluster(connection -> connection.sync().set(ApnFallbackManager.NEXT_SLOT_TO_PERSIST_KEY, String.valueOf(previousSlot))); + assertEquals(1, worker.processNextSlot()); - assertEquals(1, worker.processNextSlot()); + final ArgumentCaptor messageCaptor = ArgumentCaptor.forClass(ApnMessage.class); + verify(apnSender).sendMessage(messageCaptor.capture()); - final ArgumentCaptor messageCaptor = ArgumentCaptor.forClass(ApnMessage.class); - verify(apnSender).sendMessage(messageCaptor.capture()); + final ApnMessage message = messageCaptor.getValue(); - final ApnMessage message = messageCaptor.getValue(); + assertEquals(VOIP_APN_ID, message.getApnId()); + assertEquals(Optional.of(ACCOUNT_UUID), message.getUuid()); + assertEquals(DEVICE_ID, message.getDeviceId()); - assertEquals(VOIP_APN_ID, message.getApnId()); - assertEquals(Optional.of(ACCOUNT_UUID), message.getUuid()); - assertEquals(DEVICE_ID, message.getDeviceId()); - - assertEquals(0, worker.processNextSlot()); - } + assertEquals(0, worker.processNextSlot()); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/RedisClusterExtension.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/RedisClusterExtension.java new file mode 100644 index 000000000..bcf3e17dc --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/RedisClusterExtension.java @@ -0,0 +1,216 @@ +/* + * Copyright 2013-2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.redis; + +import static org.junit.Assume.assumeFalse; + +import io.lettuce.core.RedisClient; +import io.lettuce.core.RedisException; +import io.lettuce.core.RedisURI; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.sync.RedisCommands; +import io.lettuce.core.cluster.RedisClusterClient; +import io.lettuce.core.cluster.SlotHash; +import java.io.File; +import java.io.IOException; +import java.net.ServerSocket; +import java.time.Duration; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.junit.jupiter.api.extension.AfterAllCallback; +import org.junit.jupiter.api.extension.AfterEachCallback; +import org.junit.jupiter.api.extension.BeforeAllCallback; +import org.junit.jupiter.api.extension.BeforeEachCallback; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; +import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; +import org.whispersystems.textsecuregcm.util.RedisClusterUtil; +import redis.embedded.RedisServer; + +public class RedisClusterExtension implements BeforeAllCallback, BeforeEachCallback, AfterAllCallback, AfterEachCallback { + + public static RedisClusterExtensionBuilder builder() { + return new RedisClusterExtensionBuilder(); + } + + + @Override + public void afterAll(final ExtensionContext context) throws Exception { + for (final RedisServer node : clusterNodes) { + node.stop(); + } + } + + @Override + public void afterEach(final ExtensionContext context) throws Exception { + redisCluster.shutdown(); + } + + @Override + public void beforeAll(final ExtensionContext context) throws Exception { + assumeFalse(System.getProperty("os.name").equalsIgnoreCase("windows")); + + clusterNodes = new RedisServer[NODE_COUNT]; + + for (int i = 0; i < NODE_COUNT; i++) { + clusterNodes[i] = buildClusterNode(getNextRedisClusterPort()); + clusterNodes[i].start(); + } + + assembleCluster(clusterNodes); + } + + @Override + public void beforeEach(final ExtensionContext context) throws Exception { + final List urls = Arrays.stream(clusterNodes) + .map(node -> String.format("redis://127.0.0.1:%d", node.ports().get(0))) + .collect(Collectors.toList()); + + redisCluster = new FaultTolerantRedisCluster("test-cluster", + RedisClusterClient.create(urls.stream().map(RedisURI::create).collect(Collectors.toList())), + Duration.ofSeconds(2), + new CircuitBreakerConfiguration(), + new RetryConfiguration()); + + redisCluster.useCluster(connection -> { + boolean setAll = false; + + final String[] keys = new String[NODE_COUNT]; + + for (int i = 0; i < keys.length; i++) { + keys[i] = RedisClusterUtil.getMinimalHashTag(i * SlotHash.SLOT_COUNT / keys.length); + } + + while (!setAll) { + try { + for (final String key : keys) { + connection.sync().set(key, "warmup"); + } + + setAll = true; + } catch (final RedisException ignored) { + // Cluster isn't ready; wait and retry. + try { + Thread.sleep(500); + } catch (final InterruptedException ignored2) { + } + } + } + }); + + redisCluster.useCluster(connection -> connection.sync().flushall()); + } + + private static final int NODE_COUNT = 2; + + private static RedisServer[] clusterNodes; + + private FaultTolerantRedisCluster redisCluster; + + public FaultTolerantRedisCluster getRedisCluster() { + return redisCluster; + } + + private static RedisServer buildClusterNode(final int port) throws IOException { + final File clusterConfigFile = File.createTempFile("redis", ".conf"); + clusterConfigFile.deleteOnExit(); + + return RedisServer.builder() + .setting("cluster-enabled yes") + .setting("cluster-config-file " + clusterConfigFile.getAbsolutePath()) + .setting("cluster-node-timeout 5000") + .setting("appendonly no") + .setting("save \"\"") + .setting("dir " + System.getProperty("java.io.tmpdir")) + .port(port) + .build(); + } + + private static void assembleCluster(final RedisServer... nodes) throws InterruptedException { + final RedisClient meetClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[0].ports().get(0))); + + try { + final StatefulRedisConnection connection = meetClient.connect(); + final RedisCommands commands = connection.sync(); + + for (int i = 1; i < nodes.length; i++) { + commands.clusterMeet("127.0.0.1", nodes[i].ports().get(0)); + } + } finally { + meetClient.shutdown(); + } + + final int slotsPerNode = SlotHash.SLOT_COUNT / nodes.length; + + for (int i = 0; i < nodes.length; i++) { + final int startInclusive = i * slotsPerNode; + final int endExclusive = i == nodes.length - 1 ? SlotHash.SLOT_COUNT : (i + 1) * slotsPerNode; + + final RedisClient assignSlotClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[i].ports().get(0))); + + try (final StatefulRedisConnection assignSlotConnection = assignSlotClient.connect()) { + final int[] slots = new int[endExclusive - startInclusive]; + + for (int s = startInclusive; s < endExclusive; s++) { + slots[s - startInclusive] = s; + } + + assignSlotConnection.sync().clusterAddSlots(slots); + } finally { + assignSlotClient.shutdown(); + } + } + + final RedisClient waitClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[0].ports().get(0))); + + try (final StatefulRedisConnection connection = waitClient.connect()) { + // CLUSTER INFO gives us a big blob of key-value pairs, but the one we're interested in is `cluster_state`. + // According to https://redis.io/commands/cluster-info, `cluster_state:ok` means that the node is ready to + // receive queries, all slots are assigned, and a majority of master nodes are reachable. + + final int sleepMillis = 500; + int tries = 0; + while (!connection.sync().clusterInfo().contains("cluster_state:ok")) { + Thread.sleep(sleepMillis); + tries++; + + if (tries == 10) { + throw new RuntimeException( + String.format("Timeout: Redis not ready after waiting %d milliseconds", sleepMillis)); + } + } + } finally { + waitClient.shutdown(); + } + } + + public static int getNextRedisClusterPort() throws IOException { + final int MAX_ITERATIONS = 11_000; + int port; + for (int i = 0; i < MAX_ITERATIONS; i++) { + try (ServerSocket socket = new ServerSocket(0)) { + socket.setReuseAddress(false); + port = socket.getLocalPort(); + } + if (port < 55535) { + return port; + } + } + throw new IOException("Couldn't find an open port below 55,535 in " + MAX_ITERATIONS + " tries"); + } + + public static class RedisClusterExtensionBuilder { + + private RedisClusterExtensionBuilder() { + + } + + public RedisClusterExtension build() { + return new RedisClusterExtension(); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/TorExitNodeManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/TorExitNodeManagerTest.java index 55554459d..75739a6ed 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/TorExitNodeManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/TorExitNodeManagerTest.java @@ -5,21 +5,25 @@ package org.whispersystems.textsecuregcm.util; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; import java.io.ByteArrayInputStream; import java.nio.charset.StandardCharsets; import java.util.concurrent.ScheduledExecutorService; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; import org.whispersystems.textsecuregcm.configuration.MonitoredS3ObjectConfiguration; -import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; -public class TorExitNodeManagerTest extends AbstractRedisClusterTest { +class TorExitNodeManagerTest { + + @RegisterExtension + static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); @Test - public void testIsTorExitNode() { + void testIsTorExitNode() { final MonitoredS3ObjectConfiguration configuration = new MonitoredS3ObjectConfiguration(); configuration.setS3Region("ap-northeast-3");