Revert "Move rate limiter logic to Lua scripts"
This reverts commit b585c6676d
.
This commit is contained in:
parent
062bf737c2
commit
c5d0d4acd0
|
@ -0,0 +1,98 @@
|
||||||
|
/**
|
||||||
|
* Copyright (C) 2013 Open WhisperSystems
|
||||||
|
*
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU Affero General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU Affero General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU Affero General Public License
|
||||||
|
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
package org.whispersystems.textsecuregcm.limits;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
public class LeakyBucket {
|
||||||
|
|
||||||
|
private final int bucketSize;
|
||||||
|
private final double leakRatePerMillis;
|
||||||
|
|
||||||
|
private int spaceRemaining;
|
||||||
|
private long lastUpdateTimeMillis;
|
||||||
|
|
||||||
|
public LeakyBucket(int bucketSize, double leakRatePerMillis) {
|
||||||
|
this(bucketSize, leakRatePerMillis, bucketSize, System.currentTimeMillis());
|
||||||
|
}
|
||||||
|
|
||||||
|
private LeakyBucket(int bucketSize, double leakRatePerMillis, int spaceRemaining, long lastUpdateTimeMillis) {
|
||||||
|
this.bucketSize = bucketSize;
|
||||||
|
this.leakRatePerMillis = leakRatePerMillis;
|
||||||
|
this.spaceRemaining = spaceRemaining;
|
||||||
|
this.lastUpdateTimeMillis = lastUpdateTimeMillis;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean add(int amount) {
|
||||||
|
this.spaceRemaining = getUpdatedSpaceRemaining();
|
||||||
|
this.lastUpdateTimeMillis = System.currentTimeMillis();
|
||||||
|
|
||||||
|
if (this.spaceRemaining >= amount) {
|
||||||
|
this.spaceRemaining -= amount;
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private int getUpdatedSpaceRemaining() {
|
||||||
|
long elapsedTime = System.currentTimeMillis() - this.lastUpdateTimeMillis;
|
||||||
|
|
||||||
|
return Math.min(this.bucketSize,
|
||||||
|
(int)Math.floor(this.spaceRemaining + (elapsedTime * this.leakRatePerMillis)));
|
||||||
|
}
|
||||||
|
|
||||||
|
public String serialize(ObjectMapper mapper) throws JsonProcessingException {
|
||||||
|
return mapper.writeValueAsString(new LeakyBucketEntity(bucketSize, leakRatePerMillis, spaceRemaining, lastUpdateTimeMillis));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static LeakyBucket fromSerialized(ObjectMapper mapper, String serialized) throws IOException {
|
||||||
|
LeakyBucketEntity entity = mapper.readValue(serialized, LeakyBucketEntity.class);
|
||||||
|
|
||||||
|
return new LeakyBucket(entity.bucketSize, entity.leakRatePerMillis,
|
||||||
|
entity.spaceRemaining, entity.lastUpdateTimeMillis);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class LeakyBucketEntity {
|
||||||
|
@JsonProperty
|
||||||
|
private int bucketSize;
|
||||||
|
|
||||||
|
@JsonProperty
|
||||||
|
private double leakRatePerMillis;
|
||||||
|
|
||||||
|
@JsonProperty
|
||||||
|
private int spaceRemaining;
|
||||||
|
|
||||||
|
@JsonProperty
|
||||||
|
private long lastUpdateTimeMillis;
|
||||||
|
|
||||||
|
public LeakyBucketEntity() {}
|
||||||
|
|
||||||
|
private LeakyBucketEntity(int bucketSize, double leakRatePerMillis,
|
||||||
|
int spaceRemaining, long lastUpdateTimeMillis)
|
||||||
|
{
|
||||||
|
this.bucketSize = bucketSize;
|
||||||
|
this.leakRatePerMillis = leakRatePerMillis;
|
||||||
|
this.spaceRemaining = spaceRemaining;
|
||||||
|
this.lastUpdateTimeMillis = lastUpdateTimeMillis;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,74 @@
|
||||||
|
package org.whispersystems.textsecuregcm.limits;
|
||||||
|
|
||||||
|
import com.codahale.metrics.Meter;
|
||||||
|
import com.codahale.metrics.MetricRegistry;
|
||||||
|
import com.codahale.metrics.SharedMetricRegistries;
|
||||||
|
import io.lettuce.core.SetArgs;
|
||||||
|
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||||
|
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||||
|
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
|
||||||
|
import org.whispersystems.textsecuregcm.util.Constants;
|
||||||
|
|
||||||
|
import static com.codahale.metrics.MetricRegistry.name;
|
||||||
|
import redis.clients.jedis.Jedis;
|
||||||
|
|
||||||
|
public class LockingRateLimiter extends RateLimiter {
|
||||||
|
|
||||||
|
private final Meter meter;
|
||||||
|
|
||||||
|
public LockingRateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name, int bucketSize, double leakRatePerMinute) {
|
||||||
|
super(cacheClient, cacheCluster, name, bucketSize, leakRatePerMinute);
|
||||||
|
|
||||||
|
MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
|
||||||
|
this.meter = metricRegistry.meter(name(getClass(), name, "locked"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void validate(String key, int amount) throws RateLimitExceededException {
|
||||||
|
if (!acquireLock(key)) {
|
||||||
|
meter.mark();
|
||||||
|
throw new RateLimitExceededException("Locked");
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
super.validate(key, amount);
|
||||||
|
} finally {
|
||||||
|
releaseLock(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void validate(String key) throws RateLimitExceededException {
|
||||||
|
validate(key, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void releaseLock(String key) {
|
||||||
|
try (Jedis jedis = cacheClient.getWriteResource()) {
|
||||||
|
final String lockName = getLockName(key);
|
||||||
|
|
||||||
|
jedis.del(lockName);
|
||||||
|
cacheCluster.useWriteCluster(connection -> connection.sync().del(lockName));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean acquireLock(String key) {
|
||||||
|
try (Jedis jedis = cacheClient.getWriteResource()) {
|
||||||
|
final String lockName = getLockName(key);
|
||||||
|
|
||||||
|
final boolean acquiredLock = jedis.set(lockName, "L", "NX", "EX", 10) != null;
|
||||||
|
|
||||||
|
if (acquiredLock) {
|
||||||
|
// TODO Restore the NX flag when the cluster becomes the primary source of truth
|
||||||
|
cacheCluster.useWriteCluster(connection -> connection.sync().set(lockName, "L", SetArgs.Builder.ex(10)));
|
||||||
|
}
|
||||||
|
|
||||||
|
return acquiredLock;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getLockName(String key) {
|
||||||
|
return "leaky_lock::" + name + "::" + key;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -20,26 +20,27 @@ import com.codahale.metrics.Meter;
|
||||||
import com.codahale.metrics.MetricRegistry;
|
import com.codahale.metrics.MetricRegistry;
|
||||||
import com.codahale.metrics.SharedMetricRegistries;
|
import com.codahale.metrics.SharedMetricRegistries;
|
||||||
import com.codahale.metrics.Timer;
|
import com.codahale.metrics.Timer;
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
import io.lettuce.core.ScriptOutputType;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||||
import org.whispersystems.textsecuregcm.experiment.Experiment;
|
import org.whispersystems.textsecuregcm.experiment.Experiment;
|
||||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
|
||||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||||
import org.whispersystems.textsecuregcm.redis.LuaScript;
|
|
||||||
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
|
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
|
||||||
import org.whispersystems.textsecuregcm.util.Constants;
|
import org.whispersystems.textsecuregcm.util.Constants;
|
||||||
import redis.clients.jedis.Jedis;
|
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.charset.StandardCharsets;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
import static com.codahale.metrics.MetricRegistry.name;
|
import static com.codahale.metrics.MetricRegistry.name;
|
||||||
|
import redis.clients.jedis.Jedis;
|
||||||
|
|
||||||
public class RateLimiter {
|
public class RateLimiter {
|
||||||
|
|
||||||
|
private final Logger logger = LoggerFactory.getLogger(RateLimiter.class);
|
||||||
|
private final ObjectMapper mapper = SystemMapper.getMapper();
|
||||||
|
|
||||||
private final Meter meter;
|
private final Meter meter;
|
||||||
private final Timer validateTimer;
|
private final Timer validateTimer;
|
||||||
protected final ReplicatedJedisPool cacheClient;
|
protected final ReplicatedJedisPool cacheClient;
|
||||||
|
@ -47,12 +48,19 @@ public class RateLimiter {
|
||||||
protected final String name;
|
protected final String name;
|
||||||
private final int bucketSize;
|
private final int bucketSize;
|
||||||
private final double leakRatePerMillis;
|
private final double leakRatePerMillis;
|
||||||
|
private final boolean reportLimits;
|
||||||
private final Experiment redisClusterExperiment;
|
private final Experiment redisClusterExperiment;
|
||||||
private final LuaScript validateScript;
|
|
||||||
private final ClusterLuaScript clusterValidateScript;
|
|
||||||
|
|
||||||
public RateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name,
|
public RateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name,
|
||||||
int bucketSize, double leakRatePerMinute) throws IOException {
|
int bucketSize, double leakRatePerMinute)
|
||||||
|
{
|
||||||
|
this(cacheClient, cacheCluster, name, bucketSize, leakRatePerMinute, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public RateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name,
|
||||||
|
int bucketSize, double leakRatePerMinute,
|
||||||
|
boolean reportLimits)
|
||||||
|
{
|
||||||
MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
|
MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
|
||||||
|
|
||||||
this.meter = metricRegistry.meter(name(getClass(), name, "exceeded"));
|
this.meter = metricRegistry.meter(name(getClass(), name, "exceeded"));
|
||||||
|
@ -62,37 +70,27 @@ public class RateLimiter {
|
||||||
this.name = name;
|
this.name = name;
|
||||||
this.bucketSize = bucketSize;
|
this.bucketSize = bucketSize;
|
||||||
this.leakRatePerMillis = leakRatePerMinute / (60.0 * 1000.0);
|
this.leakRatePerMillis = leakRatePerMinute / (60.0 * 1000.0);
|
||||||
|
this.reportLimits = reportLimits;
|
||||||
this.redisClusterExperiment = new Experiment("RedisCluster", "RateLimiter", name);
|
this.redisClusterExperiment = new Experiment("RedisCluster", "RateLimiter", name);
|
||||||
this.validateScript = LuaScript.fromResource(cacheClient, "lua/validate_rate_limit.lua");
|
|
||||||
this.clusterValidateScript = ClusterLuaScript.fromResource(cacheCluster, "lua/validate_rate_limit.lua", ScriptOutputType.INTEGER);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void validate(String key) throws RateLimitExceededException {
|
|
||||||
validate(key, 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void validate(String key, int amount) throws RateLimitExceededException {
|
public void validate(String key, int amount) throws RateLimitExceededException {
|
||||||
validate(key, amount, System.currentTimeMillis());
|
|
||||||
}
|
|
||||||
|
|
||||||
@VisibleForTesting
|
|
||||||
void validate(String key, int amount, final long currentTimeMillis) throws RateLimitExceededException {
|
|
||||||
try (final Timer.Context ignored = validateTimer.time()) {
|
try (final Timer.Context ignored = validateTimer.time()) {
|
||||||
final List<String> keys = List.of(getBucketName(key));
|
LeakyBucket bucket = getBucket(key);
|
||||||
final List<String> arguments = List.of(String.valueOf(bucketSize), String.valueOf(leakRatePerMillis), String.valueOf(currentTimeMillis), String.valueOf(amount));
|
|
||||||
|
|
||||||
final Object result = validateScript.execute(keys.stream().map(k -> k.getBytes(StandardCharsets.UTF_8)).collect(Collectors.toList()),
|
if (bucket.add(amount)) {
|
||||||
arguments.stream().map(a -> a.getBytes(StandardCharsets.UTF_8)).collect(Collectors.toList()));
|
setBucket(key, bucket);
|
||||||
|
} else {
|
||||||
redisClusterExperiment.compareSupplierResult(result, () -> clusterValidateScript.execute(keys, arguments));
|
|
||||||
|
|
||||||
if (result == null) {
|
|
||||||
meter.mark();
|
meter.mark();
|
||||||
throw new RateLimitExceededException(key + " , " + amount);
|
throw new RateLimitExceededException(key + " , " + amount);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void validate(String key) throws RateLimitExceededException {
|
||||||
|
validate(key, 1);
|
||||||
|
}
|
||||||
|
|
||||||
public void clear(String key) {
|
public void clear(String key) {
|
||||||
try (Jedis jedis = cacheClient.getWriteResource()) {
|
try (Jedis jedis = cacheClient.getWriteResource()) {
|
||||||
final String bucketName = getBucketName(key);
|
final String bucketName = getBucketName(key);
|
||||||
|
@ -102,8 +100,37 @@ public class RateLimiter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@VisibleForTesting
|
private void setBucket(String key, LeakyBucket bucket) {
|
||||||
String getBucketName(String key) {
|
try (Jedis jedis = cacheClient.getWriteResource()) {
|
||||||
|
final String bucketName = getBucketName(key);
|
||||||
|
final String serialized = bucket.serialize(mapper);
|
||||||
|
final int level = (int) Math.ceil((bucketSize / leakRatePerMillis) / 1000);
|
||||||
|
|
||||||
|
jedis.setex(bucketName, level, serialized);
|
||||||
|
cacheCluster.useWriteCluster(connection -> connection.sync().setex(bucketName, level, serialized));
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
throw new IllegalArgumentException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private LeakyBucket getBucket(String key) {
|
||||||
|
try (Jedis jedis = cacheClient.getReadResource()) {
|
||||||
|
final String bucketName = getBucketName(key);
|
||||||
|
|
||||||
|
String serialized = jedis.get(bucketName);
|
||||||
|
redisClusterExperiment.compareSupplierResult(serialized, () -> cacheCluster.withReadCluster(connection -> connection.sync().get(bucketName)));
|
||||||
|
|
||||||
|
if (serialized != null) {
|
||||||
|
return LeakyBucket.fromSerialized(mapper, serialized);
|
||||||
|
}
|
||||||
|
} catch (IOException e) {
|
||||||
|
logger.warn("Deserialization error", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new LeakyBucket(bucketSize, leakRatePerMillis);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getBucketName(String key) {
|
||||||
return "leaky_bucket::" + name + "::" + key;
|
return "leaky_bucket::" + name + "::" + key;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,8 +21,6 @@ import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
|
||||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||||
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
|
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
public class RateLimiters {
|
public class RateLimiters {
|
||||||
|
|
||||||
private final RateLimiter smsDestinationLimiter;
|
private final RateLimiter smsDestinationLimiter;
|
||||||
|
@ -50,7 +48,7 @@ public class RateLimiters {
|
||||||
private final RateLimiter usernameLookupLimiter;
|
private final RateLimiter usernameLookupLimiter;
|
||||||
private final RateLimiter usernameSetLimiter;
|
private final RateLimiter usernameSetLimiter;
|
||||||
|
|
||||||
public RateLimiters(RateLimitsConfiguration config, ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster) throws IOException {
|
public RateLimiters(RateLimitsConfiguration config, ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster) {
|
||||||
this.smsDestinationLimiter = new RateLimiter(cacheClient, cacheCluster, "smsDestination",
|
this.smsDestinationLimiter = new RateLimiter(cacheClient, cacheCluster, "smsDestination",
|
||||||
config.getSmsDestination().getBucketSize(),
|
config.getSmsDestination().getBucketSize(),
|
||||||
config.getSmsDestination().getLeakRatePerMinute());
|
config.getSmsDestination().getLeakRatePerMinute());
|
||||||
|
@ -75,11 +73,11 @@ public class RateLimiters {
|
||||||
config.getAutoBlock().getBucketSize(),
|
config.getAutoBlock().getBucketSize(),
|
||||||
config.getAutoBlock().getLeakRatePerMinute());
|
config.getAutoBlock().getLeakRatePerMinute());
|
||||||
|
|
||||||
this.verifyLimiter = new RateLimiter(cacheClient, cacheCluster, "verify",
|
this.verifyLimiter = new LockingRateLimiter(cacheClient, cacheCluster, "verify",
|
||||||
config.getVerifyNumber().getBucketSize(),
|
config.getVerifyNumber().getBucketSize(),
|
||||||
config.getVerifyNumber().getLeakRatePerMinute());
|
config.getVerifyNumber().getLeakRatePerMinute());
|
||||||
|
|
||||||
this.pinLimiter = new RateLimiter(cacheClient, cacheCluster, "pin",
|
this.pinLimiter = new LockingRateLimiter(cacheClient, cacheCluster, "pin",
|
||||||
config.getVerifyPin().getBucketSize(),
|
config.getVerifyPin().getBucketSize(),
|
||||||
config.getVerifyPin().getLeakRatePerMinute());
|
config.getVerifyPin().getLeakRatePerMinute());
|
||||||
|
|
||||||
|
|
|
@ -1,30 +0,0 @@
|
||||||
local bucketId = KEYS[1]
|
|
||||||
|
|
||||||
local bucketSize = tonumber(ARGV[1])
|
|
||||||
local leakRatePerMillis = tonumber(ARGV[2])
|
|
||||||
local currentTimeMillis = tonumber(ARGV[3])
|
|
||||||
local amount = tonumber(ARGV[4])
|
|
||||||
|
|
||||||
local leakyBucket
|
|
||||||
|
|
||||||
if redis.call("EXISTS", bucketId) == 1 then
|
|
||||||
leakyBucket = cjson.decode(redis.call("GET", bucketId))
|
|
||||||
else
|
|
||||||
leakyBucket = {
|
|
||||||
bucketSize = bucketSize,
|
|
||||||
leakRatePerMillis = leakRatePerMillis,
|
|
||||||
spaceRemaining = bucketSize,
|
|
||||||
lastUpdateTimeMillis = currentTimeMillis
|
|
||||||
}
|
|
||||||
end
|
|
||||||
|
|
||||||
local elapsedTime = currentTimeMillis - leakyBucket["lastUpdateTimeMillis"]
|
|
||||||
local updatedSpaceRemaining = math.min(leakyBucket["bucketSize"], math.floor(leakyBucket["spaceRemaining"] + (elapsedTime * leakyBucket["leakRatePerMillis"])))
|
|
||||||
|
|
||||||
if updatedSpaceRemaining >= amount then
|
|
||||||
leakyBucket["spaceRemaining"] = updatedSpaceRemaining - amount
|
|
||||||
redis.call("SET", bucketId, cjson.encode(leakyBucket))
|
|
||||||
return true
|
|
||||||
else
|
|
||||||
return false
|
|
||||||
end
|
|
|
@ -1,114 +0,0 @@
|
||||||
package org.whispersystems.textsecuregcm.limits;
|
|
||||||
|
|
||||||
import org.junit.After;
|
|
||||||
import org.junit.Before;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
|
|
||||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
|
||||||
import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
|
|
||||||
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
|
|
||||||
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
|
|
||||||
import redis.clients.jedis.Jedis;
|
|
||||||
import redis.embedded.RedisServer;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.net.URISyntaxException;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.concurrent.TimeUnit;
|
|
||||||
|
|
||||||
import static org.junit.Assert.fail;
|
|
||||||
|
|
||||||
public class RateLimiterTest extends AbstractRedisClusterTest {
|
|
||||||
|
|
||||||
private static final long NOW_MILLIS = System.currentTimeMillis();
|
|
||||||
private static final String KEY = "key";
|
|
||||||
|
|
||||||
private RedisServer redisServer;
|
|
||||||
|
|
||||||
private ReplicatedJedisPool replicatedJedisPool;
|
|
||||||
|
|
||||||
@FunctionalInterface
|
|
||||||
private interface RateLimitedTask {
|
|
||||||
void run() throws RateLimitExceededException;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Before
|
|
||||||
public void clearCache() throws URISyntaxException, IOException {
|
|
||||||
redisServer = new RedisServer(AbstractRedisClusterTest.getNextRedisClusterPort());
|
|
||||||
redisServer.start();
|
|
||||||
|
|
||||||
final String redisUrl = "redis://127.0.0.1:" + redisServer.ports().get(0);
|
|
||||||
replicatedJedisPool = new RedisClientFactory("test-pool", redisUrl, List.of(redisUrl), new CircuitBreakerConfiguration()).getRedisClientPool();
|
|
||||||
|
|
||||||
try (final Jedis jedis = replicatedJedisPool.getWriteResource()) {
|
|
||||||
jedis.flushAll();
|
|
||||||
}
|
|
||||||
|
|
||||||
getRedisCluster().useWriteCluster(connection -> connection.sync().flushall());
|
|
||||||
}
|
|
||||||
|
|
||||||
@After
|
|
||||||
public void stopServer() {
|
|
||||||
redisServer.stop();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void validate() throws RateLimitExceededException, IOException {
|
|
||||||
final RateLimiter rateLimiter = buildRateLimiter(2, 0.5);
|
|
||||||
|
|
||||||
rateLimiter.validate(KEY, 1, NOW_MILLIS);
|
|
||||||
rateLimiter.validate(KEY, 1, NOW_MILLIS);
|
|
||||||
assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void validateWithAmount() throws RateLimitExceededException, IOException {
|
|
||||||
final RateLimiter rateLimiter = buildRateLimiter(2, 0.5);
|
|
||||||
|
|
||||||
rateLimiter.validate(KEY, 2, NOW_MILLIS);
|
|
||||||
assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testLapseRate() throws RateLimitExceededException, IOException {
|
|
||||||
final RateLimiter rateLimiter = buildRateLimiter(2, 8.333333333333334E-6);
|
|
||||||
final String leakyBucketJson = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (NOW_MILLIS - TimeUnit.MINUTES.toMillis(2)) + "}";
|
|
||||||
|
|
||||||
try (final Jedis jedis = replicatedJedisPool.getWriteResource()) {
|
|
||||||
jedis.set(rateLimiter.getBucketName(KEY), leakyBucketJson);
|
|
||||||
}
|
|
||||||
|
|
||||||
getRedisCluster().useWriteCluster(connection -> connection.sync().set(rateLimiter.getBucketName(KEY), leakyBucketJson));
|
|
||||||
|
|
||||||
rateLimiter.validate(KEY, 1, NOW_MILLIS);
|
|
||||||
assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testLapseShort() throws IOException {
|
|
||||||
final RateLimiter rateLimiter = buildRateLimiter(2, 8.333333333333334E-6);
|
|
||||||
final String leakyBucketJson = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (NOW_MILLIS - TimeUnit.MINUTES.toMillis(1)) + "}";
|
|
||||||
|
|
||||||
try (final Jedis jedis = replicatedJedisPool.getWriteResource()) {
|
|
||||||
jedis.set(rateLimiter.getBucketName(KEY), leakyBucketJson);
|
|
||||||
}
|
|
||||||
|
|
||||||
getRedisCluster().useWriteCluster(connection -> connection.sync().set(rateLimiter.getBucketName(KEY), leakyBucketJson));
|
|
||||||
|
|
||||||
assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
|
|
||||||
}
|
|
||||||
|
|
||||||
private void assertRateLimitExceeded(final RateLimitedTask task) {
|
|
||||||
try {
|
|
||||||
task.run();
|
|
||||||
fail("Expected RateLimitExceededException");
|
|
||||||
} catch (final RateLimitExceededException ignored) {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@SuppressWarnings("SameParameterValue")
|
|
||||||
private RateLimiter buildRateLimiter(final int bucketSize, final double leakRatePerMilli) throws IOException {
|
|
||||||
final double leakRatePerMinute = leakRatePerMilli * 60_000d;
|
|
||||||
return new RateLimiter(replicatedJedisPool, getRedisCluster(), KEY, bucketSize, leakRatePerMinute);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,48 +0,0 @@
|
||||||
package org.whispersystems.textsecuregcm.redis;
|
|
||||||
|
|
||||||
import org.junit.AfterClass;
|
|
||||||
import org.junit.Before;
|
|
||||||
import org.junit.BeforeClass;
|
|
||||||
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
|
|
||||||
import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
|
|
||||||
import redis.embedded.RedisServer;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.net.ServerSocket;
|
|
||||||
import java.net.URISyntaxException;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public abstract class AbstractRedisTest {
|
|
||||||
|
|
||||||
private static RedisServer redisServer;
|
|
||||||
|
|
||||||
private ReplicatedJedisPool replicatedJedisPool;
|
|
||||||
|
|
||||||
@BeforeClass
|
|
||||||
public static void setUpBeforeClass() throws IOException {
|
|
||||||
redisServer = new RedisServer(getNextPort());
|
|
||||||
redisServer.start();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Before
|
|
||||||
public void setUp() throws URISyntaxException {
|
|
||||||
final String redisUrl = "redis://127.0.0.1:" + redisServer.ports().get(0);
|
|
||||||
replicatedJedisPool = new RedisClientFactory("test-pool", redisUrl, List.of(redisUrl), new CircuitBreakerConfiguration()).getRedisClientPool();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected ReplicatedJedisPool getReplicatedJedisPool() {
|
|
||||||
return replicatedJedisPool;
|
|
||||||
}
|
|
||||||
|
|
||||||
@AfterClass
|
|
||||||
public static void tearDownAfterClass() {
|
|
||||||
redisServer.stop();
|
|
||||||
}
|
|
||||||
|
|
||||||
private static int getNextPort() throws IOException {
|
|
||||||
try (ServerSocket socket = new ServerSocket(0)) {
|
|
||||||
socket.setReuseAddress(false);
|
|
||||||
return socket.getLocalPort();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
package org.whispersystems.textsecuregcm.tests.limits;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.whispersystems.textsecuregcm.limits.LeakyBucket;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
public class LeakyBucketTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testFull() {
|
||||||
|
LeakyBucket leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
|
||||||
|
|
||||||
|
assertTrue(leakyBucket.add(1));
|
||||||
|
assertTrue(leakyBucket.add(1));
|
||||||
|
assertFalse(leakyBucket.add(1));
|
||||||
|
|
||||||
|
leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
|
||||||
|
|
||||||
|
assertTrue(leakyBucket.add(2));
|
||||||
|
assertFalse(leakyBucket.add(1));
|
||||||
|
assertFalse(leakyBucket.add(2));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLapseRate() throws IOException {
|
||||||
|
ObjectMapper mapper = new ObjectMapper();
|
||||||
|
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(2)) + "}";
|
||||||
|
|
||||||
|
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
|
||||||
|
assertTrue(leakyBucket.add(1));
|
||||||
|
|
||||||
|
String serializedAgain = leakyBucket.serialize(mapper);
|
||||||
|
LeakyBucket leakyBucketAgain = LeakyBucket.fromSerialized(mapper, serializedAgain);
|
||||||
|
|
||||||
|
assertFalse(leakyBucketAgain.add(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLapseShort() throws Exception {
|
||||||
|
ObjectMapper mapper = new ObjectMapper();
|
||||||
|
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
|
||||||
|
|
||||||
|
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
|
||||||
|
assertFalse(leakyBucket.add(1));
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue