diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimiterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimiterTest.java index 59ce691ea..294ea323b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimiterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimiterTest.java @@ -1,33 +1,55 @@ package org.whispersystems.textsecuregcm.limits; +import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; -import org.whispersystems.textsecuregcm.redis.AbstractRedisTest; -import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import org.whispersystems.textsecuregcm.providers.RedisClientFactory; +import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; +import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool; import redis.clients.jedis.Jedis; +import redis.embedded.RedisServer; import java.io.IOException; +import java.net.URISyntaxException; +import java.util.List; import java.util.concurrent.TimeUnit; import static org.junit.Assert.fail; -import static org.mockito.Mockito.mock; -public class RateLimiterTest extends AbstractRedisTest { +public class RateLimiterTest extends AbstractRedisClusterTest { private static final long NOW_MILLIS = System.currentTimeMillis(); private static final String KEY = "key"; + private RedisServer redisServer; + + private ReplicatedJedisPool replicatedJedisPool; + @FunctionalInterface private interface RateLimitedTask { void run() throws RateLimitExceededException; } @Before - public void clearCache() { - try (final Jedis jedis = getReplicatedJedisPool().getWriteResource()) { + public void clearCache() throws URISyntaxException, IOException { + redisServer = new RedisServer(AbstractRedisClusterTest.getNextRedisClusterPort()); + redisServer.start(); + + final String redisUrl = "redis://127.0.0.1:" + redisServer.ports().get(0); + replicatedJedisPool = new RedisClientFactory("test-pool", redisUrl, List.of(redisUrl), new CircuitBreakerConfiguration()).getRedisClientPool(); + + try (final Jedis jedis = replicatedJedisPool.getWriteResource()) { jedis.flushAll(); } + + getRedisCluster().useWriteCluster(connection -> connection.sync().flushall()); + } + + @After + public void stopServer() { + redisServer.stop(); } @Test @@ -52,10 +74,12 @@ public class RateLimiterTest extends AbstractRedisTest { final RateLimiter rateLimiter = buildRateLimiter(2, 8.333333333333334E-6); final String leakyBucketJson = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (NOW_MILLIS - TimeUnit.MINUTES.toMillis(2)) + "}"; - try (final Jedis jedis = getReplicatedJedisPool().getWriteResource()) { + try (final Jedis jedis = replicatedJedisPool.getWriteResource()) { jedis.set(rateLimiter.getBucketName(KEY), leakyBucketJson); } + getRedisCluster().useWriteCluster(connection -> connection.sync().set(rateLimiter.getBucketName(KEY), leakyBucketJson)); + rateLimiter.validate(KEY, 1, NOW_MILLIS); assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS)); } @@ -65,10 +89,12 @@ public class RateLimiterTest extends AbstractRedisTest { final RateLimiter rateLimiter = buildRateLimiter(2, 8.333333333333334E-6); final String leakyBucketJson = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (NOW_MILLIS - TimeUnit.MINUTES.toMillis(1)) + "}"; - try (final Jedis jedis = getReplicatedJedisPool().getWriteResource()) { + try (final Jedis jedis = replicatedJedisPool.getWriteResource()) { jedis.set(rateLimiter.getBucketName(KEY), leakyBucketJson); } + getRedisCluster().useWriteCluster(connection -> connection.sync().set(rateLimiter.getBucketName(KEY), leakyBucketJson)); + assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS)); } @@ -83,6 +109,6 @@ public class RateLimiterTest extends AbstractRedisTest { @SuppressWarnings("SameParameterValue") private RateLimiter buildRateLimiter(final int bucketSize, final double leakRatePerMilli) throws IOException { final double leakRatePerMinute = leakRatePerMilli * 60_000d; - return new RateLimiter(getReplicatedJedisPool(), mock(FaultTolerantRedisCluster.class), KEY, bucketSize, leakRatePerMinute); + return new RateLimiter(replicatedJedisPool, getRedisCluster(), KEY, bucketSize, leakRatePerMinute); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java index 9b5908add..f32903596 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java @@ -143,7 +143,7 @@ public abstract class AbstractRedisClusterTest { } } - private static int getNextRedisClusterPort() throws IOException { + public static int getNextRedisClusterPort() throws IOException { final int MAX_ITERATIONS = 11_000; int port; for (int i = 0; i < MAX_ITERATIONS; i++) {