diff --git a/pom.xml b/pom.xml index c370c4bd8..67f544595 100644 --- a/pom.xml +++ b/pom.xml @@ -64,6 +64,7 @@ 6.2.1.RELEASE 8.12.54 7.2 + 3.3.0 1.10.3 4.11.0 4.1.82.Final diff --git a/service/pom.xml b/service/pom.xml index 5e5c6f054..88214f098 100644 --- a/service/pom.xml +++ b/service/pom.xml @@ -172,6 +172,26 @@ + + party.iroiro.luajava + luajava + ${luajava.version} + test + + + party.iroiro.luajava + lua51 + ${luajava.version} + test + + + party.iroiro.luajava + lua51-platform + ${luajava.version} + natives-desktop + runtime + + org.eclipse.jetty.websocket websocket-api diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/BaseRateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/BaseRateLimiters.java index b92d0a786..65ad263ca 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/BaseRateLimiters.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/BaseRateLimiters.java @@ -7,7 +7,11 @@ package org.whispersystems.textsecuregcm.limits; import static java.util.Objects.requireNonNull; +import io.lettuce.core.ScriptOutputType; +import java.io.IOException; +import java.io.UncheckedIOException; import java.lang.invoke.MethodHandles; +import java.time.Clock; import java.util.Arrays; import java.util.Map; import java.util.Set; @@ -17,6 +21,7 @@ import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; @@ -33,12 +38,14 @@ public abstract class BaseRateLimiters { final T[] values, final Map configs, final DynamicConfigurationManager dynamicConfigurationManager, - final FaultTolerantRedisCluster cacheCluster) { + final ClusterLuaScript validateScript, + final FaultTolerantRedisCluster cacheCluster, + final Clock clock) { this.configs = configs; this.rateLimiterByDescriptor = Arrays.stream(values) .map(descriptor -> Pair.of( descriptor, - createForDescriptor(descriptor, configs, dynamicConfigurationManager, cacheCluster))) + createForDescriptor(descriptor, configs, dynamicConfigurationManager, validateScript, cacheCluster, clock))) .collect(Collectors.toUnmodifiableMap(Pair::getKey, Pair::getValue)); } @@ -62,11 +69,22 @@ public abstract class BaseRateLimiters { } } + protected static ClusterLuaScript defaultScript(final FaultTolerantRedisCluster cacheCluster) { + try { + return ClusterLuaScript.fromResource( + cacheCluster, "lua/validate_rate_limit.lua", ScriptOutputType.INTEGER); + } catch (final IOException e) { + throw new UncheckedIOException("Failed to load rate limit validation script", e); + } + } + private static RateLimiter createForDescriptor( final RateLimiterDescriptor descriptor, final Map configs, final DynamicConfigurationManager dynamicConfigurationManager, - final FaultTolerantRedisCluster cacheCluster) { + final ClusterLuaScript validateScript, + final FaultTolerantRedisCluster cacheCluster, + final Clock clock) { if (descriptor.isDynamic()) { final Supplier configResolver = () -> { final RateLimiterConfig config = dynamicConfigurationManager.getConfiguration().getLimits().get(descriptor.id()); @@ -74,9 +92,9 @@ public abstract class BaseRateLimiters { ? config : configs.getOrDefault(descriptor.id(), descriptor.defaultConfig()); }; - return new DynamicRateLimiter(descriptor.id(), configResolver, cacheCluster); + return new DynamicRateLimiter(descriptor.id(), configResolver, validateScript, cacheCluster, clock); } final RateLimiterConfig cfg = configs.getOrDefault(descriptor.id(), descriptor.defaultConfig()); - return new StaticRateLimiter(descriptor.id(), cfg, cacheCluster); + return new StaticRateLimiter(descriptor.id(), cfg, validateScript, cacheCluster, clock); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiter.java index dc711469a..8f7b0ec82 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiter.java @@ -7,10 +7,13 @@ package org.whispersystems.textsecuregcm.limits; import static java.util.Objects.requireNonNull; +import java.time.Clock; +import java.util.concurrent.CompletionStage; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import org.apache.commons.lang3.tuple.Pair; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; public class DynamicRateLimiter implements RateLimiter { @@ -19,18 +22,26 @@ public class DynamicRateLimiter implements RateLimiter { private final Supplier configResolver; + private final ClusterLuaScript validateScript; + private final FaultTolerantRedisCluster cluster; + private final Clock clock; + private final AtomicReference> currentHolder = new AtomicReference<>(); public DynamicRateLimiter( final String name, final Supplier configResolver, - final FaultTolerantRedisCluster cluster) { + final ClusterLuaScript validateScript, + final FaultTolerantRedisCluster cluster, + final Clock clock) { this.name = requireNonNull(name); this.configResolver = requireNonNull(configResolver); + this.validateScript = requireNonNull(validateScript); this.cluster = requireNonNull(cluster); + this.clock = requireNonNull(clock); } @Override @@ -38,16 +49,31 @@ public class DynamicRateLimiter implements RateLimiter { current().getRight().validate(key, amount); } + @Override + public CompletionStage validateAsync(final String key, final int amount) { + return current().getRight().validateAsync(key, amount); + } + @Override public boolean hasAvailablePermits(final String key, final int permits) { return current().getRight().hasAvailablePermits(key, permits); } + @Override + public CompletionStage hasAvailablePermitsAsync(final String key, final int amount) { + return current().getRight().hasAvailablePermitsAsync(key, amount); + } + @Override public void clear(final String key) { current().getRight().clear(key); } + @Override + public CompletionStage clearAsync(final String key) { + return current().getRight().clearAsync(key); + } + @Override public RateLimiterConfig config() { return current().getLeft(); @@ -57,7 +83,7 @@ public class DynamicRateLimiter implements RateLimiter { final RateLimiterConfig cfg = configResolver.get(); return currentHolder.updateAndGet(p -> p != null && p.getLeft().equals(cfg) ? p - : Pair.of(cfg, new StaticRateLimiter(name, cfg, cluster)) + : Pair.of(cfg, new StaticRateLimiter(name, cfg, validateScript, cluster, clock)) ); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java deleted file mode 100644 index 51b60ea8d..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright 2013-2020 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.limits; - -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; - -import java.io.IOException; -import java.time.Duration; - -public class LeakyBucket { - - private final int bucketSize; - private final double leakRatePerMillis; - - private int spaceRemaining; - private long lastUpdateTimeMillis; - - public LeakyBucket(int bucketSize, double leakRatePerMillis) { - this(bucketSize, leakRatePerMillis, bucketSize, System.currentTimeMillis()); - } - - private LeakyBucket(int bucketSize, double leakRatePerMillis, int spaceRemaining, long lastUpdateTimeMillis) { - this.bucketSize = bucketSize; - this.leakRatePerMillis = leakRatePerMillis; - this.spaceRemaining = spaceRemaining; - this.lastUpdateTimeMillis = lastUpdateTimeMillis; - } - - public boolean add(int amount) { - this.spaceRemaining = getUpdatedSpaceRemaining(); - this.lastUpdateTimeMillis = System.currentTimeMillis(); - - if (this.spaceRemaining >= amount) { - this.spaceRemaining -= amount; - return true; - } else { - return false; - } - } - - private int getUpdatedSpaceRemaining() { - long elapsedTime = System.currentTimeMillis() - this.lastUpdateTimeMillis; - - return Math.min(this.bucketSize, - (int)Math.floor(this.spaceRemaining + (elapsedTime * this.leakRatePerMillis))); - } - - public Duration getTimeUntilSpaceAvailable(int amount) { - int currentSpaceRemaining = getUpdatedSpaceRemaining(); - if (currentSpaceRemaining >= amount) { - return Duration.ZERO; - } else if (amount > this.bucketSize) { - // This shouldn't happen today but if so we should bubble this to the clients somehow - throw new IllegalArgumentException("Requested permits exceed maximum bucket size"); - } else { - return Duration.ofMillis((long)Math.ceil((double)(amount - currentSpaceRemaining) / this.leakRatePerMillis)); - } - } - - public String serialize(ObjectMapper mapper) throws JsonProcessingException { - return mapper.writeValueAsString(new LeakyBucketEntity(bucketSize, leakRatePerMillis, spaceRemaining, lastUpdateTimeMillis)); - } - - public static LeakyBucket fromSerialized(ObjectMapper mapper, String serialized) throws IOException { - LeakyBucketEntity entity = mapper.readValue(serialized, LeakyBucketEntity.class); - - return new LeakyBucket(entity.bucketSize, entity.leakRatePerMillis, - entity.spaceRemaining, entity.lastUpdateTimeMillis); - } - - private static class LeakyBucketEntity { - @JsonProperty - private int bucketSize; - - @JsonProperty - private double leakRatePerMillis; - - @JsonProperty - private int spaceRemaining; - - @JsonProperty - private long lastUpdateTimeMillis; - - public LeakyBucketEntity() {} - - private LeakyBucketEntity(int bucketSize, double leakRatePerMillis, - int spaceRemaining, long lastUpdateTimeMillis) - { - this.bucketSize = bucketSize; - this.leakRatePerMillis = leakRatePerMillis; - this.spaceRemaining = spaceRemaining; - this.lastUpdateTimeMillis = lastUpdateTimeMillis; - } - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java deleted file mode 100644 index c755f83a4..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright 2013 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.limits; - -import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; - -import io.lettuce.core.SetArgs; -import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.Metrics; -import java.time.Duration; -import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; -import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; - -public class LockingRateLimiter extends StaticRateLimiter { - - private static final RateLimitExceededException REUSABLE_RATE_LIMIT_EXCEEDED_EXCEPTION - = new RateLimitExceededException(Duration.ZERO, true); - - private final Counter counter; - - - public LockingRateLimiter( - final String name, - final RateLimiterConfig config, - final FaultTolerantRedisCluster cacheCluster) { - super(name, config, cacheCluster); - this.counter = Metrics.counter(name(getClass(), "locked"), "name", name); - } - - @Override - public void validate(final String key, final int amount) throws RateLimitExceededException { - if (!acquireLock(key)) { - counter.increment(); - throw REUSABLE_RATE_LIMIT_EXCEEDED_EXCEPTION; - } - - try { - super.validate(key, amount); - } finally { - releaseLock(key); - } - } - - private void releaseLock(final String key) { - cacheCluster.useCluster(connection -> connection.sync().del(getLockName(key))); - } - - private boolean acquireLock(final String key) { - return cacheCluster.withCluster(connection -> connection.sync().set(getLockName(key), "L", SetArgs.Builder.nx().ex(10))) != null; - } - - private String getLockName(final String key) { - return "leaky_lock::" + name + "::" + key; - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java index 6f168166a..44b819dc4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java @@ -6,16 +6,23 @@ package org.whispersystems.textsecuregcm.limits; import java.util.UUID; +import java.util.concurrent.CompletionStage; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; public interface RateLimiter { void validate(String key, int amount) throws RateLimitExceededException; + CompletionStage validateAsync(String key, int amount); + boolean hasAvailablePermits(String key, int permits); + CompletionStage hasAvailablePermitsAsync(String key, int amount); + void clear(String key); + CompletionStage clearAsync(String key); + RateLimiterConfig config(); default void validate(final String key) throws RateLimitExceededException { @@ -30,14 +37,34 @@ public interface RateLimiter { validate(srcAccountUuid.toString() + "__" + dstAccountUuid.toString()); } + default CompletionStage validateAsync(final String key) { + return validateAsync(key, 1); + } + + default CompletionStage validateAsync(final UUID accountUuid) { + return validateAsync(accountUuid.toString()); + } + + default CompletionStage validateAsync(final UUID srcAccountUuid, final UUID dstAccountUuid) { + return validateAsync(srcAccountUuid.toString() + "__" + dstAccountUuid.toString()); + } + default boolean hasAvailablePermits(final UUID accountUuid, final int permits) { return hasAvailablePermits(accountUuid.toString(), permits); } + default CompletionStage hasAvailablePermitsAsync(final UUID accountUuid, final int permits) { + return hasAvailablePermitsAsync(accountUuid.toString(), permits); + } + default void clear(final UUID accountUuid) { clear(accountUuid.toString()); } + default CompletionStage clearAsync(final UUID accountUuid) { + return clearAsync(accountUuid.toString()); + } + /** * If the wrapped {@code validate()} call throws a {@link RateLimitExceededException}, it will adapt it to ensure that * {@link RateLimitExceededException#isLegacy()} returns {@code true} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java index 3aca56d7c..fa9262f96 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java @@ -6,8 +6,10 @@ package org.whispersystems.textsecuregcm.limits; import com.google.common.annotations.VisibleForTesting; +import java.time.Clock; import java.util.Map; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; @@ -103,7 +105,8 @@ public class RateLimiters extends BaseRateLimiters { final Map configs, final DynamicConfigurationManager dynamicConfigurationManager, final FaultTolerantRedisCluster cacheCluster) { - final RateLimiters rateLimiters = new RateLimiters(configs, dynamicConfigurationManager, cacheCluster); + final RateLimiters rateLimiters = new RateLimiters( + configs, dynamicConfigurationManager, defaultScript(cacheCluster), cacheCluster, Clock.systemUTC()); rateLimiters.validateValuesAndConfigs(); return rateLimiters; } @@ -112,8 +115,10 @@ public class RateLimiters extends BaseRateLimiters { RateLimiters( final Map configs, final DynamicConfigurationManager dynamicConfigurationManager, - final FaultTolerantRedisCluster cacheCluster) { - super(For.values(), configs, dynamicConfigurationManager, cacheCluster); + final ClusterLuaScript validateScript, + final FaultTolerantRedisCluster cacheCluster, + final Clock clock) { + super(For.values(), configs, dynamicConfigurationManager, validateScript, cacheCluster, clock); } public RateLimiter getAllocateDeviceLimiter() { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java index 22d6058c5..22d4f4ffc 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java @@ -5,60 +5,96 @@ package org.whispersystems.textsecuregcm.limits; import static java.util.Objects.requireNonNull; -import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; +import static java.util.concurrent.CompletableFuture.completedFuture; +import static java.util.concurrent.CompletableFuture.failedFuture; -import com.fasterxml.jackson.core.JsonProcessingException; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; -import java.io.IOException; +import java.time.Clock; import java.time.Duration; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import java.util.List; +import java.util.concurrent.CompletionStage; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; -import org.whispersystems.textsecuregcm.util.SystemMapper; +import org.whispersystems.textsecuregcm.util.Util; public class StaticRateLimiter implements RateLimiter { - private static final Logger logger = LoggerFactory.getLogger(StaticRateLimiter.class); - protected final String name; private final RateLimiterConfig config; - protected final FaultTolerantRedisCluster cacheCluster; - private final Counter counter; + private final ClusterLuaScript validateScript; + + private final FaultTolerantRedisCluster cacheCluster; + + private final Clock clock; + + public StaticRateLimiter( final String name, final RateLimiterConfig config, - final FaultTolerantRedisCluster cacheCluster) { + final ClusterLuaScript validateScript, + final FaultTolerantRedisCluster cacheCluster, + final Clock clock) { this.name = requireNonNull(name); this.config = requireNonNull(config); + this.validateScript = requireNonNull(validateScript); this.cacheCluster = requireNonNull(cacheCluster); - this.counter = Metrics.counter(name(getClass(), "exceeded"), "name", name); + this.clock = requireNonNull(clock); + this.counter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "name", name); } @Override public void validate(final String key, final int amount) throws RateLimitExceededException { - final LeakyBucket bucket = getBucket(key); - if (bucket.add(amount)) { - setBucket(key, bucket); - } else { + final long deficitPermitsAmount = executeValidateScript(key, amount, true); + if (deficitPermitsAmount > 0) { counter.increment(); - throw new RateLimitExceededException(bucket.getTimeUntilSpaceAvailable(amount), true); + final Duration retryAfter = Duration.ofMillis( + (long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis())); + throw new RateLimitExceededException(retryAfter, true); } } @Override - public boolean hasAvailablePermits(final String key, final int permits) { - return getBucket(key).getTimeUntilSpaceAvailable(permits).equals(Duration.ZERO); + public CompletionStage validateAsync(final String key, final int amount) { + return executeValidateScriptAsync(key, amount, true) + .thenCompose(deficitPermitsAmount -> { + if (deficitPermitsAmount == 0) { + return completedFuture(null); + } + counter.increment(); + final Duration retryAfter = Duration.ofMillis( + (long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis())); + return failedFuture(new RateLimitExceededException(retryAfter, true)); + }); + } + + @Override + public boolean hasAvailablePermits(final String key, final int amount) { + final long deficitPermitsAmount = executeValidateScript(key, amount, false); + return deficitPermitsAmount == 0; + } + + @Override + public CompletionStage hasAvailablePermitsAsync(final String key, final int amount) { + return executeValidateScriptAsync(key, amount, false) + .thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0); } @Override public void clear(final String key) { - cacheCluster.useCluster(connection -> connection.sync().del(getBucketName(key))); + cacheCluster.useCluster(connection -> connection.sync().del(bucketName(key))); + } + + @Override + public CompletionStage clearAsync(final String key) { + return cacheCluster.withCluster(connection -> connection.async().del(bucketName(key))) + .thenRun(Util.NOOP); } @Override @@ -66,33 +102,31 @@ public class StaticRateLimiter implements RateLimiter { return config; } - private void setBucket(final String key, final LeakyBucket bucket) { - try { - final String serialized = bucket.serialize(SystemMapper.jsonMapper()); - cacheCluster.useCluster(connection -> connection.sync().setex( - getBucketName(key), - (int) Math.ceil((config.bucketSize() / config.leakRatePerMillis()) / 1000), - serialized)); - } catch (final JsonProcessingException e) { - throw new IllegalArgumentException(e); - } + private long executeValidateScript(final String key, final int amount, final boolean applyChanges) { + final List keys = List.of(bucketName(key)); + final List arguments = List.of( + String.valueOf(config.bucketSize()), + String.valueOf(config.leakRatePerMillis()), + String.valueOf(clock.millis()), + String.valueOf(amount), + String.valueOf(applyChanges) + ); + return (Long) validateScript.execute(keys, arguments); } - private LeakyBucket getBucket(final String key) { - try { - final String serialized = cacheCluster.withCluster(connection -> connection.sync().get(getBucketName(key))); - - if (serialized != null) { - return LeakyBucket.fromSerialized(SystemMapper.jsonMapper(), serialized); - } - } catch (final IOException e) { - logger.warn("Deserialization error", e); - } - - return new LeakyBucket(config.bucketSize(), config.leakRatePerMillis()); + private CompletionStage executeValidateScriptAsync(final String key, final int amount, final boolean applyChanges) { + final List keys = List.of(bucketName(key)); + final List arguments = List.of( + String.valueOf(config.bucketSize()), + String.valueOf(config.leakRatePerMillis()), + String.valueOf(clock.millis()), + String.valueOf(amount), + String.valueOf(applyChanges) + ); + return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o); } - private String getBucketName(final String key) { + private String bucketName(final String key) { return "leaky_bucket::" + name + "::" + key; } } diff --git a/service/src/main/resources/lua/validate_rate_limit.lua b/service/src/main/resources/lua/validate_rate_limit.lua new file mode 100644 index 000000000..5c092320a --- /dev/null +++ b/service/src/main/resources/lua/validate_rate_limit.lua @@ -0,0 +1,67 @@ +-- The script encapsulates the logic of a token bucket rate limiter. +-- Two types of operations are supported: 'check-only' and 'use-if-available' (controlled by the 'useTokens' arg). +-- Both operations take in rate limiter configuration parameters and the requested amount of tokens. +-- Both operations return 0, if the rate limiter has enough tokens to cover the requested amount, +-- and the deficit amount otherwise. +-- However, 'check-only' operation doesn't modify the bucket, while 'use-if-available' (if successful) +-- reduces the amount of available tokens by the requested amount. + +local bucketId = KEYS[1] + +local bucketSize = tonumber(ARGV[1]) +local refillRatePerMillis = tonumber(ARGV[2]) +local currentTimeMillis = tonumber(ARGV[3]) +local requestedAmount = tonumber(ARGV[4]) +local useTokens = ARGV[5] and string.lower(ARGV[5]) == "true" + +local tokenBucketJson = redis.call("GET", bucketId) +local tokenBucket +local changesMade = false + +if tokenBucketJson then + tokenBucket = cjson.decode(tokenBucketJson) +else + tokenBucket = { + ["bucketSize"] = bucketSize, + ["leakRatePerMillis"] = refillRatePerMillis, + ["spaceRemaining"] = bucketSize, + ["lastUpdateTimeMillis"] = currentTimeMillis + } +end + +-- this can happen if rate limiter configuration has changed while the key is still in Redis +if tokenBucket["bucketSize"] ~= bucketSize or tokenBucket["leakRatePerMillis"] ~= refillRatePerMillis then + tokenBucket["bucketSize"] = bucketSize + tokenBucket["leakRatePerMillis"] = refillRatePerMillis + changesMade = true +end + +local elapsedTime = currentTimeMillis - tokenBucket["lastUpdateTimeMillis"] +local availableAmount = math.min( + tokenBucket["bucketSize"], + math.floor(tokenBucket["spaceRemaining"] + (elapsedTime * tokenBucket["leakRatePerMillis"])) +) + +if availableAmount >= requestedAmount then + if useTokens then + tokenBucket["spaceRemaining"] = availableAmount - requestedAmount + tokenBucket["lastUpdateTimeMillis"] = currentTimeMillis + changesMade = true + end + if changesMade then + local tokensUsed = tokenBucket["bucketSize"] - tokenBucket["spaceRemaining"] + -- Storing a 'full' bucket is equivalent of not storing any state at all + -- (in which case a bucket will be just initialized from the input configs as a 'full' one). + -- For this reason, we either set an expiration time on the record (calculated to let the bucket fully replenish) + -- or we just delete the key if the bucket is full. + if tokensUsed > 0 then + local ttlMillis = math.ceil(tokensUsed / tokenBucket["leakRatePerMillis"]) + redis.call("SET", bucketId, cjson.encode(tokenBucket), "PX", ttlMillis) + else + redis.call("DEL", bucketId) + end + end + return 0 +else + return requestedAmount - availableAmount +end diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/LeakyBucketTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/LeakyBucketTest.java deleted file mode 100644 index 76ae6bf1b..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/LeakyBucketTest.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright 2013 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.limits; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import com.fasterxml.jackson.databind.ObjectMapper; -import java.io.IOException; -import java.time.Duration; -import java.util.concurrent.TimeUnit; -import org.junit.jupiter.api.Test; -import org.whispersystems.textsecuregcm.util.SystemMapper; - -class LeakyBucketTest { - - @Test - void testFull() { - LeakyBucket leakyBucket = new LeakyBucket(2, 1.0 / 2.0); - - assertTrue(leakyBucket.add(1)); - assertTrue(leakyBucket.add(1)); - assertFalse(leakyBucket.add(1)); - - leakyBucket = new LeakyBucket(2, 1.0 / 2.0); - - assertTrue(leakyBucket.add(2)); - assertFalse(leakyBucket.add(1)); - assertFalse(leakyBucket.add(2)); - } - - @Test - void testLapseRate() throws IOException { - ObjectMapper mapper = SystemMapper.jsonMapper(); - String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(2)) + "}"; - - LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized); - assertTrue(leakyBucket.add(1)); - - String serializedAgain = leakyBucket.serialize(mapper); - LeakyBucket leakyBucketAgain = LeakyBucket.fromSerialized(mapper, serializedAgain); - - assertFalse(leakyBucketAgain.add(1)); - } - - @Test - void testLapseShort() throws Exception { - ObjectMapper mapper = new ObjectMapper(); - String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}"; - - LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized); - assertFalse(leakyBucket.add(1)); - } - - @Test - void testGetTimeUntilSpaceAvailable() throws Exception { - ObjectMapper mapper = new ObjectMapper(); - - { - String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":2,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}"; - - LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized); - - assertEquals(Duration.ZERO, leakyBucket.getTimeUntilSpaceAvailable(1)); - assertThrows(IllegalArgumentException.class, () -> leakyBucket.getTimeUntilSpaceAvailable(5000)); - } - - { - String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}"; - - LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized); - - Duration timeUntilSpaceAvailable = leakyBucket.getTimeUntilSpaceAvailable(1); - - // TODO Refactor LeakyBucket to be more test-friendly and accept a Clock - assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMillis(119_000)) > 0); - assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMinutes(2)) <= 0); - } - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java new file mode 100644 index 000000000..424c0fff1 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java @@ -0,0 +1,147 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.limits; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.core.JsonProcessingException; +import io.lettuce.core.ScriptOutputType; +import java.time.Clock; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; +import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; +import org.whispersystems.textsecuregcm.util.MockUtils; +import org.whispersystems.textsecuregcm.util.MutableClock; +import org.whispersystems.textsecuregcm.util.SystemMapper; +import org.whispersystems.textsecuregcm.util.redis.RedisLuaScriptSandbox; +import org.whispersystems.textsecuregcm.util.redis.SimpleCacheCommandsHandler; + +public class RateLimitersLuaScriptTest { + + @RegisterExtension + private static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); + + private final DynamicConfiguration configuration = mock(DynamicConfiguration.class); + + private final MutableClock clock = MockUtils.mutableClock(0); + + private final RedisLuaScriptSandbox sandbox = RedisLuaScriptSandbox.fromResource( + "lua/validate_rate_limit.lua", + ScriptOutputType.INTEGER); + + private final SimpleCacheCommandsHandler redisCommandsHandler = new SimpleCacheCommandsHandler(clock); + + private final DynamicConfigurationManager dynamicConfig = + MockUtils.buildMock(DynamicConfigurationManager.class, cfg -> when(cfg.getConfiguration()).thenReturn(configuration)); + + @Test + public void testWithEmbeddedRedis() throws Exception { + final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION; + final FaultTolerantRedisCluster redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster(); + final RateLimiters limiters = new RateLimiters( + Map.of(descriptor.id(), new RateLimiterConfig(60, 60)), + dynamicConfig, + RateLimiters.defaultScript(redisCluster), + redisCluster, + Clock.systemUTC()); + + final RateLimiter rateLimiter = limiters.forDescriptor(descriptor); + rateLimiter.validate("test", 25); + rateLimiter.validate("test", 25); + assertThrows(Exception.class, () -> rateLimiter.validate("test", 25)); + } + + @Test + public void testLuaBucketConfigurationUpdates() throws Exception { + final String key = "key1"; + clock.setTimeMillis(0); + long result = (long) sandbox.execute( + List.of(key), + scriptArgs(1000, 1, 1, true), + redisCommandsHandler + ); + assertEquals(0L, result); + assertEquals(1000L, decodeBucket(key).orElseThrow().bucketSize); + + // now making a check-only call, but changing the bucket size + result = (long) sandbox.execute( + List.of(key), + scriptArgs(2000, 1, 1, false), + redisCommandsHandler + ); + assertEquals(0L, result); + assertEquals(2000L, decodeBucket(key).orElseThrow().bucketSize); + } + + @Test + public void testLuaUpdatesTokenBucket() throws Exception { + final String key = "key1"; + clock.setTimeMillis(0); + long result = (long) sandbox.execute( + List.of(key), + scriptArgs(1000, 1, 200, true), + redisCommandsHandler + ); + assertEquals(0L, result); + assertEquals(800L, decodeBucket(key).orElseThrow().spaceRemaining); + + // 50 tokens replenished, acquiring 100 more, should end up with 750 available + clock.setTimeMillis(50); + result = (long) sandbox.execute( + List.of(key), + scriptArgs(1000, 1, 100, true), + redisCommandsHandler + ); + assertEquals(0L, result); + assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining); + + // now checking without an update, should not affect the count + result = (long) sandbox.execute( + List.of(key), + scriptArgs(1000, 1, 100, false), + redisCommandsHandler + ); + assertEquals(0L, result); + assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining); + } + + private Optional decodeBucket(final String key) { + try { + final String json = redisCommandsHandler.get(key); + return json == null + ? Optional.empty() + : Optional.of(SystemMapper.jsonMapper().readValue(json, TokenBucket.class)); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private List scriptArgs( + final long bucketSize, + final long ratePerMillis, + final long requestedAmount, + final boolean useTokens) { + return List.of( + String.valueOf(bucketSize), + String.valueOf(ratePerMillis), + String.valueOf(clock.millis()), + String.valueOf(requestedAmount), + String.valueOf(useTokens) + ); + } + + private record TokenBucket(long bucketSize, long leakRatePerMillis, long spaceRemaining, long lastUpdateTimeMillis) { + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java index 77bcf124a..d1705e7f9 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java @@ -18,9 +18,11 @@ import javax.validation.Valid; import javax.validation.constraints.NotNull; import org.junit.jupiter.api.Test; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.util.MockUtils; +import org.whispersystems.textsecuregcm.util.MutableClock; @SuppressWarnings("unchecked") public class RateLimitersTest { @@ -30,8 +32,12 @@ public class RateLimitersTest { private final DynamicConfigurationManager dynamicConfig = MockUtils.buildMock(DynamicConfigurationManager.class, cfg -> when(cfg.getConfiguration()).thenReturn(configuration)); + private final ClusterLuaScript validateScript = mock(ClusterLuaScript.class); + private final FaultTolerantRedisCluster redisCluster = mock(FaultTolerantRedisCluster.class); + private final MutableClock clock = MockUtils.mutableClock(0); + private static final String BAD_YAML = """ limits: smsVoicePrefix: @@ -59,12 +65,12 @@ public class RateLimitersTest { public void testValidateConfigs() throws Exception { assertThrows(IllegalArgumentException.class, () -> { final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(BAD_YAML, GenericHolder.class).orElseThrow(); - final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster); + final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, validateScript, redisCluster, clock); rateLimiters.validateValuesAndConfigs(); }); final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(GOOD_YAML, GenericHolder.class).orElseThrow(); - final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster); + final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, validateScript, redisCluster, clock); rateLimiters.validateValuesAndConfigs(); } @@ -79,18 +85,22 @@ public class RateLimitersTest { new TestDescriptor[] { td1, td2, td3, tdDup }, Collections.emptyMap(), dynamicConfig, - redisCluster) {}); + validateScript, + redisCluster, + clock) {}); new BaseRateLimiters<>( new TestDescriptor[] { td1, td2, td3 }, Collections.emptyMap(), dynamicConfig, - redisCluster) {}; + validateScript, + redisCluster, + clock) {}; } @Test void testUnchangingConfiguration() { - final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster); + final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock); final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter(); final RateLimiterConfig expected = RateLimiters.For.RATE_LIMIT_RESET.defaultConfig(); assertEquals(expected, limiter.config()); @@ -109,7 +119,7 @@ public class RateLimitersTest { when(configuration.getLimits()).thenReturn(limitsConfigMap); - final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster); + final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock); final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter(); limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), initialRateLimiterConfig); @@ -137,7 +147,7 @@ public class RateLimitersTest { when(configuration.getLimits()).thenReturn(mapForDynamic); - final RateLimiters rateLimiters = new RateLimiters(mapForStatic, dynamicConfig, redisCluster); + final RateLimiters rateLimiters = new RateLimiters(mapForStatic, dynamicConfig, validateScript, redisCluster, clock); final RateLimiter limiter = rateLimiters.forDescriptor(descriptor); // test only default is present diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/BaseRedisCommandsHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/BaseRedisCommandsHandler.java new file mode 100644 index 000000000..0619ad8c9 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/BaseRedisCommandsHandler.java @@ -0,0 +1,54 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util.redis; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.whispersystems.textsecuregcm.util.redis.RedisLuaScriptSandbox.tail; + +import java.util.List; + +/** + * This class is to be extended with implementations of Redis commands as needed. + */ +public class BaseRedisCommandsHandler implements RedisCommandsHandler { + + @Override + public Object redisCommand(final String command, final List args) { + return switch (command) { + case "SET" -> { + assertTrue(args.size() > 2); + yield set(args.get(0).toString(), args.get(1).toString(), tail(args, 2)); + } + case "GET" -> { + assertEquals(1, args.size()); + yield get(args.get(0).toString()); + } + case "DEL" -> { + assertTrue(args.size() > 1); + yield del(args.get(0).toString()); + } + default -> other(command, args); + }; + } + + public Object set(final String key, final String value, final List tail) { + return "OK"; + } + + public String get(final String key) { + return null; + + } + + public int del(final String key) { + return 0; + } + + public Object other(final String command, final List args) { + return "OK"; + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/RedisCommandsHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/RedisCommandsHandler.java new file mode 100644 index 000000000..e287577df --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/RedisCommandsHandler.java @@ -0,0 +1,14 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util.redis; + +import java.util.List; + +@FunctionalInterface +public interface RedisCommandsHandler { + + Object redisCommand(String command, List args); +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/RedisLuaScriptSandbox.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/RedisLuaScriptSandbox.java new file mode 100644 index 000000000..06452e5bb --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/RedisLuaScriptSandbox.java @@ -0,0 +1,167 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util.redis; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.google.common.io.Resources; +import io.lettuce.core.ScriptOutputType; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.whispersystems.textsecuregcm.util.SystemMapper; +import party.iroiro.luajava.Lua; +import party.iroiro.luajava.lua51.Lua51; +import party.iroiro.luajava.value.ImmutableLuaValue; + +public class RedisLuaScriptSandbox { + + private static final String PREFIX = """ + function redis_call(...) + -- variable name needs to match the one used in the `L.setGlobal()` call + -- method name needs to match method name of the Java class + return proxy:redisCall(arg) + end + + function json_encode(obj) + return mapper:encode(obj) + end + + function json_decode(json) + return java.luaify(mapper:decode(json)) + end + + local redis = { call = redis_call } + local cjson = { encode = json_encode, decode = json_decode } + + """; + + private final String luaScript; + + private final ScriptOutputType scriptOutputType; + + + public static RedisLuaScriptSandbox fromResource( + final String resource, + final ScriptOutputType scriptOutputType) { + try { + final String src = Resources.toString(Resources.getResource(resource), StandardCharsets.UTF_8); + return new RedisLuaScriptSandbox(src, scriptOutputType); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public RedisLuaScriptSandbox(final String luaScript, final ScriptOutputType scriptOutputType) { + this.luaScript = luaScript; + this.scriptOutputType = scriptOutputType; + } + + public Object execute( + final List keys, + final List args, + final RedisCommandsHandler redisCallsHandler) { + + try (final Lua lua = new Lua51()) { + lua.openLibraries(); + final RedisLuaProxy proxy = new RedisLuaProxy(redisCallsHandler); + lua.push(MapperLuaProxy.INSTANCE, Lua.Conversion.FULL); + lua.setGlobal("mapper"); + lua.push(proxy, Lua.Conversion.FULL); + lua.setGlobal("proxy"); + lua.push(keys, Lua.Conversion.FULL); + lua.setGlobal("KEYS"); + lua.push(args, Lua.Conversion.FULL); + lua.setGlobal("ARGV"); + final Lua.LuaError executionResult = lua.run(PREFIX + luaScript); + assertEquals("OK", executionResult.name(), "Runtime error during Lua script execution"); + return adaptOutputResult(lua.get()); + } + } + + protected Object adaptOutputResult(final Object luaObject) { + if (luaObject instanceof ImmutableLuaValue luaValue) { + final Object javaValue = luaValue.toJavaObject(); + // validate expected script output type + switch (scriptOutputType) { + case INTEGER -> assertTrue(javaValue instanceof Double); // lua number is always Double + case STATUS -> assertTrue(javaValue instanceof String); + case BOOLEAN -> assertTrue(javaValue instanceof Boolean); + }; + if (javaValue instanceof Double d) { + return d.longValue(); + } + if (javaValue instanceof String s) { + return s; + } + if (javaValue instanceof Boolean b) { + return b; + } + if (javaValue == null) { + return null; + } + throw new IllegalStateException("unexpected script result java type: " + javaValue.getClass().getName()); + } + throw new IllegalStateException("unexpected script result lua type: " + luaObject.getClass().getName()); + } + + public static List tail(final List list, final int fromIdx) { + return fromIdx < list.size() ? list.subList(fromIdx, list.size()) : Collections.emptyList(); + } + + public static final class MapperLuaProxy { + + public static final MapperLuaProxy INSTANCE = new MapperLuaProxy(); + + public String encode(final Map obj) { + try { + return SystemMapper.jsonMapper().writeValueAsString(obj); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + public Map decode(final Object json) { + try { + //noinspection unchecked + return SystemMapper.jsonMapper().readValue(json.toString(), Map.class); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + } + + /** + * Instances of this class are passed to the Lua scripting engine + * and serve as a stubs for the calls to `redis.call()`. + * + * @see #PREFIX + */ + public static final class RedisLuaProxy { + + private final RedisCommandsHandler handler; + + public RedisLuaProxy(final RedisCommandsHandler handler) { + this.handler = handler; + } + + /** + * Method name needs to match the one from the {@link #PREFIX} code. + * The method is getting called from the Lua scripting engine. + */ + @SuppressWarnings("unused") + public Object redisCall(final List args) { + assertFalse(args.isEmpty(), "`redis.call()` in Lua script invoked without arguments"); + assertTrue(args.get(0) instanceof String, "first argument to `redis.call()` must be of type `String`"); + return handler.redisCommand((String) args.get(0), tail(args, 1)); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/SimpleCacheCommandsHandler.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/SimpleCacheCommandsHandler.java new file mode 100644 index 000000000..290da8e48 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/redis/SimpleCacheCommandsHandler.java @@ -0,0 +1,73 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util.redis; + +import java.time.Clock; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +public class SimpleCacheCommandsHandler extends BaseRedisCommandsHandler { + + public record Entry(String value, long expirationEpochMillis) { + } + + private final Map cache = new ConcurrentHashMap<>(); + + private final Clock clock; + + + public SimpleCacheCommandsHandler(final Clock clock) { + this.clock = clock; + } + + @Override + public Object set(final String key, final String value, final List tail) { + cache.put(key, new Entry(value, resolveExpirationEpochMillis(tail))); + return "OK"; + } + + @Override + public String get(final String key) { + final Entry entry = cache.get(key); + if (entry == null) { + return null; + } + if (entry.expirationEpochMillis() < clock.millis()) { + del(key); + return null; + } + return entry.value(); + } + + @Override + public int del(final String key) { + return cache.remove(key) != null ? 1 : 0; + } + + protected long resolveExpirationEpochMillis(final List args) { + for (int i = 0; i < args.size() - 1; i++) { + final long currentTimeMillis = clock.millis(); + final String param = args.get(i).toString(); + final String value = args.get(i + 1).toString(); + switch (param) { + case "EX" -> { + return currentTimeMillis + Double.valueOf(value).longValue() * 1000; + } + case "PX" -> { + return currentTimeMillis + Double.valueOf(value).longValue(); + } + case "EXAT" -> { + return Double.valueOf(value).longValue() * 1000; + } + case "PXAT" -> { + return Double.valueOf(value).longValue(); + } + } + } + return Long.MAX_VALUE; + } +}