Move rate limiter logic to Lua scripts

This commit is contained in:
Jon Chambers 2020-07-06 10:10:13 -04:00 committed by GitHub
parent f5ddb0f1f8
commit b585c6676d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 214 additions and 294 deletions

View File

@ -1,98 +0,0 @@
/**
* Copyright (C) 2013 Open WhisperSystems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.whispersystems.textsecuregcm.limits;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
public class LeakyBucket {
private final int bucketSize;
private final double leakRatePerMillis;
private int spaceRemaining;
private long lastUpdateTimeMillis;
public LeakyBucket(int bucketSize, double leakRatePerMillis) {
this(bucketSize, leakRatePerMillis, bucketSize, System.currentTimeMillis());
}
private LeakyBucket(int bucketSize, double leakRatePerMillis, int spaceRemaining, long lastUpdateTimeMillis) {
this.bucketSize = bucketSize;
this.leakRatePerMillis = leakRatePerMillis;
this.spaceRemaining = spaceRemaining;
this.lastUpdateTimeMillis = lastUpdateTimeMillis;
}
public boolean add(int amount) {
this.spaceRemaining = getUpdatedSpaceRemaining();
this.lastUpdateTimeMillis = System.currentTimeMillis();
if (this.spaceRemaining >= amount) {
this.spaceRemaining -= amount;
return true;
} else {
return false;
}
}
private int getUpdatedSpaceRemaining() {
long elapsedTime = System.currentTimeMillis() - this.lastUpdateTimeMillis;
return Math.min(this.bucketSize,
(int)Math.floor(this.spaceRemaining + (elapsedTime * this.leakRatePerMillis)));
}
public String serialize(ObjectMapper mapper) throws JsonProcessingException {
return mapper.writeValueAsString(new LeakyBucketEntity(bucketSize, leakRatePerMillis, spaceRemaining, lastUpdateTimeMillis));
}
public static LeakyBucket fromSerialized(ObjectMapper mapper, String serialized) throws IOException {
LeakyBucketEntity entity = mapper.readValue(serialized, LeakyBucketEntity.class);
return new LeakyBucket(entity.bucketSize, entity.leakRatePerMillis,
entity.spaceRemaining, entity.lastUpdateTimeMillis);
}
private static class LeakyBucketEntity {
@JsonProperty
private int bucketSize;
@JsonProperty
private double leakRatePerMillis;
@JsonProperty
private int spaceRemaining;
@JsonProperty
private long lastUpdateTimeMillis;
public LeakyBucketEntity() {}
private LeakyBucketEntity(int bucketSize, double leakRatePerMillis,
int spaceRemaining, long lastUpdateTimeMillis)
{
this.bucketSize = bucketSize;
this.leakRatePerMillis = leakRatePerMillis;
this.spaceRemaining = spaceRemaining;
this.lastUpdateTimeMillis = lastUpdateTimeMillis;
}
}
}

View File

@ -1,74 +0,0 @@
package org.whispersystems.textsecuregcm.limits;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import io.lettuce.core.SetArgs;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
import org.whispersystems.textsecuregcm.util.Constants;
import static com.codahale.metrics.MetricRegistry.name;
import redis.clients.jedis.Jedis;
public class LockingRateLimiter extends RateLimiter {
private final Meter meter;
public LockingRateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name, int bucketSize, double leakRatePerMinute) {
super(cacheClient, cacheCluster, name, bucketSize, leakRatePerMinute);
MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
this.meter = metricRegistry.meter(name(getClass(), name, "locked"));
}
@Override
public void validate(String key, int amount) throws RateLimitExceededException {
if (!acquireLock(key)) {
meter.mark();
throw new RateLimitExceededException("Locked");
}
try {
super.validate(key, amount);
} finally {
releaseLock(key);
}
}
@Override
public void validate(String key) throws RateLimitExceededException {
validate(key, 1);
}
private void releaseLock(String key) {
try (Jedis jedis = cacheClient.getWriteResource()) {
final String lockName = getLockName(key);
jedis.del(lockName);
cacheCluster.useWriteCluster(connection -> connection.sync().del(lockName));
}
}
private boolean acquireLock(String key) {
try (Jedis jedis = cacheClient.getWriteResource()) {
final String lockName = getLockName(key);
final boolean acquiredLock = jedis.set(lockName, "L", "NX", "EX", 10) != null;
if (acquiredLock) {
// TODO Restore the NX flag when the cluster becomes the primary source of truth
cacheCluster.useWriteCluster(connection -> connection.sync().set(lockName, "L", SetArgs.Builder.ex(10)));
}
return acquiredLock;
}
}
private String getLockName(String key) {
return "leaky_lock::" + name + "::" + key;
}
}

View File

@ -20,26 +20,25 @@ import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.annotations.VisibleForTesting;
import io.lettuce.core.ScriptOutputType;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.experiment.Experiment;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.LuaScript;
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import java.io.IOException;
import static com.codahale.metrics.MetricRegistry.name;
import redis.clients.jedis.Jedis;
public class RateLimiter {
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.stream.Collectors;
private final Logger logger = LoggerFactory.getLogger(RateLimiter.class);
private final ObjectMapper mapper = SystemMapper.getMapper();
import static com.codahale.metrics.MetricRegistry.name;
public class RateLimiter {
private final Meter meter;
private final Timer validateTimer;
@ -48,19 +47,12 @@ public class RateLimiter {
protected final String name;
private final int bucketSize;
private final double leakRatePerMillis;
private final boolean reportLimits;
private final Experiment redisClusterExperiment;
private final LuaScript validateScript;
private final ClusterLuaScript clusterValidateScript;
public RateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name,
int bucketSize, double leakRatePerMinute)
{
this(cacheClient, cacheCluster, name, bucketSize, leakRatePerMinute, false);
}
public RateLimiter(ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster, String name,
int bucketSize, double leakRatePerMinute,
boolean reportLimits)
{
int bucketSize, double leakRatePerMinute) throws IOException {
MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
this.meter = metricRegistry.meter(name(getClass(), name, "exceeded"));
@ -70,27 +62,37 @@ public class RateLimiter {
this.name = name;
this.bucketSize = bucketSize;
this.leakRatePerMillis = leakRatePerMinute / (60.0 * 1000.0);
this.reportLimits = reportLimits;
this.redisClusterExperiment = new Experiment("RedisCluster", "RateLimiter", name);
}
public void validate(String key, int amount) throws RateLimitExceededException {
try (final Timer.Context ignored = validateTimer.time()) {
LeakyBucket bucket = getBucket(key);
if (bucket.add(amount)) {
setBucket(key, bucket);
} else {
meter.mark();
throw new RateLimitExceededException(key + " , " + amount);
}
}
this.validateScript = LuaScript.fromResource(cacheClient, "lua/validate_rate_limit.lua");
this.clusterValidateScript = ClusterLuaScript.fromResource(cacheCluster, "lua/validate_rate_limit.lua", ScriptOutputType.INTEGER);
}
public void validate(String key) throws RateLimitExceededException {
validate(key, 1);
}
public void validate(String key, int amount) throws RateLimitExceededException {
validate(key, amount, System.currentTimeMillis());
}
@VisibleForTesting
void validate(String key, int amount, final long currentTimeMillis) throws RateLimitExceededException {
try (final Timer.Context ignored = validateTimer.time()) {
final List<String> keys = List.of(getBucketName(key));
final List<String> arguments = List.of(String.valueOf(bucketSize), String.valueOf(leakRatePerMillis), String.valueOf(currentTimeMillis), String.valueOf(amount));
final Object result = validateScript.execute(keys.stream().map(k -> k.getBytes(StandardCharsets.UTF_8)).collect(Collectors.toList()),
arguments.stream().map(a -> a.getBytes(StandardCharsets.UTF_8)).collect(Collectors.toList()));
redisClusterExperiment.compareSupplierResult(result, () -> clusterValidateScript.execute(keys, arguments));
if (result == null) {
meter.mark();
throw new RateLimitExceededException(key + " , " + amount);
}
}
}
public void clear(String key) {
try (Jedis jedis = cacheClient.getWriteResource()) {
final String bucketName = getBucketName(key);
@ -100,37 +102,8 @@ public class RateLimiter {
}
}
private void setBucket(String key, LeakyBucket bucket) {
try (Jedis jedis = cacheClient.getWriteResource()) {
final String bucketName = getBucketName(key);
final String serialized = bucket.serialize(mapper);
final int level = (int) Math.ceil((bucketSize / leakRatePerMillis) / 1000);
jedis.setex(bucketName, level, serialized);
cacheCluster.useWriteCluster(connection -> connection.sync().setex(bucketName, level, serialized));
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
}
private LeakyBucket getBucket(String key) {
try (Jedis jedis = cacheClient.getReadResource()) {
final String bucketName = getBucketName(key);
String serialized = jedis.get(bucketName);
redisClusterExperiment.compareSupplierResult(serialized, () -> cacheCluster.withReadCluster(connection -> connection.sync().get(bucketName)));
if (serialized != null) {
return LeakyBucket.fromSerialized(mapper, serialized);
}
} catch (IOException e) {
logger.warn("Deserialization error", e);
}
return new LeakyBucket(bucketSize, leakRatePerMillis);
}
private String getBucketName(String key) {
@VisibleForTesting
String getBucketName(String key) {
return "leaky_bucket::" + name + "::" + key;
}
}

View File

@ -21,6 +21,8 @@ import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
import java.io.IOException;
public class RateLimiters {
private final RateLimiter smsDestinationLimiter;
@ -48,7 +50,7 @@ public class RateLimiters {
private final RateLimiter usernameLookupLimiter;
private final RateLimiter usernameSetLimiter;
public RateLimiters(RateLimitsConfiguration config, ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster) {
public RateLimiters(RateLimitsConfiguration config, ReplicatedJedisPool cacheClient, FaultTolerantRedisCluster cacheCluster) throws IOException {
this.smsDestinationLimiter = new RateLimiter(cacheClient, cacheCluster, "smsDestination",
config.getSmsDestination().getBucketSize(),
config.getSmsDestination().getLeakRatePerMinute());
@ -73,11 +75,11 @@ public class RateLimiters {
config.getAutoBlock().getBucketSize(),
config.getAutoBlock().getLeakRatePerMinute());
this.verifyLimiter = new LockingRateLimiter(cacheClient, cacheCluster, "verify",
this.verifyLimiter = new RateLimiter(cacheClient, cacheCluster, "verify",
config.getVerifyNumber().getBucketSize(),
config.getVerifyNumber().getLeakRatePerMinute());
this.pinLimiter = new LockingRateLimiter(cacheClient, cacheCluster, "pin",
this.pinLimiter = new RateLimiter(cacheClient, cacheCluster, "pin",
config.getVerifyPin().getBucketSize(),
config.getVerifyPin().getLeakRatePerMinute());

View File

@ -0,0 +1,33 @@
local bucketId = KEYS[1]
local bucketSize = tonumber(ARGV[1])
local leakRatePerMillis = tonumber(ARGV[2])
local currentTimeMillis = tonumber(ARGV[3])
local amount = tonumber(ARGV[4])
local leakyBucket
if redis.call("EXISTS", bucketId) == 1 then
leakyBucket = cjson.decode(redis.call("GET", bucketId))
else
leakyBucket = {
bucketSize = bucketSize,
leakRatePerMillis = leakRatePerMillis,
spaceRemaining = bucketSize,
lastUpdateTimeMillis = currentTimeMillis
}
end
local elapsedTime = currentTimeMillis - leakyBucket["lastUpdateTimeMillis"]
local updatedSpaceRemaining = math.min(leakyBucket["bucketSize"], math.floor(leakyBucket["spaceRemaining"] + (elapsedTime * leakyBucket["leakRatePerMillis"])))
redis.call("SET", "elapsedTime", elapsedTime)
redis.call("SET", "updatedSpaceRemaining", updatedSpaceRemaining)
if updatedSpaceRemaining >= amount then
leakyBucket["spaceRemaining"] = updatedSpaceRemaining - amount
redis.call("SET", bucketId, cjson.encode(leakyBucket))
return true
else
return false
end

View File

@ -0,0 +1,88 @@
package org.whispersystems.textsecuregcm.limits;
import org.junit.Before;
import org.junit.Test;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.AbstractRedisTest;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import redis.clients.jedis.Jedis;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
public class RateLimiterTest extends AbstractRedisTest {
private static final long NOW_MILLIS = System.currentTimeMillis();
private static final String KEY = "key";
@FunctionalInterface
private interface RateLimitedTask {
void run() throws RateLimitExceededException;
}
@Before
public void clearCache() {
try (final Jedis jedis = getReplicatedJedisPool().getWriteResource()) {
jedis.flushAll();
}
}
@Test
public void validate() throws RateLimitExceededException, IOException {
final RateLimiter rateLimiter = buildRateLimiter(2, 0.5);
rateLimiter.validate(KEY, 1, NOW_MILLIS);
rateLimiter.validate(KEY, 1, NOW_MILLIS);
assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
}
@Test
public void validateWithAmount() throws RateLimitExceededException, IOException {
final RateLimiter rateLimiter = buildRateLimiter(2, 0.5);
rateLimiter.validate(KEY, 2, NOW_MILLIS);
assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
}
@Test
public void testLapseRate() throws RateLimitExceededException, IOException {
final RateLimiter rateLimiter = buildRateLimiter(2, 8.333333333333334E-6);
final String leakyBucketJson = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (NOW_MILLIS - TimeUnit.MINUTES.toMillis(2)) + "}";
try (final Jedis jedis = getReplicatedJedisPool().getWriteResource()) {
jedis.set(rateLimiter.getBucketName(KEY), leakyBucketJson);
}
rateLimiter.validate(KEY, 1, NOW_MILLIS);
assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
}
@Test
public void testLapseShort() throws IOException {
final RateLimiter rateLimiter = buildRateLimiter(2, 8.333333333333334E-6);
final String leakyBucketJson = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (NOW_MILLIS - TimeUnit.MINUTES.toMillis(1)) + "}";
try (final Jedis jedis = getReplicatedJedisPool().getWriteResource()) {
jedis.set(rateLimiter.getBucketName(KEY), leakyBucketJson);
}
assertRateLimitExceeded(() -> rateLimiter.validate(KEY, 1, NOW_MILLIS));
}
private void assertRateLimitExceeded(final RateLimitedTask task) {
try {
task.run();
fail("Expected RateLimitExceededException");
} catch (final RateLimitExceededException ignored) {
}
}
@SuppressWarnings("SameParameterValue")
private RateLimiter buildRateLimiter(final int bucketSize, final double leakRatePerMilli) throws IOException {
final double leakRatePerMinute = leakRatePerMilli * 60_000d;
return new RateLimiter(getReplicatedJedisPool(), mock(FaultTolerantRedisCluster.class), KEY, bucketSize, leakRatePerMinute);
}
}

View File

@ -0,0 +1,48 @@
package org.whispersystems.textsecuregcm.redis;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
import redis.embedded.RedisServer;
import java.io.IOException;
import java.net.ServerSocket;
import java.net.URISyntaxException;
import java.util.List;
public abstract class AbstractRedisTest {
private static RedisServer redisServer;
private ReplicatedJedisPool replicatedJedisPool;
@BeforeClass
public static void setUpBeforeClass() throws IOException {
redisServer = new RedisServer(getNextPort());
redisServer.start();
}
@Before
public void setUp() throws URISyntaxException {
final String redisUrl = "redis://127.0.0.1:" + redisServer.ports().get(0);
replicatedJedisPool = new RedisClientFactory("test-pool", redisUrl, List.of(redisUrl), new CircuitBreakerConfiguration()).getRedisClientPool();
}
protected ReplicatedJedisPool getReplicatedJedisPool() {
return replicatedJedisPool;
}
@AfterClass
public static void tearDownAfterClass() {
redisServer.stop();
}
private static int getNextPort() throws IOException {
try (ServerSocket socket = new ServerSocket(0)) {
socket.setReuseAddress(false);
return socket.getLocalPort();
}
}
}

View File

@ -1,52 +0,0 @@
package org.whispersystems.textsecuregcm.tests.limits;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.Test;
import org.whispersystems.textsecuregcm.limits.LeakyBucket;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
public class LeakyBucketTest {
@Test
public void testFull() {
LeakyBucket leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
assertTrue(leakyBucket.add(1));
assertTrue(leakyBucket.add(1));
assertFalse(leakyBucket.add(1));
leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
assertTrue(leakyBucket.add(2));
assertFalse(leakyBucket.add(1));
assertFalse(leakyBucket.add(2));
}
@Test
public void testLapseRate() throws IOException {
ObjectMapper mapper = new ObjectMapper();
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(2)) + "}";
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
assertTrue(leakyBucket.add(1));
String serializedAgain = leakyBucket.serialize(mapper);
LeakyBucket leakyBucketAgain = LeakyBucket.fromSerialized(mapper, serializedAgain);
assertFalse(leakyBucketAgain.add(1));
}
@Test
public void testLapseShort() throws Exception {
ObjectMapper mapper = new ObjectMapper();
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
assertFalse(leakyBucket.add(1));
}
}