From b585c6676ded56c9cf203d8bda3ac323ca41b5d6 Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Mon, 6 Jul 2020 10:10:13 -0400 Subject: [PATCH] Move rate limiter logic to Lua scripts --- .../textsecuregcm/limits/LeakyBucket.java | 98 ---------------- .../limits/LockingRateLimiter.java | 74 ------------ .../textsecuregcm/limits/RateLimiter.java | 107 +++++++----------- .../textsecuregcm/limits/RateLimiters.java | 8 +- .../resources/lua/validate_rate_limit.lua | 33 ++++++ .../textsecuregcm/limits/RateLimiterTest.java | 88 ++++++++++++++ .../redis/AbstractRedisTest.java | 48 ++++++++ .../tests/limits/LeakyBucketTest.java | 52 --------- 8 files changed, 214 insertions(+), 294 deletions(-) delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java create mode 100644 service/src/main/resources/lua/validate_rate_limit.lua create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimiterTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisTest.java delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/tests/limits/LeakyBucketTest.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java deleted file mode 100644 index 332a86869..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java +++ /dev/null @@ -1,98 +0,0 @@ -/** - * Copyright (C) 2013 Open WhisperSystems - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ -package org.whispersystems.textsecuregcm.limits; - -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; - -import java.io.IOException; - -public class LeakyBucket { - - private final int bucketSize; - private final double leakRatePerMillis; - - private int spaceRemaining; - private long lastUpdateTimeMillis; - - public LeakyBucket(int bucketSize, double leakRatePerMillis) { - this(bucketSize, leakRatePerMillis, bucketSize, System.currentTimeMillis()); - } - - private LeakyBucket(int bucketSize, double leakRatePerMillis, int spaceRemaining, long lastUpdateTimeMillis) { - this.bucketSize = bucketSize; - this.leakRatePerMillis = leakRatePerMillis; - this.spaceRemaining = spaceRemaining; - this.lastUpdateTimeMillis = lastUpdateTimeMillis; - } - - public boolean add(int amount) { - this.spaceRemaining = getUpdatedSpaceRemaining(); - this.lastUpdateTimeMillis = System.currentTimeMillis(); - - if (this.spaceRemaining >= amount) { - this.spaceRemaining -= amount; - return true; - } else { - return false; - } - } - - private int getUpdatedSpaceRemaining() { - long elapsedTime = System.currentTimeMillis() - this.lastUpdateTimeMillis; - - return Math.min(this.bucketSize, - (int)Math.floor(this.spaceRemaining + (elapsedTime * this.leakRatePerMillis))); - } - - public String serialize(ObjectMapper mapper) throws JsonProcessingException { - return mapper.writeValueAsString(new LeakyBucketEntity(bucketSize, leakRatePerMillis, spaceRemaining, lastUpdateTimeMillis)); - } - - public static LeakyBucket fromSerialized(ObjectMapper mapper, String serialized) throws IOException { - LeakyBucketEntity entity = mapper.readValue(serialized, LeakyBucketEntity.class); - - return new LeakyBucket(entity.bucketSize, entity.leakRatePerMillis, - entity.spaceRemaining, entity.lastUpdateTimeMillis); - } - - private static class LeakyBucketEntity { - @JsonProperty - private int bucketSize; - - @JsonProperty - private double leakRatePerMillis; - - @JsonProperty - private int spaceRemaining; - - @JsonProperty - private long lastUpdateTimeMillis; - - public LeakyBucketEntity() {} - - private LeakyBucketEntity(int bucketSize, double leakRatePerMillis, - int spaceRemaining, long lastUpdateTimeMillis) - { - this.bucketSize = bucketSize; - this.leakRatePerMillis = leakRatePerMillis; - this.spaceRemaining = spaceRemaining; - this.lastUpdateTimeMillis = lastUpdateTimeMillis; - } - } -} \ No newline at end of file diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java deleted file mode 100644 index c16079790..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java +++ /dev/null @@ -1,74 +0,0 @@ -package org.whispersystems.textsecuregcm.limits; - -import com.codahale.metrics.Meter; -import com.codahale.metrics.MetricRegistry; -import com.codahale.metrics.SharedMetricRegistries; -import io.lettuce.core.SetArgs; -import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; -import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; -import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool; -import org.whispersystems.textsecuregcm.util.Constants; - -import static com.codahale.metrics.MetricRegistry.name; -import redis.clients.jedis.Jedis; - -public class LockingRateLimiter extends RateLimiter { - - private final Meter meter; - - public LockingRateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name, int bucketSize, double leakRatePerMinute) { - super(cacheClient, cacheCluster, name, bucketSize, leakRatePerMinute); - - MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); - this.meter = metricRegistry.meter(name(getClass(), name, "locked")); - } - - @Override - public void validate(String key, int amount) throws RateLimitExceededException { - if (!acquireLock(key)) { - meter.mark(); - throw new RateLimitExceededException("Locked"); - } - - try { - super.validate(key, amount); - } finally { - releaseLock(key); - } - } - - @Override - public void validate(String key) throws RateLimitExceededException { - validate(key, 1); - } - - private void releaseLock(String key) { - try (Jedis jedis = cacheClient.getWriteResource()) { - final String lockName = getLockName(key); - - jedis.del(lockName); - cacheCluster.useWriteCluster(connection -> connection.sync().del(lockName)); - } - } - - private boolean acquireLock(String key) { - try (Jedis jedis = cacheClient.getWriteResource()) { - final String lockName = getLockName(key); - - final boolean acquiredLock = jedis.set(lockName, "L", "NX", "EX", 10) != null; - - if (acquiredLock) { - // TODO Restore the NX flag when the cluster becomes the primary source of truth - cacheCluster.useWriteCluster(connection -> connection.sync().set(lockName, "L", SetArgs.Builder.ex(10))); - } - - return acquiredLock; - } - } - - private String getLockName(String key) { - return "leaky_lock::" + name + "::" + key; - } - - -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java index 81780df7a..8d0a82e55 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java @@ -20,26 +20,25 @@ import com.codahale.metrics.Meter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.Timer; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import com.google.common.annotations.VisibleForTesting; +import io.lettuce.core.ScriptOutputType; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.experiment.Experiment; +import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import org.whispersystems.textsecuregcm.redis.LuaScript; import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool; import org.whispersystems.textsecuregcm.util.Constants; -import org.whispersystems.textsecuregcm.util.SystemMapper; - -import java.io.IOException; - -import static com.codahale.metrics.MetricRegistry.name; import redis.clients.jedis.Jedis; -public class RateLimiter { +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.stream.Collectors; - private final Logger logger = LoggerFactory.getLogger(RateLimiter.class); - private final ObjectMapper mapper = SystemMapper.getMapper(); +import static com.codahale.metrics.MetricRegistry.name; + +public class RateLimiter { private final Meter meter; private final Timer validateTimer; @@ -48,19 +47,12 @@ public class RateLimiter { protected final String name; private final int bucketSize; private final double leakRatePerMillis; - private final boolean reportLimits; private final Experiment redisClusterExperiment; + private final LuaScript validateScript; + private final ClusterLuaScript clusterValidateScript; public RateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name, - int bucketSize, double leakRatePerMinute) - { - this(cacheClient, cacheCluster, name, bucketSize, leakRatePerMinute, false); - } - - public RateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name, - int bucketSize, double leakRatePerMinute, - boolean reportLimits) - { + int bucketSize, double leakRatePerMinute) throws IOException { MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); this.meter = metricRegistry.meter(name(getClass(), name, "exceeded")); @@ -70,27 +62,37 @@ public class RateLimiter { this.name = name; this.bucketSize = bucketSize; this.leakRatePerMillis = leakRatePerMinute / (60.0 * 1000.0); - this.reportLimits = reportLimits; this.redisClusterExperiment = new Experiment("RedisCluster", "RateLimiter", name); - } - - public void validate(String key, int amount) throws RateLimitExceededException { - try (final Timer.Context ignored = validateTimer.time()) { - LeakyBucket bucket = getBucket(key); - - if (bucket.add(amount)) { - setBucket(key, bucket); - } else { - meter.mark(); - throw new RateLimitExceededException(key + " , " + amount); - } - } + this.validateScript = LuaScript.fromResource(cacheClient, "lua/validate_rate_limit.lua"); + this.clusterValidateScript = ClusterLuaScript.fromResource(cacheCluster, "lua/validate_rate_limit.lua", ScriptOutputType.INTEGER); } public void validate(String key) throws RateLimitExceededException { validate(key, 1); } + public void validate(String key, int amount) throws RateLimitExceededException { + validate(key, amount, System.currentTimeMillis()); + } + + @VisibleForTesting + void validate(String key, int amount, final long currentTimeMillis) throws RateLimitExceededException { + try (final Timer.Context ignored = validateTimer.time()) { + final List keys = List.of(getBucketName(key)); + final List arguments = List.of(String.valueOf(bucketSize), String.valueOf(leakRatePerMillis), String.valueOf(currentTimeMillis), String.valueOf(amount)); + + final Object result = validateScript.execute(keys.stream().map(k -> k.getBytes(StandardCharsets.UTF_8)).collect(Collectors.toList()), + arguments.stream().map(a -> a.getBytes(StandardCharsets.UTF_8)).collect(Collectors.toList())); + + redisClusterExperiment.compareSupplierResult(result, () -> clusterValidateScript.execute(keys, arguments)); + + if (result == null) { + meter.mark(); + throw new RateLimitExceededException(key + " , " + amount); + } + } + } + public void clear(String key) { try (Jedis jedis = cacheClient.getWriteResource()) { final String bucketName = getBucketName(key); @@ -100,37 +102,8 @@ public class RateLimiter { } } - private void setBucket(String key, LeakyBucket bucket) { - try (Jedis jedis = cacheClient.getWriteResource()) { - final String bucketName = getBucketName(key); - final String serialized = bucket.serialize(mapper); - final int level = (int) Math.ceil((bucketSize / leakRatePerMillis) / 1000); - - jedis.setex(bucketName, level, serialized); - cacheCluster.useWriteCluster(connection -> connection.sync().setex(bucketName, level, serialized)); - } catch (JsonProcessingException e) { - throw new IllegalArgumentException(e); - } - } - - private LeakyBucket getBucket(String key) { - try (Jedis jedis = cacheClient.getReadResource()) { - final String bucketName = getBucketName(key); - - String serialized = jedis.get(bucketName); - redisClusterExperiment.compareSupplierResult(serialized, () -> cacheCluster.withReadCluster(connection -> connection.sync().get(bucketName))); - - if (serialized != null) { - return LeakyBucket.fromSerialized(mapper, serialized); - } - } catch (IOException e) { - logger.warn("Deserialization error", e); - } - - return new LeakyBucket(bucketSize, leakRatePerMillis); - } - - private String getBucketName(String key) { + @VisibleForTesting + String getBucketName(String key) { return "leaky_bucket::" + name + "::" + key; } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java index 70e4974aa..1803fa3a2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java @@ -21,6 +21,8 @@ import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool; +import java.io.IOException; + public class RateLimiters { private final RateLimiter smsDestinationLimiter; @@ -48,7 +50,7 @@ public class RateLimiters { private final RateLimiter usernameLookupLimiter; private final RateLimiter usernameSetLimiter; - public RateLimiters(RateLimitsConfiguration config, ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster) { + public RateLimiters(RateLimitsConfiguration config, ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster) throws IOException { this.smsDestinationLimiter = new RateLimiter(cacheClient, cacheCluster, "smsDestination", config.getSmsDestination().getBucketSize(), config.getSmsDestination().getLeakRatePerMinute()); @@ -73,11 +75,11 @@ public class RateLimiters { config.getAutoBlock().getBucketSize(), config.getAutoBlock().getLeakRatePerMinute()); - this.verifyLimiter = new LockingRateLimiter(cacheClient, cacheCluster, "verify", + this.verifyLimiter = new RateLimiter(cacheClient, cacheCluster, "verify", config.getVerifyNumber().getBucketSize(), config.getVerifyNumber().getLeakRatePerMinute()); - this.pinLimiter = new LockingRateLimiter(cacheClient, cacheCluster, "pin", + this.pinLimiter = new RateLimiter(cacheClient, cacheCluster, "pin", config.getVerifyPin().getBucketSize(), config.getVerifyPin().getLeakRatePerMinute()); diff --git a/service/src/main/resources/lua/validate_rate_limit.lua b/service/src/main/resources/lua/validate_rate_limit.lua new file mode 100644 index 000000000..17c1e2215 --- /dev/null +++ b/service/src/main/resources/lua/validate_rate_limit.lua @@ -0,0 +1,33 @@ +local bucketId = KEYS[1] + +local bucketSize = tonumber(ARGV[1]) +local leakRatePerMillis = tonumber(ARGV[2]) +local currentTimeMillis = tonumber(ARGV[3]) +local amount = tonumber(ARGV[4]) + +local leakyBucket + +if redis.call("EXISTS", bucketId) == 1 then + leakyBucket = cjson.decode(redis.call("GET", bucketId)) +else + leakyBucket = { + bucketSize = bucketSize, + leakRatePerMillis = leakRatePerMillis, + spaceRemaining = bucketSize, + lastUpdateTimeMillis = currentTimeMillis + } +end + +local elapsedTime = currentTimeMillis - leakyBucket["lastUpdateTimeMillis"] +local updatedSpaceRemaining = math.min(leakyBucket["bucketSize"], math.floor(leakyBucket["spaceRemaining"] + (elapsedTime * leakyBucket["leakRatePerMillis"]))) + +redis.call("SET", "elapsedTime", elapsedTime) +redis.call("SET", "updatedSpaceRemaining", updatedSpaceRemaining) + +if updatedSpaceRemaining >= amount then + leakyBucket["spaceRemaining"] = updatedSpaceRemaining - amount + redis.call("SET", bucketId, cjson.encode(leakyBucket)) + return true +else + return false +end diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimiterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimiterTest.java new file mode 100644 index 000000000..59ce691ea --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimiterTest.java @@ -0,0 +1,88 @@ +package org.whispersystems.textsecuregcm.limits; + +import org.junit.Before; +import org.junit.Test; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.redis.AbstractRedisTest; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import redis.clients.jedis.Jedis; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; + +public class RateLimiterTest extends AbstractRedisTest { + + private static final long NOW_MILLIS = System.currentTimeMillis(); + private static final String KEY = "key"; + + @FunctionalInterface + private interface RateLimitedTask { + void run() throws RateLimitExceededException; + } + + @Before + public void clearCache() { + try (final Jedis jedis = getReplicatedJedisPool().getWriteResource()) { + jedis.flushAll(); + } + } + + @Test + public void validate() throws RateLimitExceededException, IOException { + final RateLimiter rateLimiter = buildRateLimiter(2, 0.5); + + rateLimiter.validate(KEY, 1, NOW_MILLIS); + rateLimiter.validate(KEY, 1, NOW_MILLIS); + assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS)); + } + + @Test + public void validateWithAmount() throws RateLimitExceededException, IOException { + final RateLimiter rateLimiter = buildRateLimiter(2, 0.5); + + rateLimiter.validate(KEY, 2, NOW_MILLIS); + assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS)); + } + + @Test + public void testLapseRate() throws RateLimitExceededException, IOException { + final RateLimiter rateLimiter = buildRateLimiter(2, 8.333333333333334E-6); + final String leakyBucketJson = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (NOW_MILLIS - TimeUnit.MINUTES.toMillis(2)) + "}"; + + try (final Jedis jedis = getReplicatedJedisPool().getWriteResource()) { + jedis.set(rateLimiter.getBucketName(KEY), leakyBucketJson); + } + + rateLimiter.validate(KEY, 1, NOW_MILLIS); + assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS)); + } + + @Test + public void testLapseShort() throws IOException { + final RateLimiter rateLimiter = buildRateLimiter(2, 8.333333333333334E-6); + final String leakyBucketJson = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (NOW_MILLIS - TimeUnit.MINUTES.toMillis(1)) + "}"; + + try (final Jedis jedis = getReplicatedJedisPool().getWriteResource()) { + jedis.set(rateLimiter.getBucketName(KEY), leakyBucketJson); + } + + assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS)); + } + + private void assertRateLimitExceeded(final RateLimitedTask task) { + try { + task.run(); + fail("Expected RateLimitExceededException"); + } catch (final RateLimitExceededException ignored) { + } + } + + @SuppressWarnings("SameParameterValue") + private RateLimiter buildRateLimiter(final int bucketSize, final double leakRatePerMilli) throws IOException { + final double leakRatePerMinute = leakRatePerMilli * 60_000d; + return new RateLimiter(getReplicatedJedisPool(), mock(FaultTolerantRedisCluster.class), KEY, bucketSize, leakRatePerMinute); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisTest.java new file mode 100644 index 000000000..2531fc6b0 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisTest.java @@ -0,0 +1,48 @@ +package org.whispersystems.textsecuregcm.redis; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; +import org.whispersystems.textsecuregcm.providers.RedisClientFactory; +import redis.embedded.RedisServer; + +import java.io.IOException; +import java.net.ServerSocket; +import java.net.URISyntaxException; +import java.util.List; + +public abstract class AbstractRedisTest { + + private static RedisServer redisServer; + + private ReplicatedJedisPool replicatedJedisPool; + + @BeforeClass + public static void setUpBeforeClass() throws IOException { + redisServer = new RedisServer(getNextPort()); + redisServer.start(); + } + + @Before + public void setUp() throws URISyntaxException { + final String redisUrl = "redis://127.0.0.1:" + redisServer.ports().get(0); + replicatedJedisPool = new RedisClientFactory("test-pool", redisUrl, List.of(redisUrl), new CircuitBreakerConfiguration()).getRedisClientPool(); + } + + protected ReplicatedJedisPool getReplicatedJedisPool() { + return replicatedJedisPool; + } + + @AfterClass + public static void tearDownAfterClass() { + redisServer.stop(); + } + + private static int getNextPort() throws IOException { + try (ServerSocket socket = new ServerSocket(0)) { + socket.setReuseAddress(false); + return socket.getLocalPort(); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/limits/LeakyBucketTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/limits/LeakyBucketTest.java deleted file mode 100644 index 213baec79..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/limits/LeakyBucketTest.java +++ /dev/null @@ -1,52 +0,0 @@ -package org.whispersystems.textsecuregcm.tests.limits; - -import com.fasterxml.jackson.databind.ObjectMapper; -import org.junit.Test; -import org.whispersystems.textsecuregcm.limits.LeakyBucket; - -import java.io.IOException; -import java.util.concurrent.TimeUnit; - -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -public class LeakyBucketTest { - - @Test - public void testFull() { - LeakyBucket leakyBucket = new LeakyBucket(2, 1.0 / 2.0); - - assertTrue(leakyBucket.add(1)); - assertTrue(leakyBucket.add(1)); - assertFalse(leakyBucket.add(1)); - - leakyBucket = new LeakyBucket(2, 1.0 / 2.0); - - assertTrue(leakyBucket.add(2)); - assertFalse(leakyBucket.add(1)); - assertFalse(leakyBucket.add(2)); - } - - @Test - public void testLapseRate() throws IOException { - ObjectMapper mapper = new ObjectMapper(); - String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(2)) + "}"; - - LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized); - assertTrue(leakyBucket.add(1)); - - String serializedAgain = leakyBucket.serialize(mapper); - LeakyBucket leakyBucketAgain = LeakyBucket.fromSerialized(mapper, serializedAgain); - - assertFalse(leakyBucketAgain.add(1)); - } - - @Test - public void testLapseShort() throws Exception { - ObjectMapper mapper = new ObjectMapper(); - String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}"; - - LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized); - assertFalse(leakyBucket.add(1)); - } -}