Moving RateLimiter logic to Redis Lua and adding async API
This commit is contained in:
parent
46fef4082c
commit
4c85e7ba66
1
pom.xml
1
pom.xml
|
@ -64,6 +64,7 @@
|
|||
<lettuce.version>6.2.1.RELEASE</lettuce.version>
|
||||
<libphonenumber.version>8.12.54</libphonenumber.version>
|
||||
<logstash.logback.version>7.2</logstash.logback.version>
|
||||
<luajava.version>3.3.0</luajava.version>
|
||||
<micrometer.version>1.10.3</micrometer.version>
|
||||
<mockito.version>4.11.0</mockito.version>
|
||||
<netty.version>4.1.82.Final</netty.version>
|
||||
|
|
|
@ -172,6 +172,26 @@
|
|||
</exclusions>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>party.iroiro.luajava</groupId>
|
||||
<artifactId>luajava</artifactId>
|
||||
<version>${luajava.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>party.iroiro.luajava</groupId>
|
||||
<artifactId>lua51</artifactId>
|
||||
<version>${luajava.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>party.iroiro.luajava</groupId>
|
||||
<artifactId>lua51-platform</artifactId>
|
||||
<version>${luajava.version}</version>
|
||||
<classifier>natives-desktop</classifier>
|
||||
<scope>runtime</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.eclipse.jetty.websocket</groupId>
|
||||
<artifactId>websocket-api</artifactId>
|
||||
|
|
|
@ -7,7 +7,11 @@ package org.whispersystems.textsecuregcm.limits;
|
|||
|
||||
import static java.util.Objects.requireNonNull;
|
||||
|
||||
import io.lettuce.core.ScriptOutputType;
|
||||
import java.io.IOException;
|
||||
import java.io.UncheckedIOException;
|
||||
import java.lang.invoke.MethodHandles;
|
||||
import java.time.Clock;
|
||||
import java.util.Arrays;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
@ -17,6 +21,7 @@ import org.apache.commons.lang3.tuple.Pair;
|
|||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||
|
||||
|
@ -33,12 +38,14 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
|
|||
final T[] values,
|
||||
final Map<String, RateLimiterConfig> configs,
|
||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||
final FaultTolerantRedisCluster cacheCluster) {
|
||||
final ClusterLuaScript validateScript,
|
||||
final FaultTolerantRedisCluster cacheCluster,
|
||||
final Clock clock) {
|
||||
this.configs = configs;
|
||||
this.rateLimiterByDescriptor = Arrays.stream(values)
|
||||
.map(descriptor -> Pair.of(
|
||||
descriptor,
|
||||
createForDescriptor(descriptor, configs, dynamicConfigurationManager, cacheCluster)))
|
||||
createForDescriptor(descriptor, configs, dynamicConfigurationManager, validateScript, cacheCluster, clock)))
|
||||
.collect(Collectors.toUnmodifiableMap(Pair::getKey, Pair::getValue));
|
||||
}
|
||||
|
||||
|
@ -62,11 +69,22 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
|
|||
}
|
||||
}
|
||||
|
||||
protected static ClusterLuaScript defaultScript(final FaultTolerantRedisCluster cacheCluster) {
|
||||
try {
|
||||
return ClusterLuaScript.fromResource(
|
||||
cacheCluster, "lua/validate_rate_limit.lua", ScriptOutputType.INTEGER);
|
||||
} catch (final IOException e) {
|
||||
throw new UncheckedIOException("Failed to load rate limit validation script", e);
|
||||
}
|
||||
}
|
||||
|
||||
private static RateLimiter createForDescriptor(
|
||||
final RateLimiterDescriptor descriptor,
|
||||
final Map<String, RateLimiterConfig> configs,
|
||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||
final FaultTolerantRedisCluster cacheCluster) {
|
||||
final ClusterLuaScript validateScript,
|
||||
final FaultTolerantRedisCluster cacheCluster,
|
||||
final Clock clock) {
|
||||
if (descriptor.isDynamic()) {
|
||||
final Supplier<RateLimiterConfig> configResolver = () -> {
|
||||
final RateLimiterConfig config = dynamicConfigurationManager.getConfiguration().getLimits().get(descriptor.id());
|
||||
|
@ -74,9 +92,9 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
|
|||
? config
|
||||
: configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
|
||||
};
|
||||
return new DynamicRateLimiter(descriptor.id(), configResolver, cacheCluster);
|
||||
return new DynamicRateLimiter(descriptor.id(), configResolver, validateScript, cacheCluster, clock);
|
||||
}
|
||||
final RateLimiterConfig cfg = configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
|
||||
return new StaticRateLimiter(descriptor.id(), cfg, cacheCluster);
|
||||
return new StaticRateLimiter(descriptor.id(), cfg, validateScript, cacheCluster, clock);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,10 +7,13 @@ package org.whispersystems.textsecuregcm.limits;
|
|||
|
||||
import static java.util.Objects.requireNonNull;
|
||||
|
||||
import java.time.Clock;
|
||||
import java.util.concurrent.CompletionStage;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.function.Supplier;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||
|
||||
public class DynamicRateLimiter implements RateLimiter {
|
||||
|
@ -19,18 +22,26 @@ public class DynamicRateLimiter implements RateLimiter {
|
|||
|
||||
private final Supplier<RateLimiterConfig> configResolver;
|
||||
|
||||
private final ClusterLuaScript validateScript;
|
||||
|
||||
private final FaultTolerantRedisCluster cluster;
|
||||
|
||||
private final Clock clock;
|
||||
|
||||
private final AtomicReference<Pair<RateLimiterConfig, RateLimiter>> currentHolder = new AtomicReference<>();
|
||||
|
||||
|
||||
public DynamicRateLimiter(
|
||||
final String name,
|
||||
final Supplier<RateLimiterConfig> configResolver,
|
||||
final FaultTolerantRedisCluster cluster) {
|
||||
final ClusterLuaScript validateScript,
|
||||
final FaultTolerantRedisCluster cluster,
|
||||
final Clock clock) {
|
||||
this.name = requireNonNull(name);
|
||||
this.configResolver = requireNonNull(configResolver);
|
||||
this.validateScript = requireNonNull(validateScript);
|
||||
this.cluster = requireNonNull(cluster);
|
||||
this.clock = requireNonNull(clock);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -38,16 +49,31 @@ public class DynamicRateLimiter implements RateLimiter {
|
|||
current().getRight().validate(key, amount);
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletionStage<Void> validateAsync(final String key, final int amount) {
|
||||
return current().getRight().validateAsync(key, amount);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasAvailablePermits(final String key, final int permits) {
|
||||
return current().getRight().hasAvailablePermits(key, permits);
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletionStage<Boolean> hasAvailablePermitsAsync(final String key, final int amount) {
|
||||
return current().getRight().hasAvailablePermitsAsync(key, amount);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void clear(final String key) {
|
||||
current().getRight().clear(key);
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletionStage<Void> clearAsync(final String key) {
|
||||
return current().getRight().clearAsync(key);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RateLimiterConfig config() {
|
||||
return current().getLeft();
|
||||
|
@ -57,7 +83,7 @@ public class DynamicRateLimiter implements RateLimiter {
|
|||
final RateLimiterConfig cfg = configResolver.get();
|
||||
return currentHolder.updateAndGet(p -> p != null && p.getLeft().equals(cfg)
|
||||
? p
|
||||
: Pair.of(cfg, new StaticRateLimiter(name, cfg, cluster))
|
||||
: Pair.of(cfg, new StaticRateLimiter(name, cfg, validateScript, cluster, clock))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,99 +0,0 @@
|
|||
/*
|
||||
* Copyright 2013-2020 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.limits;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Duration;
|
||||
|
||||
public class LeakyBucket {
|
||||
|
||||
private final int bucketSize;
|
||||
private final double leakRatePerMillis;
|
||||
|
||||
private int spaceRemaining;
|
||||
private long lastUpdateTimeMillis;
|
||||
|
||||
public LeakyBucket(int bucketSize, double leakRatePerMillis) {
|
||||
this(bucketSize, leakRatePerMillis, bucketSize, System.currentTimeMillis());
|
||||
}
|
||||
|
||||
private LeakyBucket(int bucketSize, double leakRatePerMillis, int spaceRemaining, long lastUpdateTimeMillis) {
|
||||
this.bucketSize = bucketSize;
|
||||
this.leakRatePerMillis = leakRatePerMillis;
|
||||
this.spaceRemaining = spaceRemaining;
|
||||
this.lastUpdateTimeMillis = lastUpdateTimeMillis;
|
||||
}
|
||||
|
||||
public boolean add(int amount) {
|
||||
this.spaceRemaining = getUpdatedSpaceRemaining();
|
||||
this.lastUpdateTimeMillis = System.currentTimeMillis();
|
||||
|
||||
if (this.spaceRemaining >= amount) {
|
||||
this.spaceRemaining -= amount;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private int getUpdatedSpaceRemaining() {
|
||||
long elapsedTime = System.currentTimeMillis() - this.lastUpdateTimeMillis;
|
||||
|
||||
return Math.min(this.bucketSize,
|
||||
(int)Math.floor(this.spaceRemaining + (elapsedTime * this.leakRatePerMillis)));
|
||||
}
|
||||
|
||||
public Duration getTimeUntilSpaceAvailable(int amount) {
|
||||
int currentSpaceRemaining = getUpdatedSpaceRemaining();
|
||||
if (currentSpaceRemaining >= amount) {
|
||||
return Duration.ZERO;
|
||||
} else if (amount > this.bucketSize) {
|
||||
// This shouldn't happen today but if so we should bubble this to the clients somehow
|
||||
throw new IllegalArgumentException("Requested permits exceed maximum bucket size");
|
||||
} else {
|
||||
return Duration.ofMillis((long)Math.ceil((double)(amount - currentSpaceRemaining) / this.leakRatePerMillis));
|
||||
}
|
||||
}
|
||||
|
||||
public String serialize(ObjectMapper mapper) throws JsonProcessingException {
|
||||
return mapper.writeValueAsString(new LeakyBucketEntity(bucketSize, leakRatePerMillis, spaceRemaining, lastUpdateTimeMillis));
|
||||
}
|
||||
|
||||
public static LeakyBucket fromSerialized(ObjectMapper mapper, String serialized) throws IOException {
|
||||
LeakyBucketEntity entity = mapper.readValue(serialized, LeakyBucketEntity.class);
|
||||
|
||||
return new LeakyBucket(entity.bucketSize, entity.leakRatePerMillis,
|
||||
entity.spaceRemaining, entity.lastUpdateTimeMillis);
|
||||
}
|
||||
|
||||
private static class LeakyBucketEntity {
|
||||
@JsonProperty
|
||||
private int bucketSize;
|
||||
|
||||
@JsonProperty
|
||||
private double leakRatePerMillis;
|
||||
|
||||
@JsonProperty
|
||||
private int spaceRemaining;
|
||||
|
||||
@JsonProperty
|
||||
private long lastUpdateTimeMillis;
|
||||
|
||||
public LeakyBucketEntity() {}
|
||||
|
||||
private LeakyBucketEntity(int bucketSize, double leakRatePerMillis,
|
||||
int spaceRemaining, long lastUpdateTimeMillis)
|
||||
{
|
||||
this.bucketSize = bucketSize;
|
||||
this.leakRatePerMillis = leakRatePerMillis;
|
||||
this.spaceRemaining = spaceRemaining;
|
||||
this.lastUpdateTimeMillis = lastUpdateTimeMillis;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,58 +0,0 @@
|
|||
/*
|
||||
* Copyright 2013 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.limits;
|
||||
|
||||
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
|
||||
|
||||
import io.lettuce.core.SetArgs;
|
||||
import io.micrometer.core.instrument.Counter;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import java.time.Duration;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||
|
||||
public class LockingRateLimiter extends StaticRateLimiter {
|
||||
|
||||
private static final RateLimitExceededException REUSABLE_RATE_LIMIT_EXCEEDED_EXCEPTION
|
||||
= new RateLimitExceededException(Duration.ZERO, true);
|
||||
|
||||
private final Counter counter;
|
||||
|
||||
|
||||
public LockingRateLimiter(
|
||||
final String name,
|
||||
final RateLimiterConfig config,
|
||||
final FaultTolerantRedisCluster cacheCluster) {
|
||||
super(name, config, cacheCluster);
|
||||
this.counter = Metrics.counter(name(getClass(), "locked"), "name", name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void validate(final String key, final int amount) throws RateLimitExceededException {
|
||||
if (!acquireLock(key)) {
|
||||
counter.increment();
|
||||
throw REUSABLE_RATE_LIMIT_EXCEEDED_EXCEPTION;
|
||||
}
|
||||
|
||||
try {
|
||||
super.validate(key, amount);
|
||||
} finally {
|
||||
releaseLock(key);
|
||||
}
|
||||
}
|
||||
|
||||
private void releaseLock(final String key) {
|
||||
cacheCluster.useCluster(connection -> connection.sync().del(getLockName(key)));
|
||||
}
|
||||
|
||||
private boolean acquireLock(final String key) {
|
||||
return cacheCluster.withCluster(connection -> connection.sync().set(getLockName(key), "L", SetArgs.Builder.nx().ex(10))) != null;
|
||||
}
|
||||
|
||||
private String getLockName(final String key) {
|
||||
return "leaky_lock::" + name + "::" + key;
|
||||
}
|
||||
}
|
|
@ -6,16 +6,23 @@
|
|||
package org.whispersystems.textsecuregcm.limits;
|
||||
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletionStage;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
|
||||
public interface RateLimiter {
|
||||
|
||||
void validate(String key, int amount) throws RateLimitExceededException;
|
||||
|
||||
CompletionStage<Void> validateAsync(String key, int amount);
|
||||
|
||||
boolean hasAvailablePermits(String key, int permits);
|
||||
|
||||
CompletionStage<Boolean> hasAvailablePermitsAsync(String key, int amount);
|
||||
|
||||
void clear(String key);
|
||||
|
||||
CompletionStage<Void> clearAsync(String key);
|
||||
|
||||
RateLimiterConfig config();
|
||||
|
||||
default void validate(final String key) throws RateLimitExceededException {
|
||||
|
@ -30,14 +37,34 @@ public interface RateLimiter {
|
|||
validate(srcAccountUuid.toString() + "__" + dstAccountUuid.toString());
|
||||
}
|
||||
|
||||
default CompletionStage<Void> validateAsync(final String key) {
|
||||
return validateAsync(key, 1);
|
||||
}
|
||||
|
||||
default CompletionStage<Void> validateAsync(final UUID accountUuid) {
|
||||
return validateAsync(accountUuid.toString());
|
||||
}
|
||||
|
||||
default CompletionStage<Void> validateAsync(final UUID srcAccountUuid, final UUID dstAccountUuid) {
|
||||
return validateAsync(srcAccountUuid.toString() + "__" + dstAccountUuid.toString());
|
||||
}
|
||||
|
||||
default boolean hasAvailablePermits(final UUID accountUuid, final int permits) {
|
||||
return hasAvailablePermits(accountUuid.toString(), permits);
|
||||
}
|
||||
|
||||
default CompletionStage<Boolean> hasAvailablePermitsAsync(final UUID accountUuid, final int permits) {
|
||||
return hasAvailablePermitsAsync(accountUuid.toString(), permits);
|
||||
}
|
||||
|
||||
default void clear(final UUID accountUuid) {
|
||||
clear(accountUuid.toString());
|
||||
}
|
||||
|
||||
default CompletionStage<Void> clearAsync(final UUID accountUuid) {
|
||||
return clearAsync(accountUuid.toString());
|
||||
}
|
||||
|
||||
/**
|
||||
* If the wrapped {@code validate()} call throws a {@link RateLimitExceededException}, it will adapt it to ensure that
|
||||
* {@link RateLimitExceededException#isLegacy()} returns {@code true}
|
||||
|
|
|
@ -6,8 +6,10 @@ package org.whispersystems.textsecuregcm.limits;
|
|||
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import java.time.Clock;
|
||||
import java.util.Map;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||
|
||||
|
@ -103,7 +105,8 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
|
|||
final Map<String, RateLimiterConfig> configs,
|
||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||
final FaultTolerantRedisCluster cacheCluster) {
|
||||
final RateLimiters rateLimiters = new RateLimiters(configs, dynamicConfigurationManager, cacheCluster);
|
||||
final RateLimiters rateLimiters = new RateLimiters(
|
||||
configs, dynamicConfigurationManager, defaultScript(cacheCluster), cacheCluster, Clock.systemUTC());
|
||||
rateLimiters.validateValuesAndConfigs();
|
||||
return rateLimiters;
|
||||
}
|
||||
|
@ -112,8 +115,10 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
|
|||
RateLimiters(
|
||||
final Map<String, RateLimiterConfig> configs,
|
||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||
final FaultTolerantRedisCluster cacheCluster) {
|
||||
super(For.values(), configs, dynamicConfigurationManager, cacheCluster);
|
||||
final ClusterLuaScript validateScript,
|
||||
final FaultTolerantRedisCluster cacheCluster,
|
||||
final Clock clock) {
|
||||
super(For.values(), configs, dynamicConfigurationManager, validateScript, cacheCluster, clock);
|
||||
}
|
||||
|
||||
public RateLimiter getAllocateDeviceLimiter() {
|
||||
|
|
|
@ -5,60 +5,96 @@
|
|||
package org.whispersystems.textsecuregcm.limits;
|
||||
|
||||
import static java.util.Objects.requireNonNull;
|
||||
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
|
||||
import static java.util.concurrent.CompletableFuture.completedFuture;
|
||||
import static java.util.concurrent.CompletableFuture.failedFuture;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import io.micrometer.core.instrument.Counter;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import java.io.IOException;
|
||||
import java.time.Clock;
|
||||
import java.time.Duration;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletionStage;
|
||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
import org.whispersystems.textsecuregcm.util.Util;
|
||||
|
||||
public class StaticRateLimiter implements RateLimiter {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(StaticRateLimiter.class);
|
||||
|
||||
protected final String name;
|
||||
|
||||
private final RateLimiterConfig config;
|
||||
|
||||
protected final FaultTolerantRedisCluster cacheCluster;
|
||||
|
||||
private final Counter counter;
|
||||
|
||||
private final ClusterLuaScript validateScript;
|
||||
|
||||
private final FaultTolerantRedisCluster cacheCluster;
|
||||
|
||||
private final Clock clock;
|
||||
|
||||
|
||||
public StaticRateLimiter(
|
||||
final String name,
|
||||
final RateLimiterConfig config,
|
||||
final FaultTolerantRedisCluster cacheCluster) {
|
||||
final ClusterLuaScript validateScript,
|
||||
final FaultTolerantRedisCluster cacheCluster,
|
||||
final Clock clock) {
|
||||
this.name = requireNonNull(name);
|
||||
this.config = requireNonNull(config);
|
||||
this.validateScript = requireNonNull(validateScript);
|
||||
this.cacheCluster = requireNonNull(cacheCluster);
|
||||
this.counter = Metrics.counter(name(getClass(), "exceeded"), "name", name);
|
||||
this.clock = requireNonNull(clock);
|
||||
this.counter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "name", name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void validate(final String key, final int amount) throws RateLimitExceededException {
|
||||
final LeakyBucket bucket = getBucket(key);
|
||||
if (bucket.add(amount)) {
|
||||
setBucket(key, bucket);
|
||||
} else {
|
||||
final long deficitPermitsAmount = executeValidateScript(key, amount, true);
|
||||
if (deficitPermitsAmount > 0) {
|
||||
counter.increment();
|
||||
throw new RateLimitExceededException(bucket.getTimeUntilSpaceAvailable(amount), true);
|
||||
final Duration retryAfter = Duration.ofMillis(
|
||||
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
|
||||
throw new RateLimitExceededException(retryAfter, true);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasAvailablePermits(final String key, final int permits) {
|
||||
return getBucket(key).getTimeUntilSpaceAvailable(permits).equals(Duration.ZERO);
|
||||
public CompletionStage<Void> validateAsync(final String key, final int amount) {
|
||||
return executeValidateScriptAsync(key, amount, true)
|
||||
.thenCompose(deficitPermitsAmount -> {
|
||||
if (deficitPermitsAmount == 0) {
|
||||
return completedFuture(null);
|
||||
}
|
||||
counter.increment();
|
||||
final Duration retryAfter = Duration.ofMillis(
|
||||
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
|
||||
return failedFuture(new RateLimitExceededException(retryAfter, true));
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasAvailablePermits(final String key, final int amount) {
|
||||
final long deficitPermitsAmount = executeValidateScript(key, amount, false);
|
||||
return deficitPermitsAmount == 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletionStage<Boolean> hasAvailablePermitsAsync(final String key, final int amount) {
|
||||
return executeValidateScriptAsync(key, amount, false)
|
||||
.thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void clear(final String key) {
|
||||
cacheCluster.useCluster(connection -> connection.sync().del(getBucketName(key)));
|
||||
cacheCluster.useCluster(connection -> connection.sync().del(bucketName(key)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletionStage<Void> clearAsync(final String key) {
|
||||
return cacheCluster.withCluster(connection -> connection.async().del(bucketName(key)))
|
||||
.thenRun(Util.NOOP);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -66,33 +102,31 @@ public class StaticRateLimiter implements RateLimiter {
|
|||
return config;
|
||||
}
|
||||
|
||||
private void setBucket(final String key, final LeakyBucket bucket) {
|
||||
try {
|
||||
final String serialized = bucket.serialize(SystemMapper.jsonMapper());
|
||||
cacheCluster.useCluster(connection -> connection.sync().setex(
|
||||
getBucketName(key),
|
||||
(int) Math.ceil((config.bucketSize() / config.leakRatePerMillis()) / 1000),
|
||||
serialized));
|
||||
} catch (final JsonProcessingException e) {
|
||||
throw new IllegalArgumentException(e);
|
||||
}
|
||||
private long executeValidateScript(final String key, final int amount, final boolean applyChanges) {
|
||||
final List<String> keys = List.of(bucketName(key));
|
||||
final List<String> arguments = List.of(
|
||||
String.valueOf(config.bucketSize()),
|
||||
String.valueOf(config.leakRatePerMillis()),
|
||||
String.valueOf(clock.millis()),
|
||||
String.valueOf(amount),
|
||||
String.valueOf(applyChanges)
|
||||
);
|
||||
return (Long) validateScript.execute(keys, arguments);
|
||||
}
|
||||
|
||||
private LeakyBucket getBucket(final String key) {
|
||||
try {
|
||||
final String serialized = cacheCluster.withCluster(connection -> connection.sync().get(getBucketName(key)));
|
||||
|
||||
if (serialized != null) {
|
||||
return LeakyBucket.fromSerialized(SystemMapper.jsonMapper(), serialized);
|
||||
}
|
||||
} catch (final IOException e) {
|
||||
logger.warn("Deserialization error", e);
|
||||
}
|
||||
|
||||
return new LeakyBucket(config.bucketSize(), config.leakRatePerMillis());
|
||||
private CompletionStage<Long> executeValidateScriptAsync(final String key, final int amount, final boolean applyChanges) {
|
||||
final List<String> keys = List.of(bucketName(key));
|
||||
final List<String> arguments = List.of(
|
||||
String.valueOf(config.bucketSize()),
|
||||
String.valueOf(config.leakRatePerMillis()),
|
||||
String.valueOf(clock.millis()),
|
||||
String.valueOf(amount),
|
||||
String.valueOf(applyChanges)
|
||||
);
|
||||
return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o);
|
||||
}
|
||||
|
||||
private String getBucketName(final String key) {
|
||||
private String bucketName(final String key) {
|
||||
return "leaky_bucket::" + name + "::" + key;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
-- The script encapsulates the logic of a token bucket rate limiter.
|
||||
-- Two types of operations are supported: 'check-only' and 'use-if-available' (controlled by the 'useTokens' arg).
|
||||
-- Both operations take in rate limiter configuration parameters and the requested amount of tokens.
|
||||
-- Both operations return 0, if the rate limiter has enough tokens to cover the requested amount,
|
||||
-- and the deficit amount otherwise.
|
||||
-- However, 'check-only' operation doesn't modify the bucket, while 'use-if-available' (if successful)
|
||||
-- reduces the amount of available tokens by the requested amount.
|
||||
|
||||
local bucketId = KEYS[1]
|
||||
|
||||
local bucketSize = tonumber(ARGV[1])
|
||||
local refillRatePerMillis = tonumber(ARGV[2])
|
||||
local currentTimeMillis = tonumber(ARGV[3])
|
||||
local requestedAmount = tonumber(ARGV[4])
|
||||
local useTokens = ARGV[5] and string.lower(ARGV[5]) == "true"
|
||||
|
||||
local tokenBucketJson = redis.call("GET", bucketId)
|
||||
local tokenBucket
|
||||
local changesMade = false
|
||||
|
||||
if tokenBucketJson then
|
||||
tokenBucket = cjson.decode(tokenBucketJson)
|
||||
else
|
||||
tokenBucket = {
|
||||
["bucketSize"] = bucketSize,
|
||||
["leakRatePerMillis"] = refillRatePerMillis,
|
||||
["spaceRemaining"] = bucketSize,
|
||||
["lastUpdateTimeMillis"] = currentTimeMillis
|
||||
}
|
||||
end
|
||||
|
||||
-- this can happen if rate limiter configuration has changed while the key is still in Redis
|
||||
if tokenBucket["bucketSize"] ~= bucketSize or tokenBucket["leakRatePerMillis"] ~= refillRatePerMillis then
|
||||
tokenBucket["bucketSize"] = bucketSize
|
||||
tokenBucket["leakRatePerMillis"] = refillRatePerMillis
|
||||
changesMade = true
|
||||
end
|
||||
|
||||
local elapsedTime = currentTimeMillis - tokenBucket["lastUpdateTimeMillis"]
|
||||
local availableAmount = math.min(
|
||||
tokenBucket["bucketSize"],
|
||||
math.floor(tokenBucket["spaceRemaining"] + (elapsedTime * tokenBucket["leakRatePerMillis"]))
|
||||
)
|
||||
|
||||
if availableAmount >= requestedAmount then
|
||||
if useTokens then
|
||||
tokenBucket["spaceRemaining"] = availableAmount - requestedAmount
|
||||
tokenBucket["lastUpdateTimeMillis"] = currentTimeMillis
|
||||
changesMade = true
|
||||
end
|
||||
if changesMade then
|
||||
local tokensUsed = tokenBucket["bucketSize"] - tokenBucket["spaceRemaining"]
|
||||
-- Storing a 'full' bucket is equivalent of not storing any state at all
|
||||
-- (in which case a bucket will be just initialized from the input configs as a 'full' one).
|
||||
-- For this reason, we either set an expiration time on the record (calculated to let the bucket fully replenish)
|
||||
-- or we just delete the key if the bucket is full.
|
||||
if tokensUsed > 0 then
|
||||
local ttlMillis = math.ceil(tokensUsed / tokenBucket["leakRatePerMillis"])
|
||||
redis.call("SET", bucketId, cjson.encode(tokenBucket), "PX", ttlMillis)
|
||||
else
|
||||
redis.call("DEL", bucketId)
|
||||
end
|
||||
end
|
||||
return 0
|
||||
else
|
||||
return requestedAmount - availableAmount
|
||||
end
|
|
@ -1,85 +0,0 @@
|
|||
/*
|
||||
* Copyright 2013 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.limits;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import java.io.IOException;
|
||||
import java.time.Duration;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
|
||||
class LeakyBucketTest {
|
||||
|
||||
@Test
|
||||
void testFull() {
|
||||
LeakyBucket leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
|
||||
|
||||
assertTrue(leakyBucket.add(1));
|
||||
assertTrue(leakyBucket.add(1));
|
||||
assertFalse(leakyBucket.add(1));
|
||||
|
||||
leakyBucket = new LeakyBucket(2, 1.0 / 2.0);
|
||||
|
||||
assertTrue(leakyBucket.add(2));
|
||||
assertFalse(leakyBucket.add(1));
|
||||
assertFalse(leakyBucket.add(2));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testLapseRate() throws IOException {
|
||||
ObjectMapper mapper = SystemMapper.jsonMapper();
|
||||
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(2)) + "}";
|
||||
|
||||
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
|
||||
assertTrue(leakyBucket.add(1));
|
||||
|
||||
String serializedAgain = leakyBucket.serialize(mapper);
|
||||
LeakyBucket leakyBucketAgain = LeakyBucket.fromSerialized(mapper, serializedAgain);
|
||||
|
||||
assertFalse(leakyBucketAgain.add(1));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testLapseShort() throws Exception {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
|
||||
|
||||
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
|
||||
assertFalse(leakyBucket.add(1));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testGetTimeUntilSpaceAvailable() throws Exception {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
|
||||
{
|
||||
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":2,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
|
||||
|
||||
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
|
||||
|
||||
assertEquals(Duration.ZERO, leakyBucket.getTimeUntilSpaceAvailable(1));
|
||||
assertThrows(IllegalArgumentException.class, () -> leakyBucket.getTimeUntilSpaceAvailable(5000));
|
||||
}
|
||||
|
||||
{
|
||||
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
|
||||
|
||||
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
|
||||
|
||||
Duration timeUntilSpaceAvailable = leakyBucket.getTimeUntilSpaceAvailable(1);
|
||||
|
||||
// TODO Refactor LeakyBucket to be more test-friendly and accept a Clock
|
||||
assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMillis(119_000)) > 0);
|
||||
assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMinutes(2)) <= 0);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,147 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.limits;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import io.lettuce.core.ScriptOutputType;
|
||||
import java.time.Clock;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
|
||||
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||
import org.whispersystems.textsecuregcm.util.MockUtils;
|
||||
import org.whispersystems.textsecuregcm.util.MutableClock;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
import org.whispersystems.textsecuregcm.util.redis.RedisLuaScriptSandbox;
|
||||
import org.whispersystems.textsecuregcm.util.redis.SimpleCacheCommandsHandler;
|
||||
|
||||
public class RateLimitersLuaScriptTest {
|
||||
|
||||
@RegisterExtension
|
||||
private static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
|
||||
|
||||
private final DynamicConfiguration configuration = mock(DynamicConfiguration.class);
|
||||
|
||||
private final MutableClock clock = MockUtils.mutableClock(0);
|
||||
|
||||
private final RedisLuaScriptSandbox sandbox = RedisLuaScriptSandbox.fromResource(
|
||||
"lua/validate_rate_limit.lua",
|
||||
ScriptOutputType.INTEGER);
|
||||
|
||||
private final SimpleCacheCommandsHandler redisCommandsHandler = new SimpleCacheCommandsHandler(clock);
|
||||
|
||||
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfig =
|
||||
MockUtils.buildMock(DynamicConfigurationManager.class, cfg -> when(cfg.getConfiguration()).thenReturn(configuration));
|
||||
|
||||
@Test
|
||||
public void testWithEmbeddedRedis() throws Exception {
|
||||
final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION;
|
||||
final FaultTolerantRedisCluster redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster();
|
||||
final RateLimiters limiters = new RateLimiters(
|
||||
Map.of(descriptor.id(), new RateLimiterConfig(60, 60)),
|
||||
dynamicConfig,
|
||||
RateLimiters.defaultScript(redisCluster),
|
||||
redisCluster,
|
||||
Clock.systemUTC());
|
||||
|
||||
final RateLimiter rateLimiter = limiters.forDescriptor(descriptor);
|
||||
rateLimiter.validate("test", 25);
|
||||
rateLimiter.validate("test", 25);
|
||||
assertThrows(Exception.class, () -> rateLimiter.validate("test", 25));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLuaBucketConfigurationUpdates() throws Exception {
|
||||
final String key = "key1";
|
||||
clock.setTimeMillis(0);
|
||||
long result = (long) sandbox.execute(
|
||||
List.of(key),
|
||||
scriptArgs(1000, 1, 1, true),
|
||||
redisCommandsHandler
|
||||
);
|
||||
assertEquals(0L, result);
|
||||
assertEquals(1000L, decodeBucket(key).orElseThrow().bucketSize);
|
||||
|
||||
// now making a check-only call, but changing the bucket size
|
||||
result = (long) sandbox.execute(
|
||||
List.of(key),
|
||||
scriptArgs(2000, 1, 1, false),
|
||||
redisCommandsHandler
|
||||
);
|
||||
assertEquals(0L, result);
|
||||
assertEquals(2000L, decodeBucket(key).orElseThrow().bucketSize);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLuaUpdatesTokenBucket() throws Exception {
|
||||
final String key = "key1";
|
||||
clock.setTimeMillis(0);
|
||||
long result = (long) sandbox.execute(
|
||||
List.of(key),
|
||||
scriptArgs(1000, 1, 200, true),
|
||||
redisCommandsHandler
|
||||
);
|
||||
assertEquals(0L, result);
|
||||
assertEquals(800L, decodeBucket(key).orElseThrow().spaceRemaining);
|
||||
|
||||
// 50 tokens replenished, acquiring 100 more, should end up with 750 available
|
||||
clock.setTimeMillis(50);
|
||||
result = (long) sandbox.execute(
|
||||
List.of(key),
|
||||
scriptArgs(1000, 1, 100, true),
|
||||
redisCommandsHandler
|
||||
);
|
||||
assertEquals(0L, result);
|
||||
assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining);
|
||||
|
||||
// now checking without an update, should not affect the count
|
||||
result = (long) sandbox.execute(
|
||||
List.of(key),
|
||||
scriptArgs(1000, 1, 100, false),
|
||||
redisCommandsHandler
|
||||
);
|
||||
assertEquals(0L, result);
|
||||
assertEquals(750L, decodeBucket(key).orElseThrow().spaceRemaining);
|
||||
}
|
||||
|
||||
private Optional<TokenBucket> decodeBucket(final String key) {
|
||||
try {
|
||||
final String json = redisCommandsHandler.get(key);
|
||||
return json == null
|
||||
? Optional.empty()
|
||||
: Optional.of(SystemMapper.jsonMapper().readValue(json, TokenBucket.class));
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
private List<String> scriptArgs(
|
||||
final long bucketSize,
|
||||
final long ratePerMillis,
|
||||
final long requestedAmount,
|
||||
final boolean useTokens) {
|
||||
return List.of(
|
||||
String.valueOf(bucketSize),
|
||||
String.valueOf(ratePerMillis),
|
||||
String.valueOf(clock.millis()),
|
||||
String.valueOf(requestedAmount),
|
||||
String.valueOf(useTokens)
|
||||
);
|
||||
}
|
||||
|
||||
private record TokenBucket(long bucketSize, long leakRatePerMillis, long spaceRemaining, long lastUpdateTimeMillis) {
|
||||
}
|
||||
}
|
|
@ -18,9 +18,11 @@ import javax.validation.Valid;
|
|||
import javax.validation.constraints.NotNull;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||
import org.whispersystems.textsecuregcm.util.MockUtils;
|
||||
import org.whispersystems.textsecuregcm.util.MutableClock;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public class RateLimitersTest {
|
||||
|
@ -30,8 +32,12 @@ public class RateLimitersTest {
|
|||
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfig =
|
||||
MockUtils.buildMock(DynamicConfigurationManager.class, cfg -> when(cfg.getConfiguration()).thenReturn(configuration));
|
||||
|
||||
private final ClusterLuaScript validateScript = mock(ClusterLuaScript.class);
|
||||
|
||||
private final FaultTolerantRedisCluster redisCluster = mock(FaultTolerantRedisCluster.class);
|
||||
|
||||
private final MutableClock clock = MockUtils.mutableClock(0);
|
||||
|
||||
private static final String BAD_YAML = """
|
||||
limits:
|
||||
smsVoicePrefix:
|
||||
|
@ -59,12 +65,12 @@ public class RateLimitersTest {
|
|||
public void testValidateConfigs() throws Exception {
|
||||
assertThrows(IllegalArgumentException.class, () -> {
|
||||
final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(BAD_YAML, GenericHolder.class).orElseThrow();
|
||||
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster);
|
||||
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, validateScript, redisCluster, clock);
|
||||
rateLimiters.validateValuesAndConfigs();
|
||||
});
|
||||
|
||||
final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(GOOD_YAML, GenericHolder.class).orElseThrow();
|
||||
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster);
|
||||
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, validateScript, redisCluster, clock);
|
||||
rateLimiters.validateValuesAndConfigs();
|
||||
}
|
||||
|
||||
|
@ -79,18 +85,22 @@ public class RateLimitersTest {
|
|||
new TestDescriptor[] { td1, td2, td3, tdDup },
|
||||
Collections.emptyMap(),
|
||||
dynamicConfig,
|
||||
redisCluster) {});
|
||||
validateScript,
|
||||
redisCluster,
|
||||
clock) {});
|
||||
|
||||
new BaseRateLimiters<>(
|
||||
new TestDescriptor[] { td1, td2, td3 },
|
||||
Collections.emptyMap(),
|
||||
dynamicConfig,
|
||||
redisCluster) {};
|
||||
validateScript,
|
||||
redisCluster,
|
||||
clock) {};
|
||||
}
|
||||
|
||||
@Test
|
||||
void testUnchangingConfiguration() {
|
||||
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster);
|
||||
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock);
|
||||
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
|
||||
final RateLimiterConfig expected = RateLimiters.For.RATE_LIMIT_RESET.defaultConfig();
|
||||
assertEquals(expected, limiter.config());
|
||||
|
@ -109,7 +119,7 @@ public class RateLimitersTest {
|
|||
|
||||
when(configuration.getLimits()).thenReturn(limitsConfigMap);
|
||||
|
||||
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster);
|
||||
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock);
|
||||
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
|
||||
|
||||
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), initialRateLimiterConfig);
|
||||
|
@ -137,7 +147,7 @@ public class RateLimitersTest {
|
|||
|
||||
when(configuration.getLimits()).thenReturn(mapForDynamic);
|
||||
|
||||
final RateLimiters rateLimiters = new RateLimiters(mapForStatic, dynamicConfig, redisCluster);
|
||||
final RateLimiters rateLimiters = new RateLimiters(mapForStatic, dynamicConfig, validateScript, redisCluster, clock);
|
||||
final RateLimiter limiter = rateLimiters.forDescriptor(descriptor);
|
||||
|
||||
// test only default is present
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.util.redis;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.whispersystems.textsecuregcm.util.redis.RedisLuaScriptSandbox.tail;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* This class is to be extended with implementations of Redis commands as needed.
|
||||
*/
|
||||
public class BaseRedisCommandsHandler implements RedisCommandsHandler {
|
||||
|
||||
@Override
|
||||
public Object redisCommand(final String command, final List<Object> args) {
|
||||
return switch (command) {
|
||||
case "SET" -> {
|
||||
assertTrue(args.size() > 2);
|
||||
yield set(args.get(0).toString(), args.get(1).toString(), tail(args, 2));
|
||||
}
|
||||
case "GET" -> {
|
||||
assertEquals(1, args.size());
|
||||
yield get(args.get(0).toString());
|
||||
}
|
||||
case "DEL" -> {
|
||||
assertTrue(args.size() > 1);
|
||||
yield del(args.get(0).toString());
|
||||
}
|
||||
default -> other(command, args);
|
||||
};
|
||||
}
|
||||
|
||||
public Object set(final String key, final String value, final List<Object> tail) {
|
||||
return "OK";
|
||||
}
|
||||
|
||||
public String get(final String key) {
|
||||
return null;
|
||||
|
||||
}
|
||||
|
||||
public int del(final String key) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
public Object other(final String command, final List<Object> args) {
|
||||
return "OK";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.util.redis;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@FunctionalInterface
|
||||
public interface RedisCommandsHandler {
|
||||
|
||||
Object redisCommand(String command, List<Object> args);
|
||||
}
|
|
@ -0,0 +1,167 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.util.redis;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.google.common.io.Resources;
|
||||
import io.lettuce.core.ScriptOutputType;
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
import party.iroiro.luajava.Lua;
|
||||
import party.iroiro.luajava.lua51.Lua51;
|
||||
import party.iroiro.luajava.value.ImmutableLuaValue;
|
||||
|
||||
public class RedisLuaScriptSandbox {
|
||||
|
||||
private static final String PREFIX = """
|
||||
function redis_call(...)
|
||||
-- variable name needs to match the one used in the `L.setGlobal()` call
|
||||
-- method name needs to match method name of the Java class
|
||||
return proxy:redisCall(arg)
|
||||
end
|
||||
|
||||
function json_encode(obj)
|
||||
return mapper:encode(obj)
|
||||
end
|
||||
|
||||
function json_decode(json)
|
||||
return java.luaify(mapper:decode(json))
|
||||
end
|
||||
|
||||
local redis = { call = redis_call }
|
||||
local cjson = { encode = json_encode, decode = json_decode }
|
||||
|
||||
""";
|
||||
|
||||
private final String luaScript;
|
||||
|
||||
private final ScriptOutputType scriptOutputType;
|
||||
|
||||
|
||||
public static RedisLuaScriptSandbox fromResource(
|
||||
final String resource,
|
||||
final ScriptOutputType scriptOutputType) {
|
||||
try {
|
||||
final String src = Resources.toString(Resources.getResource(resource), StandardCharsets.UTF_8);
|
||||
return new RedisLuaScriptSandbox(src, scriptOutputType);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public RedisLuaScriptSandbox(final String luaScript, final ScriptOutputType scriptOutputType) {
|
||||
this.luaScript = luaScript;
|
||||
this.scriptOutputType = scriptOutputType;
|
||||
}
|
||||
|
||||
public Object execute(
|
||||
final List<String> keys,
|
||||
final List<String> args,
|
||||
final RedisCommandsHandler redisCallsHandler) {
|
||||
|
||||
try (final Lua lua = new Lua51()) {
|
||||
lua.openLibraries();
|
||||
final RedisLuaProxy proxy = new RedisLuaProxy(redisCallsHandler);
|
||||
lua.push(MapperLuaProxy.INSTANCE, Lua.Conversion.FULL);
|
||||
lua.setGlobal("mapper");
|
||||
lua.push(proxy, Lua.Conversion.FULL);
|
||||
lua.setGlobal("proxy");
|
||||
lua.push(keys, Lua.Conversion.FULL);
|
||||
lua.setGlobal("KEYS");
|
||||
lua.push(args, Lua.Conversion.FULL);
|
||||
lua.setGlobal("ARGV");
|
||||
final Lua.LuaError executionResult = lua.run(PREFIX + luaScript);
|
||||
assertEquals("OK", executionResult.name(), "Runtime error during Lua script execution");
|
||||
return adaptOutputResult(lua.get());
|
||||
}
|
||||
}
|
||||
|
||||
protected Object adaptOutputResult(final Object luaObject) {
|
||||
if (luaObject instanceof ImmutableLuaValue<?> luaValue) {
|
||||
final Object javaValue = luaValue.toJavaObject();
|
||||
// validate expected script output type
|
||||
switch (scriptOutputType) {
|
||||
case INTEGER -> assertTrue(javaValue instanceof Double); // lua number is always Double
|
||||
case STATUS -> assertTrue(javaValue instanceof String);
|
||||
case BOOLEAN -> assertTrue(javaValue instanceof Boolean);
|
||||
};
|
||||
if (javaValue instanceof Double d) {
|
||||
return d.longValue();
|
||||
}
|
||||
if (javaValue instanceof String s) {
|
||||
return s;
|
||||
}
|
||||
if (javaValue instanceof Boolean b) {
|
||||
return b;
|
||||
}
|
||||
if (javaValue == null) {
|
||||
return null;
|
||||
}
|
||||
throw new IllegalStateException("unexpected script result java type: " + javaValue.getClass().getName());
|
||||
}
|
||||
throw new IllegalStateException("unexpected script result lua type: " + luaObject.getClass().getName());
|
||||
}
|
||||
|
||||
public static <T> List<T> tail(final List<T> list, final int fromIdx) {
|
||||
return fromIdx < list.size() ? list.subList(fromIdx, list.size()) : Collections.emptyList();
|
||||
}
|
||||
|
||||
public static final class MapperLuaProxy {
|
||||
|
||||
public static final MapperLuaProxy INSTANCE = new MapperLuaProxy();
|
||||
|
||||
public String encode(final Map<Object, Object> obj) {
|
||||
try {
|
||||
return SystemMapper.jsonMapper().writeValueAsString(obj);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public Map<Object, Object> decode(final Object json) {
|
||||
try {
|
||||
//noinspection unchecked
|
||||
return SystemMapper.jsonMapper().readValue(json.toString(), Map.class);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Instances of this class are passed to the Lua scripting engine
|
||||
* and serve as a stubs for the calls to `redis.call()`.
|
||||
*
|
||||
* @see #PREFIX
|
||||
*/
|
||||
public static final class RedisLuaProxy {
|
||||
|
||||
private final RedisCommandsHandler handler;
|
||||
|
||||
public RedisLuaProxy(final RedisCommandsHandler handler) {
|
||||
this.handler = handler;
|
||||
}
|
||||
|
||||
/**
|
||||
* Method name needs to match the one from the {@link #PREFIX} code.
|
||||
* The method is getting called from the Lua scripting engine.
|
||||
*/
|
||||
@SuppressWarnings("unused")
|
||||
public Object redisCall(final List<Object> args) {
|
||||
assertFalse(args.isEmpty(), "`redis.call()` in Lua script invoked without arguments");
|
||||
assertTrue(args.get(0) instanceof String, "first argument to `redis.call()` must be of type `String`");
|
||||
return handler.redisCommand((String) args.get(0), tail(args, 1));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,73 @@
|
|||
/*
|
||||
* Copyright 2023 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.util.redis;
|
||||
|
||||
import java.time.Clock;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public class SimpleCacheCommandsHandler extends BaseRedisCommandsHandler {
|
||||
|
||||
public record Entry(String value, long expirationEpochMillis) {
|
||||
}
|
||||
|
||||
private final Map<String, Entry> cache = new ConcurrentHashMap<>();
|
||||
|
||||
private final Clock clock;
|
||||
|
||||
|
||||
public SimpleCacheCommandsHandler(final Clock clock) {
|
||||
this.clock = clock;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object set(final String key, final String value, final List<Object> tail) {
|
||||
cache.put(key, new Entry(value, resolveExpirationEpochMillis(tail)));
|
||||
return "OK";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String get(final String key) {
|
||||
final Entry entry = cache.get(key);
|
||||
if (entry == null) {
|
||||
return null;
|
||||
}
|
||||
if (entry.expirationEpochMillis() < clock.millis()) {
|
||||
del(key);
|
||||
return null;
|
||||
}
|
||||
return entry.value();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int del(final String key) {
|
||||
return cache.remove(key) != null ? 1 : 0;
|
||||
}
|
||||
|
||||
protected long resolveExpirationEpochMillis(final List<Object> args) {
|
||||
for (int i = 0; i < args.size() - 1; i++) {
|
||||
final long currentTimeMillis = clock.millis();
|
||||
final String param = args.get(i).toString();
|
||||
final String value = args.get(i + 1).toString();
|
||||
switch (param) {
|
||||
case "EX" -> {
|
||||
return currentTimeMillis + Double.valueOf(value).longValue() * 1000;
|
||||
}
|
||||
case "PX" -> {
|
||||
return currentTimeMillis + Double.valueOf(value).longValue();
|
||||
}
|
||||
case "EXAT" -> {
|
||||
return Double.valueOf(value).longValue() * 1000;
|
||||
}
|
||||
case "PXAT" -> {
|
||||
return Double.valueOf(value).longValue();
|
||||
}
|
||||
}
|
||||
}
|
||||
return Long.MAX_VALUE;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue