diff --git a/pom.xml b/pom.xml
index c370c4bd8..67f544595 100644
--- a/pom.xml
+++ b/pom.xml
@@ -64,6 +64,7 @@
6.2.1.RELEASE
8.12.54
7.2
+ 3.3.0
1.10.3
4.11.0
4.1.82.Final
diff --git a/service/pom.xml b/service/pom.xml
index 5e5c6f054..88214f098 100644
--- a/service/pom.xml
+++ b/service/pom.xml
@@ -172,6 +172,26 @@
+
+ party.iroiro.luajava
+ luajava
+ ${luajava.version}
+ test
+
+
+ party.iroiro.luajava
+ lua51
+ ${luajava.version}
+ test
+
+
+ party.iroiro.luajava
+ lua51-platform
+ ${luajava.version}
+ natives-desktop
+ runtime
+
+
org.eclipse.jetty.websocket
websocket-api
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/BaseRateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/BaseRateLimiters.java
index b92d0a786..65ad263ca 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/BaseRateLimiters.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/BaseRateLimiters.java
@@ -7,7 +7,11 @@ package org.whispersystems.textsecuregcm.limits;
import static java.util.Objects.requireNonNull;
+import io.lettuce.core.ScriptOutputType;
+import java.io.IOException;
+import java.io.UncheckedIOException;
import java.lang.invoke.MethodHandles;
+import java.time.Clock;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;
@@ -17,6 +21,7 @@ import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
+import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
@@ -33,12 +38,14 @@ public abstract class BaseRateLimiters {
final T[] values,
final Map configs,
final DynamicConfigurationManager dynamicConfigurationManager,
- final FaultTolerantRedisCluster cacheCluster) {
+ final ClusterLuaScript validateScript,
+ final FaultTolerantRedisCluster cacheCluster,
+ final Clock clock) {
this.configs = configs;
this.rateLimiterByDescriptor = Arrays.stream(values)
.map(descriptor -> Pair.of(
descriptor,
- createForDescriptor(descriptor, configs, dynamicConfigurationManager, cacheCluster)))
+ createForDescriptor(descriptor, configs, dynamicConfigurationManager, validateScript, cacheCluster, clock)))
.collect(Collectors.toUnmodifiableMap(Pair::getKey, Pair::getValue));
}
@@ -62,11 +69,22 @@ public abstract class BaseRateLimiters {
}
}
+ protected static ClusterLuaScript defaultScript(final FaultTolerantRedisCluster cacheCluster) {
+ try {
+ return ClusterLuaScript.fromResource(
+ cacheCluster, "lua/validate_rate_limit.lua", ScriptOutputType.INTEGER);
+ } catch (final IOException e) {
+ throw new UncheckedIOException("Failed to load rate limit validation script", e);
+ }
+ }
+
private static RateLimiter createForDescriptor(
final RateLimiterDescriptor descriptor,
final Map configs,
final DynamicConfigurationManager dynamicConfigurationManager,
- final FaultTolerantRedisCluster cacheCluster) {
+ final ClusterLuaScript validateScript,
+ final FaultTolerantRedisCluster cacheCluster,
+ final Clock clock) {
if (descriptor.isDynamic()) {
final Supplier configResolver = () -> {
final RateLimiterConfig config = dynamicConfigurationManager.getConfiguration().getLimits().get(descriptor.id());
@@ -74,9 +92,9 @@ public abstract class BaseRateLimiters {
? config
: configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
};
- return new DynamicRateLimiter(descriptor.id(), configResolver, cacheCluster);
+ return new DynamicRateLimiter(descriptor.id(), configResolver, validateScript, cacheCluster, clock);
}
final RateLimiterConfig cfg = configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
- return new StaticRateLimiter(descriptor.id(), cfg, cacheCluster);
+ return new StaticRateLimiter(descriptor.id(), cfg, validateScript, cacheCluster, clock);
}
}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiter.java
index dc711469a..8f7b0ec82 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiter.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiter.java
@@ -7,10 +7,13 @@ package org.whispersystems.textsecuregcm.limits;
import static java.util.Objects.requireNonNull;
+import java.time.Clock;
+import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import org.apache.commons.lang3.tuple.Pair;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
+import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
public class DynamicRateLimiter implements RateLimiter {
@@ -19,18 +22,26 @@ public class DynamicRateLimiter implements RateLimiter {
private final Supplier configResolver;
+ private final ClusterLuaScript validateScript;
+
private final FaultTolerantRedisCluster cluster;
+ private final Clock clock;
+
private final AtomicReference> currentHolder = new AtomicReference<>();
public DynamicRateLimiter(
final String name,
final Supplier configResolver,
- final FaultTolerantRedisCluster cluster) {
+ final ClusterLuaScript validateScript,
+ final FaultTolerantRedisCluster cluster,
+ final Clock clock) {
this.name = requireNonNull(name);
this.configResolver = requireNonNull(configResolver);
+ this.validateScript = requireNonNull(validateScript);
this.cluster = requireNonNull(cluster);
+ this.clock = requireNonNull(clock);
}
@Override
@@ -38,16 +49,31 @@ public class DynamicRateLimiter implements RateLimiter {
current().getRight().validate(key, amount);
}
+ @Override
+ public CompletionStage validateAsync(final String key, final int amount) {
+ return current().getRight().validateAsync(key, amount);
+ }
+
@Override
public boolean hasAvailablePermits(final String key, final int permits) {
return current().getRight().hasAvailablePermits(key, permits);
}
+ @Override
+ public CompletionStage hasAvailablePermitsAsync(final String key, final int amount) {
+ return current().getRight().hasAvailablePermitsAsync(key, amount);
+ }
+
@Override
public void clear(final String key) {
current().getRight().clear(key);
}
+ @Override
+ public CompletionStage clearAsync(final String key) {
+ return current().getRight().clearAsync(key);
+ }
+
@Override
public RateLimiterConfig config() {
return current().getLeft();
@@ -57,7 +83,7 @@ public class DynamicRateLimiter implements RateLimiter {
final RateLimiterConfig cfg = configResolver.get();
return currentHolder.updateAndGet(p -> p != null && p.getLeft().equals(cfg)
? p
- : Pair.of(cfg, new StaticRateLimiter(name, cfg, cluster))
+ : Pair.of(cfg, new StaticRateLimiter(name, cfg, validateScript, cluster, clock))
);
}
}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java
deleted file mode 100644
index 51b60ea8d..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java
+++ /dev/null
@@ -1,99 +0,0 @@
-/*
- * Copyright 2013-2020 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-package org.whispersystems.textsecuregcm.limits;
-
-import com.fasterxml.jackson.annotation.JsonProperty;
-import com.fasterxml.jackson.core.JsonProcessingException;
-import com.fasterxml.jackson.databind.ObjectMapper;
-
-import java.io.IOException;
-import java.time.Duration;
-
-public class LeakyBucket {
-
- private final int bucketSize;
- private final double leakRatePerMillis;
-
- private int spaceRemaining;
- private long lastUpdateTimeMillis;
-
- public LeakyBucket(int bucketSize, double leakRatePerMillis) {
- this(bucketSize, leakRatePerMillis, bucketSize, System.currentTimeMillis());
- }
-
- private LeakyBucket(int bucketSize, double leakRatePerMillis, int spaceRemaining, long lastUpdateTimeMillis) {
- this.bucketSize = bucketSize;
- this.leakRatePerMillis = leakRatePerMillis;
- this.spaceRemaining = spaceRemaining;
- this.lastUpdateTimeMillis = lastUpdateTimeMillis;
- }
-
- public boolean add(int amount) {
- this.spaceRemaining = getUpdatedSpaceRemaining();
- this.lastUpdateTimeMillis = System.currentTimeMillis();
-
- if (this.spaceRemaining >= amount) {
- this.spaceRemaining -= amount;
- return true;
- } else {
- return false;
- }
- }
-
- private int getUpdatedSpaceRemaining() {
- long elapsedTime = System.currentTimeMillis() - this.lastUpdateTimeMillis;
-
- return Math.min(this.bucketSize,
- (int)Math.floor(this.spaceRemaining + (elapsedTime * this.leakRatePerMillis)));
- }
-
- public Duration getTimeUntilSpaceAvailable(int amount) {
- int currentSpaceRemaining = getUpdatedSpaceRemaining();
- if (currentSpaceRemaining >= amount) {
- return Duration.ZERO;
- } else if (amount > this.bucketSize) {
- // This shouldn't happen today but if so we should bubble this to the clients somehow
- throw new IllegalArgumentException("Requested permits exceed maximum bucket size");
- } else {
- return Duration.ofMillis((long)Math.ceil((double)(amount - currentSpaceRemaining) / this.leakRatePerMillis));
- }
- }
-
- public String serialize(ObjectMapper mapper) throws JsonProcessingException {
- return mapper.writeValueAsString(new LeakyBucketEntity(bucketSize, leakRatePerMillis, spaceRemaining, lastUpdateTimeMillis));
- }
-
- public static LeakyBucket fromSerialized(ObjectMapper mapper, String serialized) throws IOException {
- LeakyBucketEntity entity = mapper.readValue(serialized, LeakyBucketEntity.class);
-
- return new LeakyBucket(entity.bucketSize, entity.leakRatePerMillis,
- entity.spaceRemaining, entity.lastUpdateTimeMillis);
- }
-
- private static class LeakyBucketEntity {
- @JsonProperty
- private int bucketSize;
-
- @JsonProperty
- private double leakRatePerMillis;
-
- @JsonProperty
- private int spaceRemaining;
-
- @JsonProperty
- private long lastUpdateTimeMillis;
-
- public LeakyBucketEntity() {}
-
- private LeakyBucketEntity(int bucketSize, double leakRatePerMillis,
- int spaceRemaining, long lastUpdateTimeMillis)
- {
- this.bucketSize = bucketSize;
- this.leakRatePerMillis = leakRatePerMillis;
- this.spaceRemaining = spaceRemaining;
- this.lastUpdateTimeMillis = lastUpdateTimeMillis;
- }
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java
deleted file mode 100644
index c755f83a4..000000000
--- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Copyright 2013 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-
-package org.whispersystems.textsecuregcm.limits;
-
-import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
-
-import io.lettuce.core.SetArgs;
-import io.micrometer.core.instrument.Counter;
-import io.micrometer.core.instrument.Metrics;
-import java.time.Duration;
-import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
-import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
-
-public class LockingRateLimiter extends StaticRateLimiter {
-
- private static final RateLimitExceededException REUSABLE_RATE_LIMIT_EXCEEDED_EXCEPTION
- = new RateLimitExceededException(Duration.ZERO, true);
-
- private final Counter counter;
-
-
- public LockingRateLimiter(
- final String name,
- final RateLimiterConfig config,
- final FaultTolerantRedisCluster cacheCluster) {
- super(name, config, cacheCluster);
- this.counter = Metrics.counter(name(getClass(), "locked"), "name", name);
- }
-
- @Override
- public void validate(final String key, final int amount) throws RateLimitExceededException {
- if (!acquireLock(key)) {
- counter.increment();
- throw REUSABLE_RATE_LIMIT_EXCEEDED_EXCEPTION;
- }
-
- try {
- super.validate(key, amount);
- } finally {
- releaseLock(key);
- }
- }
-
- private void releaseLock(final String key) {
- cacheCluster.useCluster(connection -> connection.sync().del(getLockName(key)));
- }
-
- private boolean acquireLock(final String key) {
- return cacheCluster.withCluster(connection -> connection.sync().set(getLockName(key), "L", SetArgs.Builder.nx().ex(10))) != null;
- }
-
- private String getLockName(final String key) {
- return "leaky_lock::" + name + "::" + key;
- }
-}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java
index 6f168166a..44b819dc4 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java
@@ -6,16 +6,23 @@
package org.whispersystems.textsecuregcm.limits;
import java.util.UUID;
+import java.util.concurrent.CompletionStage;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
public interface RateLimiter {
void validate(String key, int amount) throws RateLimitExceededException;
+ CompletionStage validateAsync(String key, int amount);
+
boolean hasAvailablePermits(String key, int permits);
+ CompletionStage hasAvailablePermitsAsync(String key, int amount);
+
void clear(String key);
+ CompletionStage clearAsync(String key);
+
RateLimiterConfig config();
default void validate(final String key) throws RateLimitExceededException {
@@ -30,14 +37,34 @@ public interface RateLimiter {
validate(srcAccountUuid.toString() + "__" + dstAccountUuid.toString());
}
+ default CompletionStage validateAsync(final String key) {
+ return validateAsync(key, 1);
+ }
+
+ default CompletionStage validateAsync(final UUID accountUuid) {
+ return validateAsync(accountUuid.toString());
+ }
+
+ default CompletionStage validateAsync(final UUID srcAccountUuid, final UUID dstAccountUuid) {
+ return validateAsync(srcAccountUuid.toString() + "__" + dstAccountUuid.toString());
+ }
+
default boolean hasAvailablePermits(final UUID accountUuid, final int permits) {
return hasAvailablePermits(accountUuid.toString(), permits);
}
+ default CompletionStage hasAvailablePermitsAsync(final UUID accountUuid, final int permits) {
+ return hasAvailablePermitsAsync(accountUuid.toString(), permits);
+ }
+
default void clear(final UUID accountUuid) {
clear(accountUuid.toString());
}
+ default CompletionStage clearAsync(final UUID accountUuid) {
+ return clearAsync(accountUuid.toString());
+ }
+
/**
* If the wrapped {@code validate()} call throws a {@link RateLimitExceededException}, it will adapt it to ensure that
* {@link RateLimitExceededException#isLegacy()} returns {@code true}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java
index 3aca56d7c..fa9262f96 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java
@@ -6,8 +6,10 @@ package org.whispersystems.textsecuregcm.limits;
import com.google.common.annotations.VisibleForTesting;
+import java.time.Clock;
import java.util.Map;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
+import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
@@ -103,7 +105,8 @@ public class RateLimiters extends BaseRateLimiters {
final Map configs,
final DynamicConfigurationManager dynamicConfigurationManager,
final FaultTolerantRedisCluster cacheCluster) {
- final RateLimiters rateLimiters = new RateLimiters(configs, dynamicConfigurationManager, cacheCluster);
+ final RateLimiters rateLimiters = new RateLimiters(
+ configs, dynamicConfigurationManager, defaultScript(cacheCluster), cacheCluster, Clock.systemUTC());
rateLimiters.validateValuesAndConfigs();
return rateLimiters;
}
@@ -112,8 +115,10 @@ public class RateLimiters extends BaseRateLimiters {
RateLimiters(
final Map configs,
final DynamicConfigurationManager dynamicConfigurationManager,
- final FaultTolerantRedisCluster cacheCluster) {
- super(For.values(), configs, dynamicConfigurationManager, cacheCluster);
+ final ClusterLuaScript validateScript,
+ final FaultTolerantRedisCluster cacheCluster,
+ final Clock clock) {
+ super(For.values(), configs, dynamicConfigurationManager, validateScript, cacheCluster, clock);
}
public RateLimiter getAllocateDeviceLimiter() {
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java
index 22d6058c5..22d4f4ffc 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java
@@ -5,60 +5,96 @@
package org.whispersystems.textsecuregcm.limits;
import static java.util.Objects.requireNonNull;
-import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
+import static java.util.concurrent.CompletableFuture.completedFuture;
+import static java.util.concurrent.CompletableFuture.failedFuture;
-import com.fasterxml.jackson.core.JsonProcessingException;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
-import java.io.IOException;
+import java.time.Clock;
import java.time.Duration;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import java.util.List;
+import java.util.concurrent.CompletionStage;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
+import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
+import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
-import org.whispersystems.textsecuregcm.util.SystemMapper;
+import org.whispersystems.textsecuregcm.util.Util;
public class StaticRateLimiter implements RateLimiter {
- private static final Logger logger = LoggerFactory.getLogger(StaticRateLimiter.class);
-
protected final String name;
private final RateLimiterConfig config;
- protected final FaultTolerantRedisCluster cacheCluster;
-
private final Counter counter;
+ private final ClusterLuaScript validateScript;
+
+ private final FaultTolerantRedisCluster cacheCluster;
+
+ private final Clock clock;
+
+
public StaticRateLimiter(
final String name,
final RateLimiterConfig config,
- final FaultTolerantRedisCluster cacheCluster) {
+ final ClusterLuaScript validateScript,
+ final FaultTolerantRedisCluster cacheCluster,
+ final Clock clock) {
this.name = requireNonNull(name);
this.config = requireNonNull(config);
+ this.validateScript = requireNonNull(validateScript);
this.cacheCluster = requireNonNull(cacheCluster);
- this.counter = Metrics.counter(name(getClass(), "exceeded"), "name", name);
+ this.clock = requireNonNull(clock);
+ this.counter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "name", name);
}
@Override
public void validate(final String key, final int amount) throws RateLimitExceededException {
- final LeakyBucket bucket = getBucket(key);
- if (bucket.add(amount)) {
- setBucket(key, bucket);
- } else {
+ final long deficitPermitsAmount = executeValidateScript(key, amount, true);
+ if (deficitPermitsAmount > 0) {
counter.increment();
- throw new RateLimitExceededException(bucket.getTimeUntilSpaceAvailable(amount), true);
+ final Duration retryAfter = Duration.ofMillis(
+ (long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
+ throw new RateLimitExceededException(retryAfter, true);
}
}
@Override
- public boolean hasAvailablePermits(final String key, final int permits) {
- return getBucket(key).getTimeUntilSpaceAvailable(permits).equals(Duration.ZERO);
+ public CompletionStage validateAsync(final String key, final int amount) {
+ return executeValidateScriptAsync(key, amount, true)
+ .thenCompose(deficitPermitsAmount -> {
+ if (deficitPermitsAmount == 0) {
+ return completedFuture(null);
+ }
+ counter.increment();
+ final Duration retryAfter = Duration.ofMillis(
+ (long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
+ return failedFuture(new RateLimitExceededException(retryAfter, true));
+ });
+ }
+
+ @Override
+ public boolean hasAvailablePermits(final String key, final int amount) {
+ final long deficitPermitsAmount = executeValidateScript(key, amount, false);
+ return deficitPermitsAmount == 0;
+ }
+
+ @Override
+ public CompletionStage hasAvailablePermitsAsync(final String key, final int amount) {
+ return executeValidateScriptAsync(key, amount, false)
+ .thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0);
}
@Override
public void clear(final String key) {
- cacheCluster.useCluster(connection -> connection.sync().del(getBucketName(key)));
+ cacheCluster.useCluster(connection -> connection.sync().del(bucketName(key)));
+ }
+
+ @Override
+ public CompletionStage clearAsync(final String key) {
+ return cacheCluster.withCluster(connection -> connection.async().del(bucketName(key)))
+ .thenRun(Util.NOOP);
}
@Override
@@ -66,33 +102,31 @@ public class StaticRateLimiter implements RateLimiter {
return config;
}
- private void setBucket(final String key, final LeakyBucket bucket) {
- try {
- final String serialized = bucket.serialize(SystemMapper.jsonMapper());
- cacheCluster.useCluster(connection -> connection.sync().setex(
- getBucketName(key),
- (int) Math.ceil((config.bucketSize() / config.leakRatePerMillis()) / 1000),
- serialized));
- } catch (final JsonProcessingException e) {
- throw new IllegalArgumentException(e);
- }
+ private long executeValidateScript(final String key, final int amount, final boolean applyChanges) {
+ final List keys = List.of(bucketName(key));
+ final List arguments = List.of(
+ String.valueOf(config.bucketSize()),
+ String.valueOf(config.leakRatePerMillis()),
+ String.valueOf(clock.millis()),
+ String.valueOf(amount),
+ String.valueOf(applyChanges)
+ );
+ return (Long) validateScript.execute(keys, arguments);
}
- private LeakyBucket getBucket(final String key) {
- try {
- final String serialized = cacheCluster.withCluster(connection -> connection.sync().get(getBucketName(key)));
-
- if (serialized != null) {
- return LeakyBucket.fromSerialized(SystemMapper.jsonMapper(), serialized);
- }
- } catch (final IOException e) {
- logger.warn("Deserialization error", e);
- }
-
- return new LeakyBucket(config.bucketSize(), config.leakRatePerMillis());
+ private CompletionStage executeValidateScriptAsync(final String key, final int amount, final boolean applyChanges) {
+ final List keys = List.of(bucketName(key));
+ final List arguments = List.of(
+ String.valueOf(config.bucketSize()),
+ String.valueOf(config.leakRatePerMillis()),
+ String.valueOf(clock.millis()),
+ String.valueOf(amount),
+ String.valueOf(applyChanges)
+ );
+ return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o);
}
- private String getBucketName(final String key) {
+ private String bucketName(final String key) {
return "leaky_bucket::" + name + "::" + key;
}
}
diff --git a/service/src/main/resources/lua/validate_rate_limit.lua b/service/src/main/resources/lua/validate_rate_limit.lua
new file mode 100644
index 000000000..5c092320a
--- /dev/null
+++ b/service/src/main/resources/lua/validate_rate_limit.lua
@@ -0,0 +1,67 @@
+-- The script encapsulates the logic of a token bucket rate limiter.
+-- Two types of operations are supported: 'check-only' and 'use-if-available' (controlled by the 'useTokens' arg).
+-- Both operations take in rate limiter configuration parameters and the requested amount of tokens.
+-- Both operations return 0, if the rate limiter has enough tokens to cover the requested amount,
+-- and the deficit amount otherwise.
+-- However, 'check-only' operation doesn't modify the bucket, while 'use-if-available' (if successful)
+-- reduces the amount of available tokens by the requested amount.
+
+local bucketId = KEYS[1]
+
+local bucketSize = tonumber(ARGV[1])
+local refillRatePerMillis = tonumber(ARGV[2])
+local currentTimeMillis = tonumber(ARGV[3])
+local requestedAmount = tonumber(ARGV[4])
+local useTokens = ARGV[5] and string.lower(ARGV[5]) == "true"
+
+local tokenBucketJson = redis.call("GET", bucketId)
+local tokenBucket
+local changesMade = false
+
+if tokenBucketJson then
+ tokenBucket = cjson.decode(tokenBucketJson)
+else
+ tokenBucket = {
+ ["bucketSize"] = bucketSize,
+ ["leakRatePerMillis"] = refillRatePerMillis,
+ ["spaceRemaining"] = bucketSize,
+ ["lastUpdateTimeMillis"] = currentTimeMillis
+ }
+end
+
+-- this can happen if rate limiter configuration has changed while the key is still in Redis
+if tokenBucket["bucketSize"] ~= bucketSize or tokenBucket["leakRatePerMillis"] ~= refillRatePerMillis then
+ tokenBucket["bucketSize"] = bucketSize
+ tokenBucket["leakRatePerMillis"] = refillRatePerMillis
+ changesMade = true
+end
+
+local elapsedTime = currentTimeMillis - tokenBucket["lastUpdateTimeMillis"]
+local availableAmount = math.min(
+ tokenBucket["bucketSize"],
+ math.floor(tokenBucket["spaceRemaining"] + (elapsedTime * tokenBucket["leakRatePerMillis"]))
+)
+
+if availableAmount >= requestedAmount then
+ if useTokens then
+ tokenBucket["spaceRemaining"] = availableAmount - requestedAmount
+ tokenBucket["lastUpdateTimeMillis"] = currentTimeMillis
+ changesMade = true
+ end
+ if changesMade then
+ local tokensUsed = tokenBucket["bucketSize"] - tokenBucket["spaceRemaining"]
+ -- Storing a 'full' bucket is equivalent of not storing any state at all
+ -- (in which case a bucket will be just initialized from the input configs as a 'full' one).
+ -- For this reason, we either set an expiration time on the record (calculated to let the bucket fully replenish)
+ -- or we just delete the key if the bucket is full.
+ if tokensUsed > 0 then
+ local ttlMillis = math.ceil(tokensUsed / tokenBucket["leakRatePerMillis"])
+ redis.call("SET", bucketId, cjson.encode(tokenBucket), "PX", ttlMillis)
+ else
+ redis.call("DEL", bucketId)
+ end
+ end
+ return 0
+else
+ return requestedAmount - availableAmount
+end
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/LeakyBucketTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/LeakyBucketTest.java
deleted file mode 100644
index 76ae6bf1b..000000000
--- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/LeakyBucketTest.java
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * Copyright 2013 Signal Messenger, LLC
- * SPDX-License-Identifier: AGPL-3.0-only
- */
-
-package org.whispersystems.textsecuregcm.limits;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertFalse;
-import static org.junit.jupiter.api.Assertions.assertThrows;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-import com.fasterxml.jackson.databind.ObjectMapper;
-import java.io.IOException;
-import java.time.Duration;
-import java.util.concurrent.TimeUnit;
-import org.junit.jupiter.api.Test;
-import org.whispersystems.textsecuregcm.util.SystemMapper;
-
-class LeakyBucketTest {
-
- @Test
- void testFull() {
- LeakyBucket leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
-
- assertTrue(leakyBucket.add(1));
- assertTrue(leakyBucket.add(1));
- assertFalse(leakyBucket.add(1));
-
- leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
-
- assertTrue(leakyBucket.add(2));
- assertFalse(leakyBucket.add(1));
- assertFalse(leakyBucket.add(2));
- }
-
- @Test
- void testLapseRate() throws IOException {
- ObjectMapper mapper = SystemMapper.jsonMapper();
- String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(2)) + "}";
-
- LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
- assertTrue(leakyBucket.add(1));
-
- String serializedAgain = leakyBucket.serialize(mapper);
- LeakyBucket leakyBucketAgain = LeakyBucket.fromSerialized(mapper, serializedAgain);
-
- assertFalse(leakyBucketAgain.add(1));
- }
-
- @Test
- void testLapseShort() throws Exception {
- ObjectMapper mapper = new ObjectMapper();
- String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
-
- LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
- assertFalse(leakyBucket.add(1));
- }
-
- @Test
- void testGetTimeUntilSpaceAvailable() throws Exception {
- ObjectMapper mapper = new ObjectMapper();
-
- {
- String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":2,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
-
- LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
-
- assertEquals(Duration.ZERO, leakyBucket.getTimeUntilSpaceAvailable(1));
- assertThrows(IllegalArgumentException.class, () -> leakyBucket.getTimeUntilSpaceAvailable(5000));
- }
-
- {
- String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
-
- LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
-
- Duration timeUntilSpaceAvailable = leakyBucket.getTimeUntilSpaceAvailable(1);
-
- // TODO Refactor LeakyBucket to be more test-friendly and accept a Clock
- assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMillis(119_000)) > 0);
- assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMinutes(2)) <= 0);
- }
- }
-}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java
new file mode 100644
index 000000000..424c0fff1
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java
@@ -0,0 +1,147 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.limits;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import io.lettuce.core.ScriptOutputType;
+import java.time.Clock;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
+import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
+import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
+import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
+import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
+import org.whispersystems.textsecuregcm.util.MockUtils;
+import org.whispersystems.textsecuregcm.util.MutableClock;
+import org.whispersystems.textsecuregcm.util.SystemMapper;
+import org.whispersystems.textsecuregcm.util.redis.RedisLuaScriptSandbox;
+import org.whispersystems.textsecuregcm.util.redis.SimpleCacheCommandsHandler;
+
+public class RateLimitersLuaScriptTest {
+
+ @RegisterExtension
+ private static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
+
+ private final DynamicConfiguration configuration = mock(DynamicConfiguration.class);
+
+ private final MutableClock clock = MockUtils.mutableClock(0);
+
+ private final RedisLuaScriptSandbox sandbox = RedisLuaScriptSandbox.fromResource(
+ "lua/validate_rate_limit.lua",
+ ScriptOutputType.INTEGER);
+
+ private final SimpleCacheCommandsHandler redisCommandsHandler = new SimpleCacheCommandsHandler(clock);
+
+ private final DynamicConfigurationManager dynamicConfig =
+ MockUtils.buildMock(DynamicConfigurationManager.class, cfg -> when(cfg.getConfiguration()).thenReturn(configuration));
+
+ @Test
+ public void testWithEmbeddedRedis() throws Exception {
+ final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION;
+ final FaultTolerantRedisCluster redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster();
+ final RateLimiters limiters = new RateLimiters(
+ Map.of(descriptor.id(), new RateLimiterConfig(60, 60)),
+ dynamicConfig,
+ RateLimiters.defaultScript(redisCluster),
+ redisCluster,
+ Clock.systemUTC());
+
+ final RateLimiter rateLimiter = limiters.forDescriptor(descriptor);
+ rateLimiter.validate("test", 25);
+ rateLimiter.validate("test", 25);
+ assertThrows(Exception.class, () -> rateLimiter.validate("test", 25));
+ }
+
+ @Test
+ public void testLuaBucketConfigurationUpdates() throws Exception {
+ final String key = "key1";
+ clock.setTimeMillis(0);
+ long result = (long) sandbox.execute(
+ List.of(key),
+ scriptArgs(1000, 1, 1, true),
+ redisCommandsHandler
+ );
+ assertEquals(0L, result);
+ assertEquals(1000L, decodeBucket(key).orElseThrow().bucketSize);
+
+ // now making a check-only call, but changing the bucket size
+ result = (long) sandbox.execute(
+ List.of(key),
+ scriptArgs(2000, 1, 1, false),
+ redisCommandsHandler
+ );
+ assertEquals(0L, result);
+ assertEquals(2000L, decodeBucket(key).orElseThrow().bucketSize);
+ }
+
+ @Test
+ public void testLuaUpdatesTokenBucket() throws Exception {
+ final String key = "key1";
+ clock.setTimeMillis(0);
+ long result = (long) sandbox.execute(
+ List.of(key),
+ scriptArgs(1000, 1, 200, true),
+ redisCommandsHandler
+ );
+ assertEquals(0L, result);
+ assertEquals(800L, decodeBucket(key).orElseThrow().spaceRemaining);
+
+ // 50 tokens replenished, acquiring 100 more, should end up with 750 available
+ clock.setTimeMillis(50);
+ result = (long) sandbox.execute(
+ List.of(key),
+ scriptArgs(1000, 1, 100, true),
+ redisCommandsHandler
+ );
+ assertEquals(0L, result);
+ assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining);
+
+ // now checking without an update, should not affect the count
+ result = (long) sandbox.execute(
+ List.of(key),
+ scriptArgs(1000, 1, 100, false),
+ redisCommandsHandler
+ );
+ assertEquals(0L, result);
+ assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining);
+ }
+
+ private Optional decodeBucket(final String key) {
+ try {
+ final String json = redisCommandsHandler.get(key);
+ return json == null
+ ? Optional.empty()
+ : Optional.of(SystemMapper.jsonMapper().readValue(json, TokenBucket.class));
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private List scriptArgs(
+ final long bucketSize,
+ final long ratePerMillis,
+ final long requestedAmount,
+ final boolean useTokens) {
+ return List.of(
+ String.valueOf(bucketSize),
+ String.valueOf(ratePerMillis),
+ String.valueOf(clock.millis()),
+ String.valueOf(requestedAmount),
+ String.valueOf(useTokens)
+ );
+ }
+
+ private record TokenBucket(long bucketSize, long leakRatePerMillis, long spaceRemaining, long lastUpdateTimeMillis) {
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java
index 77bcf124a..d1705e7f9 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java
@@ -18,9 +18,11 @@ import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
+import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.MockUtils;
+import org.whispersystems.textsecuregcm.util.MutableClock;
@SuppressWarnings("unchecked")
public class RateLimitersTest {
@@ -30,8 +32,12 @@ public class RateLimitersTest {
private final DynamicConfigurationManager dynamicConfig =
MockUtils.buildMock(DynamicConfigurationManager.class, cfg -> when(cfg.getConfiguration()).thenReturn(configuration));
+ private final ClusterLuaScript validateScript = mock(ClusterLuaScript.class);
+
private final FaultTolerantRedisCluster redisCluster = mock(FaultTolerantRedisCluster.class);
+ private final MutableClock clock = MockUtils.mutableClock(0);
+
private static final String BAD_YAML = """
limits:
smsVoicePrefix:
@@ -59,12 +65,12 @@ public class RateLimitersTest {
public void testValidateConfigs() throws Exception {
assertThrows(IllegalArgumentException.class, () -> {
final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(BAD_YAML, GenericHolder.class).orElseThrow();
- final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster);
+ final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, validateScript, redisCluster, clock);
rateLimiters.validateValuesAndConfigs();
});
final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(GOOD_YAML, GenericHolder.class).orElseThrow();
- final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster);
+ final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, validateScript, redisCluster, clock);
rateLimiters.validateValuesAndConfigs();
}
@@ -79,18 +85,22 @@ public class RateLimitersTest {
new TestDescriptor[] { td1, td2, td3, tdDup },
Collections.emptyMap(),
dynamicConfig,
- redisCluster) {});
+ validateScript,
+ redisCluster,
+ clock) {});
new BaseRateLimiters<>(
new TestDescriptor[] { td1, td2, td3 },
Collections.emptyMap(),
dynamicConfig,
- redisCluster) {};
+ validateScript,
+ redisCluster,
+ clock) {};
}
@Test
void testUnchangingConfiguration() {
- final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster);
+ final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock);
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
final RateLimiterConfig expected = RateLimiters.For.RATE_LIMIT_RESET.defaultConfig();
assertEquals(expected, limiter.config());
@@ -109,7 +119,7 @@ public class RateLimitersTest {
when(configuration.getLimits()).thenReturn(limitsConfigMap);
- final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster);
+ final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock);
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), initialRateLimiterConfig);
@@ -137,7 +147,7 @@ public class RateLimitersTest {
when(configuration.getLimits()).thenReturn(mapForDynamic);
- final RateLimiters rateLimiters = new RateLimiters(mapForStatic, dynamicConfig, redisCluster);
+ final RateLimiters rateLimiters = new RateLimiters(mapForStatic, dynamicConfig, validateScript, redisCluster, clock);
final RateLimiter limiter = rateLimiters.forDescriptor(descriptor);
// test only default is present
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/BaseRedisCommandsHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/BaseRedisCommandsHandler.java
new file mode 100644
index 000000000..0619ad8c9
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/BaseRedisCommandsHandler.java
@@ -0,0 +1,54 @@
+/*
+ * Copyright 2023 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.util.redis;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.whispersystems.textsecuregcm.util.redis.RedisLuaScriptSandbox.tail;
+
+import java.util.List;
+
+/**
+ * This class is to be extended with implementations of Redis commands as needed.
+ */
+public class BaseRedisCommandsHandler implements RedisCommandsHandler {
+
+ @Override
+ public Object redisCommand(final String command, final List