Move rate limiter logic to Lua scripts
This commit is contained in:
parent
f5ddb0f1f8
commit
b585c6676d
|
@ -1,98 +0,0 @@
|
|||
/**
|
||||
* Copyright (C) 2013 Open WhisperSystems
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU Affero General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU Affero General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.limits;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class LeakyBucket {
|
||||
|
||||
private final int bucketSize;
|
||||
private final double leakRatePerMillis;
|
||||
|
||||
private int spaceRemaining;
|
||||
private long lastUpdateTimeMillis;
|
||||
|
||||
public LeakyBucket(int bucketSize, double leakRatePerMillis) {
|
||||
this(bucketSize, leakRatePerMillis, bucketSize, System.currentTimeMillis());
|
||||
}
|
||||
|
||||
private LeakyBucket(int bucketSize, double leakRatePerMillis, int spaceRemaining, long lastUpdateTimeMillis) {
|
||||
this.bucketSize = bucketSize;
|
||||
this.leakRatePerMillis = leakRatePerMillis;
|
||||
this.spaceRemaining = spaceRemaining;
|
||||
this.lastUpdateTimeMillis = lastUpdateTimeMillis;
|
||||
}
|
||||
|
||||
public boolean add(int amount) {
|
||||
this.spaceRemaining = getUpdatedSpaceRemaining();
|
||||
this.lastUpdateTimeMillis = System.currentTimeMillis();
|
||||
|
||||
if (this.spaceRemaining >= amount) {
|
||||
this.spaceRemaining -= amount;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private int getUpdatedSpaceRemaining() {
|
||||
long elapsedTime = System.currentTimeMillis() - this.lastUpdateTimeMillis;
|
||||
|
||||
return Math.min(this.bucketSize,
|
||||
(int)Math.floor(this.spaceRemaining + (elapsedTime * this.leakRatePerMillis)));
|
||||
}
|
||||
|
||||
public String serialize(ObjectMapper mapper) throws JsonProcessingException {
|
||||
return mapper.writeValueAsString(new LeakyBucketEntity(bucketSize, leakRatePerMillis, spaceRemaining, lastUpdateTimeMillis));
|
||||
}
|
||||
|
||||
public static LeakyBucket fromSerialized(ObjectMapper mapper, String serialized) throws IOException {
|
||||
LeakyBucketEntity entity = mapper.readValue(serialized, LeakyBucketEntity.class);
|
||||
|
||||
return new LeakyBucket(entity.bucketSize, entity.leakRatePerMillis,
|
||||
entity.spaceRemaining, entity.lastUpdateTimeMillis);
|
||||
}
|
||||
|
||||
private static class LeakyBucketEntity {
|
||||
@JsonProperty
|
||||
private int bucketSize;
|
||||
|
||||
@JsonProperty
|
||||
private double leakRatePerMillis;
|
||||
|
||||
@JsonProperty
|
||||
private int spaceRemaining;
|
||||
|
||||
@JsonProperty
|
||||
private long lastUpdateTimeMillis;
|
||||
|
||||
public LeakyBucketEntity() {}
|
||||
|
||||
private LeakyBucketEntity(int bucketSize, double leakRatePerMillis,
|
||||
int spaceRemaining, long lastUpdateTimeMillis)
|
||||
{
|
||||
this.bucketSize = bucketSize;
|
||||
this.leakRatePerMillis = leakRatePerMillis;
|
||||
this.spaceRemaining = spaceRemaining;
|
||||
this.lastUpdateTimeMillis = lastUpdateTimeMillis;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,74 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.limits;
|
||||
|
||||
import com.codahale.metrics.Meter;
|
||||
import com.codahale.metrics.MetricRegistry;
|
||||
import com.codahale.metrics.SharedMetricRegistries;
|
||||
import io.lettuce.core.SetArgs;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
|
||||
import org.whispersystems.textsecuregcm.util.Constants;
|
||||
|
||||
import static com.codahale.metrics.MetricRegistry.name;
|
||||
import redis.clients.jedis.Jedis;
|
||||
|
||||
public class LockingRateLimiter extends RateLimiter {
|
||||
|
||||
private final Meter meter;
|
||||
|
||||
public LockingRateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name, int bucketSize, double leakRatePerMinute) {
|
||||
super(cacheClient, cacheCluster, name, bucketSize, leakRatePerMinute);
|
||||
|
||||
MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
|
||||
this.meter = metricRegistry.meter(name(getClass(), name, "locked"));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void validate(String key, int amount) throws RateLimitExceededException {
|
||||
if (!acquireLock(key)) {
|
||||
meter.mark();
|
||||
throw new RateLimitExceededException("Locked");
|
||||
}
|
||||
|
||||
try {
|
||||
super.validate(key, amount);
|
||||
} finally {
|
||||
releaseLock(key);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void validate(String key) throws RateLimitExceededException {
|
||||
validate(key, 1);
|
||||
}
|
||||
|
||||
private void releaseLock(String key) {
|
||||
try (Jedis jedis = cacheClient.getWriteResource()) {
|
||||
final String lockName = getLockName(key);
|
||||
|
||||
jedis.del(lockName);
|
||||
cacheCluster.useWriteCluster(connection -> connection.sync().del(lockName));
|
||||
}
|
||||
}
|
||||
|
||||
private boolean acquireLock(String key) {
|
||||
try (Jedis jedis = cacheClient.getWriteResource()) {
|
||||
final String lockName = getLockName(key);
|
||||
|
||||
final boolean acquiredLock = jedis.set(lockName, "L", "NX", "EX", 10) != null;
|
||||
|
||||
if (acquiredLock) {
|
||||
// TODO Restore the NX flag when the cluster becomes the primary source of truth
|
||||
cacheCluster.useWriteCluster(connection -> connection.sync().set(lockName, "L", SetArgs.Builder.ex(10)));
|
||||
}
|
||||
|
||||
return acquiredLock;
|
||||
}
|
||||
}
|
||||
|
||||
private String getLockName(String key) {
|
||||
return "leaky_lock::" + name + "::" + key;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -20,26 +20,25 @@ import com.codahale.metrics.Meter;
|
|||
import com.codahale.metrics.MetricRegistry;
|
||||
import com.codahale.metrics.SharedMetricRegistries;
|
||||
import com.codahale.metrics.Timer;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.lettuce.core.ScriptOutputType;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.experiment.Experiment;
|
||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||
import org.whispersystems.textsecuregcm.redis.LuaScript;
|
||||
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
|
||||
import org.whispersystems.textsecuregcm.util.Constants;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static com.codahale.metrics.MetricRegistry.name;
|
||||
import redis.clients.jedis.Jedis;
|
||||
|
||||
public class RateLimiter {
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(RateLimiter.class);
|
||||
private final ObjectMapper mapper = SystemMapper.getMapper();
|
||||
import static com.codahale.metrics.MetricRegistry.name;
|
||||
|
||||
public class RateLimiter {
|
||||
|
||||
private final Meter meter;
|
||||
private final Timer validateTimer;
|
||||
|
@ -48,19 +47,12 @@ public class RateLimiter {
|
|||
protected final String name;
|
||||
private final int bucketSize;
|
||||
private final double leakRatePerMillis;
|
||||
private final boolean reportLimits;
|
||||
private final Experiment redisClusterExperiment;
|
||||
private final LuaScript validateScript;
|
||||
private final ClusterLuaScript clusterValidateScript;
|
||||
|
||||
public RateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name,
|
||||
int bucketSize, double leakRatePerMinute)
|
||||
{
|
||||
this(cacheClient, cacheCluster, name, bucketSize, leakRatePerMinute, false);
|
||||
}
|
||||
|
||||
public RateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name,
|
||||
int bucketSize, double leakRatePerMinute,
|
||||
boolean reportLimits)
|
||||
{
|
||||
int bucketSize, double leakRatePerMinute) throws IOException {
|
||||
MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
|
||||
|
||||
this.meter = metricRegistry.meter(name(getClass(), name, "exceeded"));
|
||||
|
@ -70,27 +62,37 @@ public class RateLimiter {
|
|||
this.name = name;
|
||||
this.bucketSize = bucketSize;
|
||||
this.leakRatePerMillis = leakRatePerMinute / (60.0 * 1000.0);
|
||||
this.reportLimits = reportLimits;
|
||||
this.redisClusterExperiment = new Experiment("RedisCluster", "RateLimiter", name);
|
||||
}
|
||||
|
||||
public void validate(String key, int amount) throws RateLimitExceededException {
|
||||
try (final Timer.Context ignored = validateTimer.time()) {
|
||||
LeakyBucket bucket = getBucket(key);
|
||||
|
||||
if (bucket.add(amount)) {
|
||||
setBucket(key, bucket);
|
||||
} else {
|
||||
meter.mark();
|
||||
throw new RateLimitExceededException(key + " , " + amount);
|
||||
}
|
||||
}
|
||||
this.validateScript = LuaScript.fromResource(cacheClient, "lua/validate_rate_limit.lua");
|
||||
this.clusterValidateScript = ClusterLuaScript.fromResource(cacheCluster, "lua/validate_rate_limit.lua", ScriptOutputType.INTEGER);
|
||||
}
|
||||
|
||||
public void validate(String key) throws RateLimitExceededException {
|
||||
validate(key, 1);
|
||||
}
|
||||
|
||||
public void validate(String key, int amount) throws RateLimitExceededException {
|
||||
validate(key, amount, System.currentTimeMillis());
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
void validate(String key, int amount, final long currentTimeMillis) throws RateLimitExceededException {
|
||||
try (final Timer.Context ignored = validateTimer.time()) {
|
||||
final List<String> keys = List.of(getBucketName(key));
|
||||
final List<String> arguments = List.of(String.valueOf(bucketSize), String.valueOf(leakRatePerMillis), String.valueOf(currentTimeMillis), String.valueOf(amount));
|
||||
|
||||
final Object result = validateScript.execute(keys.stream().map(k -> k.getBytes(StandardCharsets.UTF_8)).collect(Collectors.toList()),
|
||||
arguments.stream().map(a -> a.getBytes(StandardCharsets.UTF_8)).collect(Collectors.toList()));
|
||||
|
||||
redisClusterExperiment.compareSupplierResult(result, () -> clusterValidateScript.execute(keys, arguments));
|
||||
|
||||
if (result == null) {
|
||||
meter.mark();
|
||||
throw new RateLimitExceededException(key + " , " + amount);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void clear(String key) {
|
||||
try (Jedis jedis = cacheClient.getWriteResource()) {
|
||||
final String bucketName = getBucketName(key);
|
||||
|
@ -100,37 +102,8 @@ public class RateLimiter {
|
|||
}
|
||||
}
|
||||
|
||||
private void setBucket(String key, LeakyBucket bucket) {
|
||||
try (Jedis jedis = cacheClient.getWriteResource()) {
|
||||
final String bucketName = getBucketName(key);
|
||||
final String serialized = bucket.serialize(mapper);
|
||||
final int level = (int) Math.ceil((bucketSize / leakRatePerMillis) / 1000);
|
||||
|
||||
jedis.setex(bucketName, level, serialized);
|
||||
cacheCluster.useWriteCluster(connection -> connection.sync().setex(bucketName, level, serialized));
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new IllegalArgumentException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private LeakyBucket getBucket(String key) {
|
||||
try (Jedis jedis = cacheClient.getReadResource()) {
|
||||
final String bucketName = getBucketName(key);
|
||||
|
||||
String serialized = jedis.get(bucketName);
|
||||
redisClusterExperiment.compareSupplierResult(serialized, () -> cacheCluster.withReadCluster(connection -> connection.sync().get(bucketName)));
|
||||
|
||||
if (serialized != null) {
|
||||
return LeakyBucket.fromSerialized(mapper, serialized);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
logger.warn("Deserialization error", e);
|
||||
}
|
||||
|
||||
return new LeakyBucket(bucketSize, leakRatePerMillis);
|
||||
}
|
||||
|
||||
private String getBucketName(String key) {
|
||||
@VisibleForTesting
|
||||
String getBucketName(String key) {
|
||||
return "leaky_bucket::" + name + "::" + key;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,8 @@ import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
|
|||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class RateLimiters {
|
||||
|
||||
private final RateLimiter smsDestinationLimiter;
|
||||
|
@ -48,7 +50,7 @@ public class RateLimiters {
|
|||
private final RateLimiter usernameLookupLimiter;
|
||||
private final RateLimiter usernameSetLimiter;
|
||||
|
||||
public RateLimiters(RateLimitsConfiguration config, ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster) {
|
||||
public RateLimiters(RateLimitsConfiguration config, ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster) throws IOException {
|
||||
this.smsDestinationLimiter = new RateLimiter(cacheClient, cacheCluster, "smsDestination",
|
||||
config.getSmsDestination().getBucketSize(),
|
||||
config.getSmsDestination().getLeakRatePerMinute());
|
||||
|
@ -73,11 +75,11 @@ public class RateLimiters {
|
|||
config.getAutoBlock().getBucketSize(),
|
||||
config.getAutoBlock().getLeakRatePerMinute());
|
||||
|
||||
this.verifyLimiter = new LockingRateLimiter(cacheClient, cacheCluster, "verify",
|
||||
this.verifyLimiter = new RateLimiter(cacheClient, cacheCluster, "verify",
|
||||
config.getVerifyNumber().getBucketSize(),
|
||||
config.getVerifyNumber().getLeakRatePerMinute());
|
||||
|
||||
this.pinLimiter = new LockingRateLimiter(cacheClient, cacheCluster, "pin",
|
||||
this.pinLimiter = new RateLimiter(cacheClient, cacheCluster, "pin",
|
||||
config.getVerifyPin().getBucketSize(),
|
||||
config.getVerifyPin().getLeakRatePerMinute());
|
||||
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
local bucketId = KEYS[1]
|
||||
|
||||
local bucketSize = tonumber(ARGV[1])
|
||||
local leakRatePerMillis = tonumber(ARGV[2])
|
||||
local currentTimeMillis = tonumber(ARGV[3])
|
||||
local amount = tonumber(ARGV[4])
|
||||
|
||||
local leakyBucket
|
||||
|
||||
if redis.call("EXISTS", bucketId) == 1 then
|
||||
leakyBucket = cjson.decode(redis.call("GET", bucketId))
|
||||
else
|
||||
leakyBucket = {
|
||||
bucketSize = bucketSize,
|
||||
leakRatePerMillis = leakRatePerMillis,
|
||||
spaceRemaining = bucketSize,
|
||||
lastUpdateTimeMillis = currentTimeMillis
|
||||
}
|
||||
end
|
||||
|
||||
local elapsedTime = currentTimeMillis - leakyBucket["lastUpdateTimeMillis"]
|
||||
local updatedSpaceRemaining = math.min(leakyBucket["bucketSize"], math.floor(leakyBucket["spaceRemaining"] + (elapsedTime * leakyBucket["leakRatePerMillis"])))
|
||||
|
||||
redis.call("SET", "elapsedTime", elapsedTime)
|
||||
redis.call("SET", "updatedSpaceRemaining", updatedSpaceRemaining)
|
||||
|
||||
if updatedSpaceRemaining >= amount then
|
||||
leakyBucket["spaceRemaining"] = updatedSpaceRemaining - amount
|
||||
redis.call("SET", bucketId, cjson.encode(leakyBucket))
|
||||
return true
|
||||
else
|
||||
return false
|
||||
end
|
|
@ -0,0 +1,88 @@
|
|||
package org.whispersystems.textsecuregcm.limits;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.redis.AbstractRedisTest;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||
import redis.clients.jedis.Jedis;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.Assert.fail;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public class RateLimiterTest extends AbstractRedisTest {
|
||||
|
||||
private static final long NOW_MILLIS = System.currentTimeMillis();
|
||||
private static final String KEY = "key";
|
||||
|
||||
@FunctionalInterface
|
||||
private interface RateLimitedTask {
|
||||
void run() throws RateLimitExceededException;
|
||||
}
|
||||
|
||||
@Before
|
||||
public void clearCache() {
|
||||
try (final Jedis jedis = getReplicatedJedisPool().getWriteResource()) {
|
||||
jedis.flushAll();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void validate() throws RateLimitExceededException, IOException {
|
||||
final RateLimiter rateLimiter = buildRateLimiter(2, 0.5);
|
||||
|
||||
rateLimiter.validate(KEY, 1, NOW_MILLIS);
|
||||
rateLimiter.validate(KEY, 1, NOW_MILLIS);
|
||||
assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void validateWithAmount() throws RateLimitExceededException, IOException {
|
||||
final RateLimiter rateLimiter = buildRateLimiter(2, 0.5);
|
||||
|
||||
rateLimiter.validate(KEY, 2, NOW_MILLIS);
|
||||
assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLapseRate() throws RateLimitExceededException, IOException {
|
||||
final RateLimiter rateLimiter = buildRateLimiter(2, 8.333333333333334E-6);
|
||||
final String leakyBucketJson = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (NOW_MILLIS - TimeUnit.MINUTES.toMillis(2)) + "}";
|
||||
|
||||
try (final Jedis jedis = getReplicatedJedisPool().getWriteResource()) {
|
||||
jedis.set(rateLimiter.getBucketName(KEY), leakyBucketJson);
|
||||
}
|
||||
|
||||
rateLimiter.validate(KEY, 1, NOW_MILLIS);
|
||||
assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLapseShort() throws IOException {
|
||||
final RateLimiter rateLimiter = buildRateLimiter(2, 8.333333333333334E-6);
|
||||
final String leakyBucketJson = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (NOW_MILLIS - TimeUnit.MINUTES.toMillis(1)) + "}";
|
||||
|
||||
try (final Jedis jedis = getReplicatedJedisPool().getWriteResource()) {
|
||||
jedis.set(rateLimiter.getBucketName(KEY), leakyBucketJson);
|
||||
}
|
||||
|
||||
assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
|
||||
}
|
||||
|
||||
private void assertRateLimitExceeded(final RateLimitedTask task) {
|
||||
try {
|
||||
task.run();
|
||||
fail("Expected RateLimitExceededException");
|
||||
} catch (final RateLimitExceededException ignored) {
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("SameParameterValue")
|
||||
private RateLimiter buildRateLimiter(final int bucketSize, final double leakRatePerMilli) throws IOException {
|
||||
final double leakRatePerMinute = leakRatePerMilli * 60_000d;
|
||||
return new RateLimiter(getReplicatedJedisPool(), mock(FaultTolerantRedisCluster.class), KEY, bucketSize, leakRatePerMinute);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
package org.whispersystems.textsecuregcm.redis;
|
||||
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.Before;
|
||||
import org.junit.BeforeClass;
|
||||
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
|
||||
import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
|
||||
import redis.embedded.RedisServer;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.ServerSocket;
|
||||
import java.net.URISyntaxException;
|
||||
import java.util.List;
|
||||
|
||||
public abstract class AbstractRedisTest {
|
||||
|
||||
private static RedisServer redisServer;
|
||||
|
||||
private ReplicatedJedisPool replicatedJedisPool;
|
||||
|
||||
@BeforeClass
|
||||
public static void setUpBeforeClass() throws IOException {
|
||||
redisServer = new RedisServer(getNextPort());
|
||||
redisServer.start();
|
||||
}
|
||||
|
||||
@Before
|
||||
public void setUp() throws URISyntaxException {
|
||||
final String redisUrl = "redis://127.0.0.1:" + redisServer.ports().get(0);
|
||||
replicatedJedisPool = new RedisClientFactory("test-pool", redisUrl, List.of(redisUrl), new CircuitBreakerConfiguration()).getRedisClientPool();
|
||||
}
|
||||
|
||||
protected ReplicatedJedisPool getReplicatedJedisPool() {
|
||||
return replicatedJedisPool;
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void tearDownAfterClass() {
|
||||
redisServer.stop();
|
||||
}
|
||||
|
||||
private static int getNextPort() throws IOException {
|
||||
try (ServerSocket socket = new ServerSocket(0)) {
|
||||
socket.setReuseAddress(false);
|
||||
return socket.getLocalPort();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,52 +0,0 @@
|
|||
package org.whispersystems.textsecuregcm.tests.limits;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.junit.Test;
|
||||
import org.whispersystems.textsecuregcm.limits.LeakyBucket;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class LeakyBucketTest {
|
||||
|
||||
@Test
|
||||
public void testFull() {
|
||||
LeakyBucket leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
|
||||
|
||||
assertTrue(leakyBucket.add(1));
|
||||
assertTrue(leakyBucket.add(1));
|
||||
assertFalse(leakyBucket.add(1));
|
||||
|
||||
leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
|
||||
|
||||
assertTrue(leakyBucket.add(2));
|
||||
assertFalse(leakyBucket.add(1));
|
||||
assertFalse(leakyBucket.add(2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLapseRate() throws IOException {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(2)) + "}";
|
||||
|
||||
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
|
||||
assertTrue(leakyBucket.add(1));
|
||||
|
||||
String serializedAgain = leakyBucket.serialize(mapper);
|
||||
LeakyBucket leakyBucketAgain = LeakyBucket.fromSerialized(mapper, serializedAgain);
|
||||
|
||||
assertFalse(leakyBucketAgain.add(1));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLapseShort() throws Exception {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
|
||||
|
||||
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
|
||||
assertFalse(leakyBucket.add(1));
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue