diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java
new file mode 100644
index 000000000..332a86869
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java
@@ -0,0 +1,98 @@
+/**
+ * Copyright (C) 2013 Open WhisperSystems
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+package org.whispersystems.textsecuregcm.limits;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import java.io.IOException;
+
+public class LeakyBucket {
+
+ private final int bucketSize;
+ private final double leakRatePerMillis;
+
+ private int spaceRemaining;
+ private long lastUpdateTimeMillis;
+
+ public LeakyBucket(int bucketSize, double leakRatePerMillis) {
+ this(bucketSize, leakRatePerMillis, bucketSize, System.currentTimeMillis());
+ }
+
+ private LeakyBucket(int bucketSize, double leakRatePerMillis, int spaceRemaining, long lastUpdateTimeMillis) {
+ this.bucketSize = bucketSize;
+ this.leakRatePerMillis = leakRatePerMillis;
+ this.spaceRemaining = spaceRemaining;
+ this.lastUpdateTimeMillis = lastUpdateTimeMillis;
+ }
+
+ public boolean add(int amount) {
+ this.spaceRemaining = getUpdatedSpaceRemaining();
+ this.lastUpdateTimeMillis = System.currentTimeMillis();
+
+ if (this.spaceRemaining >= amount) {
+ this.spaceRemaining -= amount;
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ private int getUpdatedSpaceRemaining() {
+ long elapsedTime = System.currentTimeMillis() - this.lastUpdateTimeMillis;
+
+ return Math.min(this.bucketSize,
+ (int)Math.floor(this.spaceRemaining + (elapsedTime * this.leakRatePerMillis)));
+ }
+
+ public String serialize(ObjectMapper mapper) throws JsonProcessingException {
+ return mapper.writeValueAsString(new LeakyBucketEntity(bucketSize, leakRatePerMillis, spaceRemaining, lastUpdateTimeMillis));
+ }
+
+ public static LeakyBucket fromSerialized(ObjectMapper mapper, String serialized) throws IOException {
+ LeakyBucketEntity entity = mapper.readValue(serialized, LeakyBucketEntity.class);
+
+ return new LeakyBucket(entity.bucketSize, entity.leakRatePerMillis,
+ entity.spaceRemaining, entity.lastUpdateTimeMillis);
+ }
+
+ private static class LeakyBucketEntity {
+ @JsonProperty
+ private int bucketSize;
+
+ @JsonProperty
+ private double leakRatePerMillis;
+
+ @JsonProperty
+ private int spaceRemaining;
+
+ @JsonProperty
+ private long lastUpdateTimeMillis;
+
+ public LeakyBucketEntity() {}
+
+ private LeakyBucketEntity(int bucketSize, double leakRatePerMillis,
+ int spaceRemaining, long lastUpdateTimeMillis)
+ {
+ this.bucketSize = bucketSize;
+ this.leakRatePerMillis = leakRatePerMillis;
+ this.spaceRemaining = spaceRemaining;
+ this.lastUpdateTimeMillis = lastUpdateTimeMillis;
+ }
+ }
+}
\ No newline at end of file
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java
new file mode 100644
index 000000000..c16079790
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java
@@ -0,0 +1,74 @@
+package org.whispersystems.textsecuregcm.limits;
+
+import com.codahale.metrics.Meter;
+import com.codahale.metrics.MetricRegistry;
+import com.codahale.metrics.SharedMetricRegistries;
+import io.lettuce.core.SetArgs;
+import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
+import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
+import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
+import org.whispersystems.textsecuregcm.util.Constants;
+
+import static com.codahale.metrics.MetricRegistry.name;
+import redis.clients.jedis.Jedis;
+
+public class LockingRateLimiter extends RateLimiter {
+
+ private final Meter meter;
+
+ public LockingRateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name, int bucketSize, double leakRatePerMinute) {
+ super(cacheClient, cacheCluster, name, bucketSize, leakRatePerMinute);
+
+ MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
+ this.meter = metricRegistry.meter(name(getClass(), name, "locked"));
+ }
+
+ @Override
+ public void validate(String key, int amount) throws RateLimitExceededException {
+ if (!acquireLock(key)) {
+ meter.mark();
+ throw new RateLimitExceededException("Locked");
+ }
+
+ try {
+ super.validate(key, amount);
+ } finally {
+ releaseLock(key);
+ }
+ }
+
+ @Override
+ public void validate(String key) throws RateLimitExceededException {
+ validate(key, 1);
+ }
+
+ private void releaseLock(String key) {
+ try (Jedis jedis = cacheClient.getWriteResource()) {
+ final String lockName = getLockName(key);
+
+ jedis.del(lockName);
+ cacheCluster.useWriteCluster(connection -> connection.sync().del(lockName));
+ }
+ }
+
+ private boolean acquireLock(String key) {
+ try (Jedis jedis = cacheClient.getWriteResource()) {
+ final String lockName = getLockName(key);
+
+ final boolean acquiredLock = jedis.set(lockName, "L", "NX", "EX", 10) != null;
+
+ if (acquiredLock) {
+ // TODO Restore the NX flag when the cluster becomes the primary source of truth
+ cacheCluster.useWriteCluster(connection -> connection.sync().set(lockName, "L", SetArgs.Builder.ex(10)));
+ }
+
+ return acquiredLock;
+ }
+ }
+
+ private String getLockName(String key) {
+ return "leaky_lock::" + name + "::" + key;
+ }
+
+
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java
index 8d0a82e55..81780df7a 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java
@@ -20,26 +20,27 @@ import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
-import com.google.common.annotations.VisibleForTesting;
-import io.lettuce.core.ScriptOutputType;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.experiment.Experiment;
-import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
-import org.whispersystems.textsecuregcm.redis.LuaScript;
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
import org.whispersystems.textsecuregcm.util.Constants;
-import redis.clients.jedis.Jedis;
+import org.whispersystems.textsecuregcm.util.SystemMapper;
import java.io.IOException;
-import java.nio.charset.StandardCharsets;
-import java.util.List;
-import java.util.stream.Collectors;
import static com.codahale.metrics.MetricRegistry.name;
+import redis.clients.jedis.Jedis;
public class RateLimiter {
+ private final Logger logger = LoggerFactory.getLogger(RateLimiter.class);
+ private final ObjectMapper mapper = SystemMapper.getMapper();
+
private final Meter meter;
private final Timer validateTimer;
protected final ReplicatedJedisPool cacheClient;
@@ -47,12 +48,19 @@ public class RateLimiter {
protected final String name;
private final int bucketSize;
private final double leakRatePerMillis;
+ private final boolean reportLimits;
private final Experiment redisClusterExperiment;
- private final LuaScript validateScript;
- private final ClusterLuaScript clusterValidateScript;
public RateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name,
- int bucketSize, double leakRatePerMinute) throws IOException {
+ int bucketSize, double leakRatePerMinute)
+ {
+ this(cacheClient, cacheCluster, name, bucketSize, leakRatePerMinute, false);
+ }
+
+ public RateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name,
+ int bucketSize, double leakRatePerMinute,
+ boolean reportLimits)
+ {
MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
this.meter = metricRegistry.meter(name(getClass(), name, "exceeded"));
@@ -62,37 +70,27 @@ public class RateLimiter {
this.name = name;
this.bucketSize = bucketSize;
this.leakRatePerMillis = leakRatePerMinute / (60.0 * 1000.0);
+ this.reportLimits = reportLimits;
this.redisClusterExperiment = new Experiment("RedisCluster", "RateLimiter", name);
- this.validateScript = LuaScript.fromResource(cacheClient, "lua/validate_rate_limit.lua");
- this.clusterValidateScript = ClusterLuaScript.fromResource(cacheCluster, "lua/validate_rate_limit.lua", ScriptOutputType.INTEGER);
- }
-
- public void validate(String key) throws RateLimitExceededException {
- validate(key, 1);
}
public void validate(String key, int amount) throws RateLimitExceededException {
- validate(key, amount, System.currentTimeMillis());
- }
-
- @VisibleForTesting
- void validate(String key, int amount, final long currentTimeMillis) throws RateLimitExceededException {
try (final Timer.Context ignored = validateTimer.time()) {
- final List keys = List.of(getBucketName(key));
- final List arguments = List.of(String.valueOf(bucketSize), String.valueOf(leakRatePerMillis), String.valueOf(currentTimeMillis), String.valueOf(amount));
+ LeakyBucket bucket = getBucket(key);
- final Object result = validateScript.execute(keys.stream().map(k -> k.getBytes(StandardCharsets.UTF_8)).collect(Collectors.toList()),
- arguments.stream().map(a -> a.getBytes(StandardCharsets.UTF_8)).collect(Collectors.toList()));
-
- redisClusterExperiment.compareSupplierResult(result, () -> clusterValidateScript.execute(keys, arguments));
-
- if (result == null) {
+ if (bucket.add(amount)) {
+ setBucket(key, bucket);
+ } else {
meter.mark();
throw new RateLimitExceededException(key + " , " + amount);
}
}
}
+ public void validate(String key) throws RateLimitExceededException {
+ validate(key, 1);
+ }
+
public void clear(String key) {
try (Jedis jedis = cacheClient.getWriteResource()) {
final String bucketName = getBucketName(key);
@@ -102,8 +100,37 @@ public class RateLimiter {
}
}
- @VisibleForTesting
- String getBucketName(String key) {
+ private void setBucket(String key, LeakyBucket bucket) {
+ try (Jedis jedis = cacheClient.getWriteResource()) {
+ final String bucketName = getBucketName(key);
+ final String serialized = bucket.serialize(mapper);
+ final int level = (int) Math.ceil((bucketSize / leakRatePerMillis) / 1000);
+
+ jedis.setex(bucketName, level, serialized);
+ cacheCluster.useWriteCluster(connection -> connection.sync().setex(bucketName, level, serialized));
+ } catch (JsonProcessingException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
+ private LeakyBucket getBucket(String key) {
+ try (Jedis jedis = cacheClient.getReadResource()) {
+ final String bucketName = getBucketName(key);
+
+ String serialized = jedis.get(bucketName);
+ redisClusterExperiment.compareSupplierResult(serialized, () -> cacheCluster.withReadCluster(connection -> connection.sync().get(bucketName)));
+
+ if (serialized != null) {
+ return LeakyBucket.fromSerialized(mapper, serialized);
+ }
+ } catch (IOException e) {
+ logger.warn("Deserialization error", e);
+ }
+
+ return new LeakyBucket(bucketSize, leakRatePerMillis);
+ }
+
+ private String getBucketName(String key) {
return "leaky_bucket::" + name + "::" + key;
}
}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java
index 1803fa3a2..70e4974aa 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java
@@ -21,8 +21,6 @@ import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
-import java.io.IOException;
-
public class RateLimiters {
private final RateLimiter smsDestinationLimiter;
@@ -50,7 +48,7 @@ public class RateLimiters {
private final RateLimiter usernameLookupLimiter;
private final RateLimiter usernameSetLimiter;
- public RateLimiters(RateLimitsConfiguration config, ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster) throws IOException {
+ public RateLimiters(RateLimitsConfiguration config, ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster) {
this.smsDestinationLimiter = new RateLimiter(cacheClient, cacheCluster, "smsDestination",
config.getSmsDestination().getBucketSize(),
config.getSmsDestination().getLeakRatePerMinute());
@@ -75,11 +73,11 @@ public class RateLimiters {
config.getAutoBlock().getBucketSize(),
config.getAutoBlock().getLeakRatePerMinute());
- this.verifyLimiter = new RateLimiter(cacheClient, cacheCluster, "verify",
+ this.verifyLimiter = new LockingRateLimiter(cacheClient, cacheCluster, "verify",
config.getVerifyNumber().getBucketSize(),
config.getVerifyNumber().getLeakRatePerMinute());
- this.pinLimiter = new RateLimiter(cacheClient, cacheCluster, "pin",
+ this.pinLimiter = new LockingRateLimiter(cacheClient, cacheCluster, "pin",
config.getVerifyPin().getBucketSize(),
config.getVerifyPin().getLeakRatePerMinute());
diff --git a/service/src/main/resources/lua/validate_rate_limit.lua b/service/src/main/resources/lua/validate_rate_limit.lua
deleted file mode 100644
index 442e8c330..000000000
--- a/service/src/main/resources/lua/validate_rate_limit.lua
+++ /dev/null
@@ -1,30 +0,0 @@
-local bucketId = KEYS[1]
-
-local bucketSize = tonumber(ARGV[1])
-local leakRatePerMillis = tonumber(ARGV[2])
-local currentTimeMillis = tonumber(ARGV[3])
-local amount = tonumber(ARGV[4])
-
-local leakyBucket
-
-if redis.call("EXISTS", bucketId) == 1 then
- leakyBucket = cjson.decode(redis.call("GET", bucketId))
-else
- leakyBucket = {
- bucketSize = bucketSize,
- leakRatePerMillis = leakRatePerMillis,
- spaceRemaining = bucketSize,
- lastUpdateTimeMillis = currentTimeMillis
- }
-end
-
-local elapsedTime = currentTimeMillis - leakyBucket["lastUpdateTimeMillis"]
-local updatedSpaceRemaining = math.min(leakyBucket["bucketSize"], math.floor(leakyBucket["spaceRemaining"] + (elapsedTime * leakyBucket["leakRatePerMillis"])))
-
-if updatedSpaceRemaining >= amount then
- leakyBucket["spaceRemaining"] = updatedSpaceRemaining - amount
- redis.call("SET", bucketId, cjson.encode(leakyBucket))
- return true
-else
- return false
-end
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimiterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimiterTest.java
deleted file mode 100644
index 294ea323b..000000000
--- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimiterTest.java
+++ /dev/null
@@ -1,114 +0,0 @@
-package org.whispersystems.textsecuregcm.limits;
-
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
-import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
-import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
-import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
-import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
-import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
-import redis.clients.jedis.Jedis;
-import redis.embedded.RedisServer;
-
-import java.io.IOException;
-import java.net.URISyntaxException;
-import java.util.List;
-import java.util.concurrent.TimeUnit;
-
-import static org.junit.Assert.fail;
-
-public class RateLimiterTest extends AbstractRedisClusterTest {
-
- private static final long NOW_MILLIS = System.currentTimeMillis();
- private static final String KEY = "key";
-
- private RedisServer redisServer;
-
- private ReplicatedJedisPool replicatedJedisPool;
-
- @FunctionalInterface
- private interface RateLimitedTask {
- void run() throws RateLimitExceededException;
- }
-
- @Before
- public void clearCache() throws URISyntaxException, IOException {
- redisServer = new RedisServer(AbstractRedisClusterTest.getNextRedisClusterPort());
- redisServer.start();
-
- final String redisUrl = "redis://127.0.0.1:" + redisServer.ports().get(0);
- replicatedJedisPool = new RedisClientFactory("test-pool", redisUrl, List.of(redisUrl), new CircuitBreakerConfiguration()).getRedisClientPool();
-
- try (final Jedis jedis = replicatedJedisPool.getWriteResource()) {
- jedis.flushAll();
- }
-
- getRedisCluster().useWriteCluster(connection -> connection.sync().flushall());
- }
-
- @After
- public void stopServer() {
- redisServer.stop();
- }
-
- @Test
- public void validate() throws RateLimitExceededException, IOException {
- final RateLimiter rateLimiter = buildRateLimiter(2, 0.5);
-
- rateLimiter.validate(KEY, 1, NOW_MILLIS);
- rateLimiter.validate(KEY, 1, NOW_MILLIS);
- assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
- }
-
- @Test
- public void validateWithAmount() throws RateLimitExceededException, IOException {
- final RateLimiter rateLimiter = buildRateLimiter(2, 0.5);
-
- rateLimiter.validate(KEY, 2, NOW_MILLIS);
- assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
- }
-
- @Test
- public void testLapseRate() throws RateLimitExceededException, IOException {
- final RateLimiter rateLimiter = buildRateLimiter(2, 8.333333333333334E-6);
- final String leakyBucketJson = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (NOW_MILLIS - TimeUnit.MINUTES.toMillis(2)) + "}";
-
- try (final Jedis jedis = replicatedJedisPool.getWriteResource()) {
- jedis.set(rateLimiter.getBucketName(KEY), leakyBucketJson);
- }
-
- getRedisCluster().useWriteCluster(connection -> connection.sync().set(rateLimiter.getBucketName(KEY), leakyBucketJson));
-
- rateLimiter.validate(KEY, 1, NOW_MILLIS);
- assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
- }
-
- @Test
- public void testLapseShort() throws IOException {
- final RateLimiter rateLimiter = buildRateLimiter(2, 8.333333333333334E-6);
- final String leakyBucketJson = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (NOW_MILLIS - TimeUnit.MINUTES.toMillis(1)) + "}";
-
- try (final Jedis jedis = replicatedJedisPool.getWriteResource()) {
- jedis.set(rateLimiter.getBucketName(KEY), leakyBucketJson);
- }
-
- getRedisCluster().useWriteCluster(connection -> connection.sync().set(rateLimiter.getBucketName(KEY), leakyBucketJson));
-
- assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
- }
-
- private void assertRateLimitExceeded(final RateLimitedTask task) {
- try {
- task.run();
- fail("Expected RateLimitExceededException");
- } catch (final RateLimitExceededException ignored) {
- }
- }
-
- @SuppressWarnings("SameParameterValue")
- private RateLimiter buildRateLimiter(final int bucketSize, final double leakRatePerMilli) throws IOException {
- final double leakRatePerMinute = leakRatePerMilli * 60_000d;
- return new RateLimiter(replicatedJedisPool, getRedisCluster(), KEY, bucketSize, leakRatePerMinute);
- }
-}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisTest.java
deleted file mode 100644
index 2531fc6b0..000000000
--- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisTest.java
+++ /dev/null
@@ -1,48 +0,0 @@
-package org.whispersystems.textsecuregcm.redis;
-
-import org.junit.AfterClass;
-import org.junit.Before;
-import org.junit.BeforeClass;
-import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
-import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
-import redis.embedded.RedisServer;
-
-import java.io.IOException;
-import java.net.ServerSocket;
-import java.net.URISyntaxException;
-import java.util.List;
-
-public abstract class AbstractRedisTest {
-
- private static RedisServer redisServer;
-
- private ReplicatedJedisPool replicatedJedisPool;
-
- @BeforeClass
- public static void setUpBeforeClass() throws IOException {
- redisServer = new RedisServer(getNextPort());
- redisServer.start();
- }
-
- @Before
- public void setUp() throws URISyntaxException {
- final String redisUrl = "redis://127.0.0.1:" + redisServer.ports().get(0);
- replicatedJedisPool = new RedisClientFactory("test-pool", redisUrl, List.of(redisUrl), new CircuitBreakerConfiguration()).getRedisClientPool();
- }
-
- protected ReplicatedJedisPool getReplicatedJedisPool() {
- return replicatedJedisPool;
- }
-
- @AfterClass
- public static void tearDownAfterClass() {
- redisServer.stop();
- }
-
- private static int getNextPort() throws IOException {
- try (ServerSocket socket = new ServerSocket(0)) {
- socket.setReuseAddress(false);
- return socket.getLocalPort();
- }
- }
-}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/limits/LeakyBucketTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/limits/LeakyBucketTest.java
new file mode 100644
index 000000000..213baec79
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/limits/LeakyBucketTest.java
@@ -0,0 +1,52 @@
+package org.whispersystems.textsecuregcm.tests.limits;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.junit.Test;
+import org.whispersystems.textsecuregcm.limits.LeakyBucket;
+
+import java.io.IOException;
+import java.util.concurrent.TimeUnit;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+public class LeakyBucketTest {
+
+ @Test
+ public void testFull() {
+ LeakyBucket leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
+
+ assertTrue(leakyBucket.add(1));
+ assertTrue(leakyBucket.add(1));
+ assertFalse(leakyBucket.add(1));
+
+ leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
+
+ assertTrue(leakyBucket.add(2));
+ assertFalse(leakyBucket.add(1));
+ assertFalse(leakyBucket.add(2));
+ }
+
+ @Test
+ public void testLapseRate() throws IOException {
+ ObjectMapper mapper = new ObjectMapper();
+ String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(2)) + "}";
+
+ LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
+ assertTrue(leakyBucket.add(1));
+
+ String serializedAgain = leakyBucket.serialize(mapper);
+ LeakyBucket leakyBucketAgain = LeakyBucket.fromSerialized(mapper, serializedAgain);
+
+ assertFalse(leakyBucketAgain.add(1));
+ }
+
+ @Test
+ public void testLapseShort() throws Exception {
+ ObjectMapper mapper = new ObjectMapper();
+ String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
+
+ LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
+ assertFalse(leakyBucket.add(1));
+ }
+}