From b585c6676ded56c9cf203d8bda3ac323ca41b5d6 Mon Sep 17 00:00:00 2001
From: Jon Chambers <63609320+jon-signal@users.noreply.github.com>
Date: Mon, 6 Jul 2020 10:10:13 -0400
Subject: [PATCH] Move rate limiter logic to Lua scripts
---
.../textsecuregcm/limits/LeakyBucket.java | 98 ----------------
.../limits/LockingRateLimiter.java | 74 ------------
.../textsecuregcm/limits/RateLimiter.java | 107 +++++++-----------
.../textsecuregcm/limits/RateLimiters.java | 8 +-
.../resources/lua/validate_rate_limit.lua | 33 ++++++
.../textsecuregcm/limits/RateLimiterTest.java | 88 ++++++++++++++
.../redis/AbstractRedisTest.java | 48 ++++++++
.../tests/limits/LeakyBucketTest.java | 52 ---------
8 files changed, 214 insertions(+), 294 deletions(-)
delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java
delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java
create mode 100644 service/src/main/resources/lua/validate_rate_limit.lua
create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimiterTest.java
create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisTest.java
delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/tests/limits/LeakyBucketTest.java
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java
deleted file mode 100644
index 332a86869..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java
+++ /dev/null
@@ -1,98 +0,0 @@
-/**
- * Copyright (C) 2013 Open WhisperSystems
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with this program. If not, see .
- */
-package org.whispersystems.textsecuregcm.limits;
-
-import com.fasterxml.jackson.annotation.JsonProperty;
-import com.fasterxml.jackson.core.JsonProcessingException;
-import com.fasterxml.jackson.databind.ObjectMapper;
-
-import java.io.IOException;
-
-public class LeakyBucket {
-
- private final int bucketSize;
- private final double leakRatePerMillis;
-
- private int spaceRemaining;
- private long lastUpdateTimeMillis;
-
- public LeakyBucket(int bucketSize, double leakRatePerMillis) {
- this(bucketSize, leakRatePerMillis, bucketSize, System.currentTimeMillis());
- }
-
- private LeakyBucket(int bucketSize, double leakRatePerMillis, int spaceRemaining, long lastUpdateTimeMillis) {
- this.bucketSize = bucketSize;
- this.leakRatePerMillis = leakRatePerMillis;
- this.spaceRemaining = spaceRemaining;
- this.lastUpdateTimeMillis = lastUpdateTimeMillis;
- }
-
- public boolean add(int amount) {
- this.spaceRemaining = getUpdatedSpaceRemaining();
- this.lastUpdateTimeMillis = System.currentTimeMillis();
-
- if (this.spaceRemaining >= amount) {
- this.spaceRemaining -= amount;
- return true;
- } else {
- return false;
- }
- }
-
- private int getUpdatedSpaceRemaining() {
- long elapsedTime = System.currentTimeMillis() - this.lastUpdateTimeMillis;
-
- return Math.min(this.bucketSize,
- (int)Math.floor(this.spaceRemaining + (elapsedTime * this.leakRatePerMillis)));
- }
-
- public String serialize(ObjectMapper mapper) throws JsonProcessingException {
- return mapper.writeValueAsString(new LeakyBucketEntity(bucketSize, leakRatePerMillis, spaceRemaining, lastUpdateTimeMillis));
- }
-
- public static LeakyBucket fromSerialized(ObjectMapper mapper, String serialized) throws IOException {
- LeakyBucketEntity entity = mapper.readValue(serialized, LeakyBucketEntity.class);
-
- return new LeakyBucket(entity.bucketSize, entity.leakRatePerMillis,
- entity.spaceRemaining, entity.lastUpdateTimeMillis);
- }
-
- private static class LeakyBucketEntity {
- @JsonProperty
- private int bucketSize;
-
- @JsonProperty
- private double leakRatePerMillis;
-
- @JsonProperty
- private int spaceRemaining;
-
- @JsonProperty
- private long lastUpdateTimeMillis;
-
- public LeakyBucketEntity() {}
-
- private LeakyBucketEntity(int bucketSize, double leakRatePerMillis,
- int spaceRemaining, long lastUpdateTimeMillis)
- {
- this.bucketSize = bucketSize;
- this.leakRatePerMillis = leakRatePerMillis;
- this.spaceRemaining = spaceRemaining;
- this.lastUpdateTimeMillis = lastUpdateTimeMillis;
- }
- }
-}
\ No newline at end of file
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java
deleted file mode 100644
index c16079790..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java
+++ /dev/null
@@ -1,74 +0,0 @@
-package org.whispersystems.textsecuregcm.limits;
-
-import com.codahale.metrics.Meter;
-import com.codahale.metrics.MetricRegistry;
-import com.codahale.metrics.SharedMetricRegistries;
-import io.lettuce.core.SetArgs;
-import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
-import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
-import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
-import org.whispersystems.textsecuregcm.util.Constants;
-
-import static com.codahale.metrics.MetricRegistry.name;
-import redis.clients.jedis.Jedis;
-
-public class LockingRateLimiter extends RateLimiter {
-
- private final Meter meter;
-
- public LockingRateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name, int bucketSize, double leakRatePerMinute) {
- super(cacheClient, cacheCluster, name, bucketSize, leakRatePerMinute);
-
- MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
- this.meter = metricRegistry.meter(name(getClass(), name, "locked"));
- }
-
- @Override
- public void validate(String key, int amount) throws RateLimitExceededException {
- if (!acquireLock(key)) {
- meter.mark();
- throw new RateLimitExceededException("Locked");
- }
-
- try {
- super.validate(key, amount);
- } finally {
- releaseLock(key);
- }
- }
-
- @Override
- public void validate(String key) throws RateLimitExceededException {
- validate(key, 1);
- }
-
- private void releaseLock(String key) {
- try (Jedis jedis = cacheClient.getWriteResource()) {
- final String lockName = getLockName(key);
-
- jedis.del(lockName);
- cacheCluster.useWriteCluster(connection -> connection.sync().del(lockName));
- }
- }
-
- private boolean acquireLock(String key) {
- try (Jedis jedis = cacheClient.getWriteResource()) {
- final String lockName = getLockName(key);
-
- final boolean acquiredLock = jedis.set(lockName, "L", "NX", "EX", 10) != null;
-
- if (acquiredLock) {
- // TODO Restore the NX flag when the cluster becomes the primary source of truth
- cacheCluster.useWriteCluster(connection -> connection.sync().set(lockName, "L", SetArgs.Builder.ex(10)));
- }
-
- return acquiredLock;
- }
- }
-
- private String getLockName(String key) {
- return "leaky_lock::" + name + "::" + key;
- }
-
-
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java
index 81780df7a..8d0a82e55 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java
@@ -20,26 +20,25 @@ import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
-import com.fasterxml.jackson.core.JsonProcessingException;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import com.google.common.annotations.VisibleForTesting;
+import io.lettuce.core.ScriptOutputType;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.experiment.Experiment;
+import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
+import org.whispersystems.textsecuregcm.redis.LuaScript;
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
import org.whispersystems.textsecuregcm.util.Constants;
-import org.whispersystems.textsecuregcm.util.SystemMapper;
-
-import java.io.IOException;
-
-import static com.codahale.metrics.MetricRegistry.name;
import redis.clients.jedis.Jedis;
-public class RateLimiter {
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+import java.util.stream.Collectors;
- private final Logger logger = LoggerFactory.getLogger(RateLimiter.class);
- private final ObjectMapper mapper = SystemMapper.getMapper();
+import static com.codahale.metrics.MetricRegistry.name;
+
+public class RateLimiter {
private final Meter meter;
private final Timer validateTimer;
@@ -48,19 +47,12 @@ public class RateLimiter {
protected final String name;
private final int bucketSize;
private final double leakRatePerMillis;
- private final boolean reportLimits;
private final Experiment redisClusterExperiment;
+ private final LuaScript validateScript;
+ private final ClusterLuaScript clusterValidateScript;
public RateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name,
- int bucketSize, double leakRatePerMinute)
- {
- this(cacheClient, cacheCluster, name, bucketSize, leakRatePerMinute, false);
- }
-
- public RateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name,
- int bucketSize, double leakRatePerMinute,
- boolean reportLimits)
- {
+ int bucketSize, double leakRatePerMinute) throws IOException {
MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
this.meter = metricRegistry.meter(name(getClass(), name, "exceeded"));
@@ -70,27 +62,37 @@ public class RateLimiter {
this.name = name;
this.bucketSize = bucketSize;
this.leakRatePerMillis = leakRatePerMinute / (60.0 * 1000.0);
- this.reportLimits = reportLimits;
this.redisClusterExperiment = new Experiment("RedisCluster", "RateLimiter", name);
- }
-
- public void validate(String key, int amount) throws RateLimitExceededException {
- try (final Timer.Context ignored = validateTimer.time()) {
- LeakyBucket bucket = getBucket(key);
-
- if (bucket.add(amount)) {
- setBucket(key, bucket);
- } else {
- meter.mark();
- throw new RateLimitExceededException(key + " , " + amount);
- }
- }
+ this.validateScript = LuaScript.fromResource(cacheClient, "lua/validate_rate_limit.lua");
+ this.clusterValidateScript = ClusterLuaScript.fromResource(cacheCluster, "lua/validate_rate_limit.lua", ScriptOutputType.INTEGER);
}
public void validate(String key) throws RateLimitExceededException {
validate(key, 1);
}
+ public void validate(String key, int amount) throws RateLimitExceededException {
+ validate(key, amount, System.currentTimeMillis());
+ }
+
+ @VisibleForTesting
+ void validate(String key, int amount, final long currentTimeMillis) throws RateLimitExceededException {
+ try (final Timer.Context ignored = validateTimer.time()) {
+ final List keys = List.of(getBucketName(key));
+ final List arguments = List.of(String.valueOf(bucketSize), String.valueOf(leakRatePerMillis), String.valueOf(currentTimeMillis), String.valueOf(amount));
+
+ final Object result = validateScript.execute(keys.stream().map(k -> k.getBytes(StandardCharsets.UTF_8)).collect(Collectors.toList()),
+ arguments.stream().map(a -> a.getBytes(StandardCharsets.UTF_8)).collect(Collectors.toList()));
+
+ redisClusterExperiment.compareSupplierResult(result, () -> clusterValidateScript.execute(keys, arguments));
+
+ if (result == null) {
+ meter.mark();
+ throw new RateLimitExceededException(key + " , " + amount);
+ }
+ }
+ }
+
public void clear(String key) {
try (Jedis jedis = cacheClient.getWriteResource()) {
final String bucketName = getBucketName(key);
@@ -100,37 +102,8 @@ public class RateLimiter {
}
}
- private void setBucket(String key, LeakyBucket bucket) {
- try (Jedis jedis = cacheClient.getWriteResource()) {
- final String bucketName = getBucketName(key);
- final String serialized = bucket.serialize(mapper);
- final int level = (int) Math.ceil((bucketSize / leakRatePerMillis) / 1000);
-
- jedis.setex(bucketName, level, serialized);
- cacheCluster.useWriteCluster(connection -> connection.sync().setex(bucketName, level, serialized));
- } catch (JsonProcessingException e) {
- throw new IllegalArgumentException(e);
- }
- }
-
- private LeakyBucket getBucket(String key) {
- try (Jedis jedis = cacheClient.getReadResource()) {
- final String bucketName = getBucketName(key);
-
- String serialized = jedis.get(bucketName);
- redisClusterExperiment.compareSupplierResult(serialized, () -> cacheCluster.withReadCluster(connection -> connection.sync().get(bucketName)));
-
- if (serialized != null) {
- return LeakyBucket.fromSerialized(mapper, serialized);
- }
- } catch (IOException e) {
- logger.warn("Deserialization error", e);
- }
-
- return new LeakyBucket(bucketSize, leakRatePerMillis);
- }
-
- private String getBucketName(String key) {
+ @VisibleForTesting
+ String getBucketName(String key) {
return "leaky_bucket::" + name + "::" + key;
}
}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java
index 70e4974aa..1803fa3a2 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java
@@ -21,6 +21,8 @@ import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
+import java.io.IOException;
+
public class RateLimiters {
private final RateLimiter smsDestinationLimiter;
@@ -48,7 +50,7 @@ public class RateLimiters {
private final RateLimiter usernameLookupLimiter;
private final RateLimiter usernameSetLimiter;
- public RateLimiters(RateLimitsConfiguration config, ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster) {
+ public RateLimiters(RateLimitsConfiguration config, ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster) throws IOException {
this.smsDestinationLimiter = new RateLimiter(cacheClient, cacheCluster, "smsDestination",
config.getSmsDestination().getBucketSize(),
config.getSmsDestination().getLeakRatePerMinute());
@@ -73,11 +75,11 @@ public class RateLimiters {
config.getAutoBlock().getBucketSize(),
config.getAutoBlock().getLeakRatePerMinute());
- this.verifyLimiter = new LockingRateLimiter(cacheClient, cacheCluster, "verify",
+ this.verifyLimiter = new RateLimiter(cacheClient, cacheCluster, "verify",
config.getVerifyNumber().getBucketSize(),
config.getVerifyNumber().getLeakRatePerMinute());
- this.pinLimiter = new LockingRateLimiter(cacheClient, cacheCluster, "pin",
+ this.pinLimiter = new RateLimiter(cacheClient, cacheCluster, "pin",
config.getVerifyPin().getBucketSize(),
config.getVerifyPin().getLeakRatePerMinute());
diff --git a/service/src/main/resources/lua/validate_rate_limit.lua b/service/src/main/resources/lua/validate_rate_limit.lua
new file mode 100644
index 000000000..17c1e2215
--- /dev/null
+++ b/service/src/main/resources/lua/validate_rate_limit.lua
@@ -0,0 +1,33 @@
+local bucketId = KEYS[1]
+
+local bucketSize = tonumber(ARGV[1])
+local leakRatePerMillis = tonumber(ARGV[2])
+local currentTimeMillis = tonumber(ARGV[3])
+local amount = tonumber(ARGV[4])
+
+local leakyBucket
+
+if redis.call("EXISTS", bucketId) == 1 then
+ leakyBucket = cjson.decode(redis.call("GET", bucketId))
+else
+ leakyBucket = {
+ bucketSize = bucketSize,
+ leakRatePerMillis = leakRatePerMillis,
+ spaceRemaining = bucketSize,
+ lastUpdateTimeMillis = currentTimeMillis
+ }
+end
+
+local elapsedTime = currentTimeMillis - leakyBucket["lastUpdateTimeMillis"]
+local updatedSpaceRemaining = math.min(leakyBucket["bucketSize"], math.floor(leakyBucket["spaceRemaining"] + (elapsedTime * leakyBucket["leakRatePerMillis"])))
+
+redis.call("SET", "elapsedTime", elapsedTime)
+redis.call("SET", "updatedSpaceRemaining", updatedSpaceRemaining)
+
+if updatedSpaceRemaining >= amount then
+ leakyBucket["spaceRemaining"] = updatedSpaceRemaining - amount
+ redis.call("SET", bucketId, cjson.encode(leakyBucket))
+ return true
+else
+ return false
+end
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimiterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimiterTest.java
new file mode 100644
index 000000000..59ce691ea
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimiterTest.java
@@ -0,0 +1,88 @@
+package org.whispersystems.textsecuregcm.limits;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
+import org.whispersystems.textsecuregcm.redis.AbstractRedisTest;
+import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
+import redis.clients.jedis.Jedis;
+
+import java.io.IOException;
+import java.util.concurrent.TimeUnit;
+
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+
+public class RateLimiterTest extends AbstractRedisTest {
+
+ private static final long NOW_MILLIS = System.currentTimeMillis();
+ private static final String KEY = "key";
+
+ @FunctionalInterface
+ private interface RateLimitedTask {
+ void run() throws RateLimitExceededException;
+ }
+
+ @Before
+ public void clearCache() {
+ try (final Jedis jedis = getReplicatedJedisPool().getWriteResource()) {
+ jedis.flushAll();
+ }
+ }
+
+ @Test
+ public void validate() throws RateLimitExceededException, IOException {
+ final RateLimiter rateLimiter = buildRateLimiter(2, 0.5);
+
+ rateLimiter.validate(KEY, 1, NOW_MILLIS);
+ rateLimiter.validate(KEY, 1, NOW_MILLIS);
+ assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
+ }
+
+ @Test
+ public void validateWithAmount() throws RateLimitExceededException, IOException {
+ final RateLimiter rateLimiter = buildRateLimiter(2, 0.5);
+
+ rateLimiter.validate(KEY, 2, NOW_MILLIS);
+ assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
+ }
+
+ @Test
+ public void testLapseRate() throws RateLimitExceededException, IOException {
+ final RateLimiter rateLimiter = buildRateLimiter(2, 8.333333333333334E-6);
+ final String leakyBucketJson = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (NOW_MILLIS - TimeUnit.MINUTES.toMillis(2)) + "}";
+
+ try (final Jedis jedis = getReplicatedJedisPool().getWriteResource()) {
+ jedis.set(rateLimiter.getBucketName(KEY), leakyBucketJson);
+ }
+
+ rateLimiter.validate(KEY, 1, NOW_MILLIS);
+ assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
+ }
+
+ @Test
+ public void testLapseShort() throws IOException {
+ final RateLimiter rateLimiter = buildRateLimiter(2, 8.333333333333334E-6);
+ final String leakyBucketJson = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (NOW_MILLIS - TimeUnit.MINUTES.toMillis(1)) + "}";
+
+ try (final Jedis jedis = getReplicatedJedisPool().getWriteResource()) {
+ jedis.set(rateLimiter.getBucketName(KEY), leakyBucketJson);
+ }
+
+ assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
+ }
+
+ private void assertRateLimitExceeded(final RateLimitedTask task) {
+ try {
+ task.run();
+ fail("Expected RateLimitExceededException");
+ } catch (final RateLimitExceededException ignored) {
+ }
+ }
+
+ @SuppressWarnings("SameParameterValue")
+ private RateLimiter buildRateLimiter(final int bucketSize, final double leakRatePerMilli) throws IOException {
+ final double leakRatePerMinute = leakRatePerMilli * 60_000d;
+ return new RateLimiter(getReplicatedJedisPool(), mock(FaultTolerantRedisCluster.class), KEY, bucketSize, leakRatePerMinute);
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisTest.java
new file mode 100644
index 000000000..2531fc6b0
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisTest.java
@@ -0,0 +1,48 @@
+package org.whispersystems.textsecuregcm.redis;
+
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
+import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
+import redis.embedded.RedisServer;
+
+import java.io.IOException;
+import java.net.ServerSocket;
+import java.net.URISyntaxException;
+import java.util.List;
+
+public abstract class AbstractRedisTest {
+
+ private static RedisServer redisServer;
+
+ private ReplicatedJedisPool replicatedJedisPool;
+
+ @BeforeClass
+ public static void setUpBeforeClass() throws IOException {
+ redisServer = new RedisServer(getNextPort());
+ redisServer.start();
+ }
+
+ @Before
+ public void setUp() throws URISyntaxException {
+ final String redisUrl = "redis://127.0.0.1:" + redisServer.ports().get(0);
+ replicatedJedisPool = new RedisClientFactory("test-pool", redisUrl, List.of(redisUrl), new CircuitBreakerConfiguration()).getRedisClientPool();
+ }
+
+ protected ReplicatedJedisPool getReplicatedJedisPool() {
+ return replicatedJedisPool;
+ }
+
+ @AfterClass
+ public static void tearDownAfterClass() {
+ redisServer.stop();
+ }
+
+ private static int getNextPort() throws IOException {
+ try (ServerSocket socket = new ServerSocket(0)) {
+ socket.setReuseAddress(false);
+ return socket.getLocalPort();
+ }
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/limits/LeakyBucketTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/limits/LeakyBucketTest.java
deleted file mode 100644
index 213baec79..000000000
--- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/limits/LeakyBucketTest.java
+++ /dev/null
@@ -1,52 +0,0 @@
-package org.whispersystems.textsecuregcm.tests.limits;
-
-import com.fasterxml.jackson.databind.ObjectMapper;
-import org.junit.Test;
-import org.whispersystems.textsecuregcm.limits.LeakyBucket;
-
-import java.io.IOException;
-import java.util.concurrent.TimeUnit;
-
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
-
-public class LeakyBucketTest {
-
- @Test
- public void testFull() {
- LeakyBucket leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
-
- assertTrue(leakyBucket.add(1));
- assertTrue(leakyBucket.add(1));
- assertFalse(leakyBucket.add(1));
-
- leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
-
- assertTrue(leakyBucket.add(2));
- assertFalse(leakyBucket.add(1));
- assertFalse(leakyBucket.add(2));
- }
-
- @Test
- public void testLapseRate() throws IOException {
- ObjectMapper mapper = new ObjectMapper();
- String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(2)) + "}";
-
- LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
- assertTrue(leakyBucket.add(1));
-
- String serializedAgain = leakyBucket.serialize(mapper);
- LeakyBucket leakyBucketAgain = LeakyBucket.fromSerialized(mapper, serializedAgain);
-
- assertFalse(leakyBucketAgain.add(1));
- }
-
- @Test
- public void testLapseShort() throws Exception {
- ObjectMapper mapper = new ObjectMapper();
- String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
-
- LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
- assertFalse(leakyBucket.add(1));
- }
-}