Simplify rate limiters by making them all dynamic
This commit is contained in:
parent
aafcd63a9f
commit
35604cf151
|
@ -407,10 +407,6 @@ public class WhisperServerConfiguration extends Configuration {
|
||||||
return rateLimitersCluster;
|
return rateLimitersCluster;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Map<String, RateLimiterConfig> getLimitsConfiguration() {
|
|
||||||
return limits;
|
|
||||||
}
|
|
||||||
|
|
||||||
public FcmConfiguration getFcmConfiguration() {
|
public FcmConfiguration getFcmConfiguration() {
|
||||||
return fcm;
|
return fcm;
|
||||||
}
|
}
|
||||||
|
|
|
@ -639,8 +639,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
new PushNotificationManager(accountsManager, apnSender, fcmSender, pushNotificationScheduler, experimentEnrollmentManager);
|
new PushNotificationManager(accountsManager, apnSender, fcmSender, pushNotificationScheduler, experimentEnrollmentManager);
|
||||||
WebSocketConnectionEventManager webSocketConnectionEventManager =
|
WebSocketConnectionEventManager webSocketConnectionEventManager =
|
||||||
new WebSocketConnectionEventManager(accountsManager, pushNotificationManager, messagesCluster, clientEventExecutor, asyncOperationQueueingExecutor);
|
new WebSocketConnectionEventManager(accountsManager, pushNotificationManager, messagesCluster, clientEventExecutor, asyncOperationQueueingExecutor);
|
||||||
RateLimiters rateLimiters = RateLimiters.createAndValidate(config.getLimitsConfiguration(),
|
RateLimiters rateLimiters = RateLimiters.create(dynamicConfigurationManager, rateLimitersCluster);
|
||||||
dynamicConfigurationManager, rateLimitersCluster);
|
|
||||||
ProvisioningManager provisioningManager = new ProvisioningManager(pubsubClient);
|
ProvisioningManager provisioningManager = new ProvisioningManager(pubsubClient);
|
||||||
IssuedReceiptsManager issuedReceiptsManager = new IssuedReceiptsManager(
|
IssuedReceiptsManager issuedReceiptsManager = new IssuedReceiptsManager(
|
||||||
config.getDynamoDbTables().getIssuedReceipts().getTableName(),
|
config.getDynamoDbTables().getIssuedReceipts().getTableName(),
|
||||||
|
|
|
@ -10,16 +10,12 @@ import static java.util.Objects.requireNonNull;
|
||||||
import io.lettuce.core.ScriptOutputType;
|
import io.lettuce.core.ScriptOutputType;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.UncheckedIOException;
|
import java.io.UncheckedIOException;
|
||||||
import java.lang.invoke.MethodHandles;
|
|
||||||
import java.time.Clock;
|
import java.time.Clock;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
|
||||||
import java.util.function.Supplier;
|
import java.util.function.Supplier;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
|
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
|
||||||
|
@ -27,25 +23,18 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||||
|
|
||||||
public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
|
public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
|
||||||
|
|
||||||
private final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
|
|
||||||
|
|
||||||
private final Map<T, RateLimiter> rateLimiterByDescriptor;
|
private final Map<T, RateLimiter> rateLimiterByDescriptor;
|
||||||
|
|
||||||
private final Map<String, RateLimiterConfig> configs;
|
|
||||||
|
|
||||||
|
|
||||||
protected BaseRateLimiters(
|
protected BaseRateLimiters(
|
||||||
final T[] values,
|
final T[] values,
|
||||||
final Map<String, RateLimiterConfig> configs,
|
|
||||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||||
final ClusterLuaScript validateScript,
|
final ClusterLuaScript validateScript,
|
||||||
final FaultTolerantRedisClusterClient cacheCluster,
|
final FaultTolerantRedisClusterClient cacheCluster,
|
||||||
final Clock clock) {
|
final Clock clock) {
|
||||||
this.configs = configs;
|
|
||||||
this.rateLimiterByDescriptor = Arrays.stream(values)
|
this.rateLimiterByDescriptor = Arrays.stream(values)
|
||||||
.map(descriptor -> Pair.of(
|
.map(descriptor -> Pair.of(
|
||||||
descriptor,
|
descriptor,
|
||||||
createForDescriptor(descriptor, configs, dynamicConfigurationManager, validateScript, cacheCluster, clock)))
|
createForDescriptor(descriptor, dynamicConfigurationManager, validateScript, cacheCluster, clock)))
|
||||||
.collect(Collectors.toUnmodifiableMap(Pair::getKey, Pair::getValue));
|
.collect(Collectors.toUnmodifiableMap(Pair::getKey, Pair::getValue));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,22 +42,6 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
|
||||||
return requireNonNull(rateLimiterByDescriptor.get(handle));
|
return requireNonNull(rateLimiterByDescriptor.get(handle));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void validateValuesAndConfigs() {
|
|
||||||
final Set<String> ids = rateLimiterByDescriptor.keySet().stream()
|
|
||||||
.map(RateLimiterDescriptor::id)
|
|
||||||
.collect(Collectors.toSet());
|
|
||||||
for (final String key: configs.keySet()) {
|
|
||||||
if (!ids.contains(key)) {
|
|
||||||
final String message = String.format(
|
|
||||||
"Static configuration has an unexpected field '%s' that doesn't match any RateLimiterDescriptor",
|
|
||||||
key
|
|
||||||
);
|
|
||||||
logger.error(message);
|
|
||||||
throw new IllegalArgumentException(message);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected static ClusterLuaScript defaultScript(final FaultTolerantRedisClusterClient cacheCluster) {
|
protected static ClusterLuaScript defaultScript(final FaultTolerantRedisClusterClient cacheCluster) {
|
||||||
try {
|
try {
|
||||||
return ClusterLuaScript.fromResource(
|
return ClusterLuaScript.fromResource(
|
||||||
|
@ -80,21 +53,12 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
|
||||||
|
|
||||||
private static RateLimiter createForDescriptor(
|
private static RateLimiter createForDescriptor(
|
||||||
final RateLimiterDescriptor descriptor,
|
final RateLimiterDescriptor descriptor,
|
||||||
final Map<String, RateLimiterConfig> configs,
|
|
||||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||||
final ClusterLuaScript validateScript,
|
final ClusterLuaScript validateScript,
|
||||||
final FaultTolerantRedisClusterClient cacheCluster,
|
final FaultTolerantRedisClusterClient cacheCluster,
|
||||||
final Clock clock) {
|
final Clock clock) {
|
||||||
if (descriptor.isDynamic()) {
|
final Supplier<RateLimiterConfig> configResolver =
|
||||||
final Supplier<RateLimiterConfig> configResolver = () -> {
|
() -> dynamicConfigurationManager.getConfiguration().getLimits().getOrDefault(descriptor.id(), descriptor.defaultConfig());
|
||||||
final RateLimiterConfig config = dynamicConfigurationManager.getConfiguration().getLimits().get(descriptor.id());
|
return new DynamicRateLimiter(descriptor.id(), configResolver, validateScript, cacheCluster, clock);
|
||||||
return config != null
|
|
||||||
? config
|
|
||||||
: configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
|
|
||||||
};
|
|
||||||
return new DynamicRateLimiter(descriptor.id(), dynamicConfigurationManager, configResolver, validateScript, cacheCluster, clock);
|
|
||||||
}
|
|
||||||
final RateLimiterConfig cfg = configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
|
|
||||||
return new StaticRateLimiter(descriptor.id(), cfg, validateScript, cacheCluster, clock);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,87 +7,167 @@ package org.whispersystems.textsecuregcm.limits;
|
||||||
|
|
||||||
import static java.util.Objects.requireNonNull;
|
import static java.util.Objects.requireNonNull;
|
||||||
|
|
||||||
|
import io.micrometer.core.instrument.Counter;
|
||||||
|
import io.micrometer.core.instrument.Metrics;
|
||||||
import java.time.Clock;
|
import java.time.Clock;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
import java.util.concurrent.CompletionStage;
|
import java.util.concurrent.CompletionStage;
|
||||||
import java.util.concurrent.atomic.AtomicReference;
|
|
||||||
import java.util.function.Supplier;
|
import java.util.function.Supplier;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
|
||||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
|
||||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||||
|
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
|
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
|
||||||
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||||
|
import org.whispersystems.textsecuregcm.util.Util;
|
||||||
|
|
||||||
public class DynamicRateLimiter implements RateLimiter {
|
public class DynamicRateLimiter implements RateLimiter {
|
||||||
|
|
||||||
private final String name;
|
private final String name;
|
||||||
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
|
|
||||||
private final Supplier<RateLimiterConfig> configResolver;
|
private final Supplier<RateLimiterConfig> configResolver;
|
||||||
|
|
||||||
private final ClusterLuaScript validateScript;
|
private final ClusterLuaScript validateScript;
|
||||||
|
|
||||||
private final FaultTolerantRedisClusterClient cluster;
|
private final FaultTolerantRedisClusterClient cluster;
|
||||||
|
|
||||||
private final Clock clock;
|
private final Counter limitExceededCounter;
|
||||||
|
|
||||||
private final AtomicReference<Pair<RateLimiterConfig, RateLimiter>> currentHolder = new AtomicReference<>();
|
private final Clock clock;
|
||||||
|
|
||||||
|
|
||||||
public DynamicRateLimiter(
|
public DynamicRateLimiter(
|
||||||
final String name,
|
final String name,
|
||||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
|
||||||
final Supplier<RateLimiterConfig> configResolver,
|
final Supplier<RateLimiterConfig> configResolver,
|
||||||
final ClusterLuaScript validateScript,
|
final ClusterLuaScript validateScript,
|
||||||
final FaultTolerantRedisClusterClient cluster,
|
final FaultTolerantRedisClusterClient cluster,
|
||||||
final Clock clock) {
|
final Clock clock) {
|
||||||
this.name = requireNonNull(name);
|
this.name = requireNonNull(name);
|
||||||
this.dynamicConfigurationManager = dynamicConfigurationManager;
|
|
||||||
this.configResolver = requireNonNull(configResolver);
|
this.configResolver = requireNonNull(configResolver);
|
||||||
this.validateScript = requireNonNull(validateScript);
|
this.validateScript = requireNonNull(validateScript);
|
||||||
this.cluster = requireNonNull(cluster);
|
this.cluster = requireNonNull(cluster);
|
||||||
this.clock = requireNonNull(clock);
|
this.clock = requireNonNull(clock);
|
||||||
|
this.limitExceededCounter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "rateLimiterName", name);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void validate(final String key, final int amount) throws RateLimitExceededException {
|
public void validate(final String key, final int amount) throws RateLimitExceededException {
|
||||||
current().getRight().validate(key, amount);
|
final RateLimiterConfig config = config();
|
||||||
|
try {
|
||||||
|
final long deficitPermitsAmount = executeValidateScript(config, key, amount, true);
|
||||||
|
if (deficitPermitsAmount > 0) {
|
||||||
|
limitExceededCounter.increment();
|
||||||
|
final Duration retryAfter = Duration.ofMillis(
|
||||||
|
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
|
||||||
|
throw new RateLimitExceededException(retryAfter);
|
||||||
|
}
|
||||||
|
} catch (final Exception e) {
|
||||||
|
if (e instanceof RateLimitExceededException rateLimitExceededException) {
|
||||||
|
throw rateLimitExceededException;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!config.failOpen()) {
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public CompletionStage<Void> validateAsync(final String key, final int amount) {
|
public CompletionStage<Void> validateAsync(final String key, final int amount) {
|
||||||
return current().getRight().validateAsync(key, amount);
|
final RateLimiterConfig config = config();
|
||||||
|
|
||||||
|
return executeValidateScriptAsync(config, key, amount, true)
|
||||||
|
.thenCompose(deficitPermitsAmount -> {
|
||||||
|
if (deficitPermitsAmount == 0) {
|
||||||
|
return CompletableFuture.completedFuture((Void) null);
|
||||||
|
}
|
||||||
|
limitExceededCounter.increment();
|
||||||
|
final Duration retryAfter = Duration.ofMillis(
|
||||||
|
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
|
||||||
|
return CompletableFuture.failedFuture(new RateLimitExceededException(retryAfter));
|
||||||
|
})
|
||||||
|
.exceptionally(throwable -> {
|
||||||
|
if (ExceptionUtils.unwrap(throwable) instanceof RateLimitExceededException rateLimitExceededException) {
|
||||||
|
throw ExceptionUtils.wrap(rateLimitExceededException);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (config.failOpen()) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw ExceptionUtils.wrap(throwable);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean hasAvailablePermits(final String key, final int permits) {
|
public boolean hasAvailablePermits(final String key, final int permits) {
|
||||||
return current().getRight().hasAvailablePermits(key, permits);
|
final RateLimiterConfig config = config();
|
||||||
|
try {
|
||||||
|
final long deficitPermitsAmount = executeValidateScript(config, key, permits, false);
|
||||||
|
return deficitPermitsAmount == 0;
|
||||||
|
} catch (final Exception e) {
|
||||||
|
if (config.failOpen()) {
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public CompletionStage<Boolean> hasAvailablePermitsAsync(final String key, final int amount) {
|
public CompletionStage<Boolean> hasAvailablePermitsAsync(final String key, final int amount) {
|
||||||
return current().getRight().hasAvailablePermitsAsync(key, amount);
|
final RateLimiterConfig config = config();
|
||||||
|
return executeValidateScriptAsync(config, key, amount, false)
|
||||||
|
.thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0)
|
||||||
|
.exceptionally(throwable -> {
|
||||||
|
if (config.failOpen()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
throw ExceptionUtils.wrap(throwable);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void clear(final String key) {
|
public void clear(final String key) {
|
||||||
current().getRight().clear(key);
|
cluster.useCluster(connection -> connection.sync().del(bucketName(name, key)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public CompletionStage<Void> clearAsync(final String key) {
|
public CompletionStage<Void> clearAsync(final String key) {
|
||||||
return current().getRight().clearAsync(key);
|
return cluster.withCluster(connection -> connection.async().del(bucketName(name, key)))
|
||||||
|
.thenRun(Util.NOOP);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RateLimiterConfig config() {
|
public RateLimiterConfig config() {
|
||||||
return current().getLeft();
|
return configResolver.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
private Pair<RateLimiterConfig, RateLimiter> current() {
|
private long executeValidateScript(final RateLimiterConfig config, final String key, final int amount, final boolean applyChanges) {
|
||||||
final RateLimiterConfig cfg = configResolver.get();
|
final List<String> keys = List.of(bucketName(name, key));
|
||||||
return currentHolder.updateAndGet(p -> p != null && p.getLeft().equals(cfg)
|
final List<String> arguments = List.of(
|
||||||
? p
|
String.valueOf(config.bucketSize()),
|
||||||
: Pair.of(cfg, new StaticRateLimiter(name, cfg, validateScript, cluster, clock))
|
String.valueOf(config.leakRatePerMillis()),
|
||||||
|
String.valueOf(clock.millis()),
|
||||||
|
String.valueOf(amount),
|
||||||
|
String.valueOf(applyChanges)
|
||||||
);
|
);
|
||||||
|
return (Long) validateScript.execute(keys, arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
private CompletionStage<Long> executeValidateScriptAsync(final RateLimiterConfig config, final String key, final int amount, final boolean applyChanges) {
|
||||||
|
final List<String> keys = List.of(bucketName(name, key));
|
||||||
|
final List<String> arguments = List.of(
|
||||||
|
String.valueOf(config.bucketSize()),
|
||||||
|
String.valueOf(config.leakRatePerMillis()),
|
||||||
|
String.valueOf(clock.millis()),
|
||||||
|
String.valueOf(amount),
|
||||||
|
String.valueOf(applyChanges)
|
||||||
|
);
|
||||||
|
return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String bucketName(final String name, final String key) {
|
||||||
|
return "leaky_bucket::" + name + "::" + key;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,14 +15,9 @@ public interface RateLimiterDescriptor {
|
||||||
*/
|
*/
|
||||||
String id();
|
String id();
|
||||||
|
|
||||||
/**
|
|
||||||
* @return {@code true} if this rate limiter needs to watch for dynamic configuration changes.
|
|
||||||
*/
|
|
||||||
boolean isDynamic();
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return an instance of {@link RateLimiterConfig} to be used by default,
|
* @return an instance of {@link RateLimiterConfig} to be used by default,
|
||||||
* i.e. if there is no overrides in the application configuration files (static or dynamic).
|
* i.e. if there is no override in the application dynamic configuration.
|
||||||
*/
|
*/
|
||||||
RateLimiterConfig defaultConfig();
|
RateLimiterConfig defaultConfig();
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,11 +4,9 @@
|
||||||
*/
|
*/
|
||||||
package org.whispersystems.textsecuregcm.limits;
|
package org.whispersystems.textsecuregcm.limits;
|
||||||
|
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
import java.time.Clock;
|
import java.time.Clock;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.util.Map;
|
|
||||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
|
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
|
||||||
|
@ -17,57 +15,54 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||||
public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
|
public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
|
||||||
|
|
||||||
public enum For implements RateLimiterDescriptor {
|
public enum For implements RateLimiterDescriptor {
|
||||||
BACKUP_AUTH_CHECK("backupAuthCheck", false, new RateLimiterConfig(100, Duration.ofMinutes(15), true)),
|
BACKUP_AUTH_CHECK("backupAuthCheck", new RateLimiterConfig(100, Duration.ofMinutes(15), true)),
|
||||||
PIN("pin", false, new RateLimiterConfig(10, Duration.ofDays(1), false)),
|
PIN("pin", new RateLimiterConfig(10, Duration.ofDays(1), false)),
|
||||||
ATTACHMENT("attachmentCreate", false, new RateLimiterConfig(50, Duration.ofMillis(1200), true)),
|
ATTACHMENT("attachmentCreate", new RateLimiterConfig(50, Duration.ofMillis(1200), true)),
|
||||||
BACKUP_ATTACHMENT("backupAttachmentCreate", true, new RateLimiterConfig(10_000, Duration.ofSeconds(1), true)),
|
BACKUP_ATTACHMENT("backupAttachmentCreate", new RateLimiterConfig(10_000, Duration.ofSeconds(1), true)),
|
||||||
PRE_KEYS("prekeys", false, new RateLimiterConfig(6, Duration.ofMinutes(10), false)),
|
PRE_KEYS("prekeys", new RateLimiterConfig(6, Duration.ofMinutes(10), false)),
|
||||||
MESSAGES("messages", false, new RateLimiterConfig(60, Duration.ofSeconds(1), true)),
|
MESSAGES("messages", new RateLimiterConfig(60, Duration.ofSeconds(1), true)),
|
||||||
STORIES("stories", false, new RateLimiterConfig(5_000, Duration.ofSeconds(8), true)),
|
STORIES("stories", new RateLimiterConfig(5_000, Duration.ofSeconds(8), true)),
|
||||||
ALLOCATE_DEVICE("allocateDevice", false, new RateLimiterConfig(6, Duration.ofMinutes(2), false)),
|
ALLOCATE_DEVICE("allocateDevice", new RateLimiterConfig(6, Duration.ofMinutes(2), false)),
|
||||||
VERIFY_DEVICE("verifyDevice", false, new RateLimiterConfig(6, Duration.ofMinutes(2), false)),
|
VERIFY_DEVICE("verifyDevice", new RateLimiterConfig(6, Duration.ofMinutes(2), false)),
|
||||||
PROFILE("profile", false, new RateLimiterConfig(4320, Duration.ofSeconds(20), true)),
|
PROFILE("profile", new RateLimiterConfig(4320, Duration.ofSeconds(20), true)),
|
||||||
STICKER_PACK("stickerPack", false, new RateLimiterConfig(50, Duration.ofMinutes(72), false)),
|
STICKER_PACK("stickerPack", new RateLimiterConfig(50, Duration.ofMinutes(72), false)),
|
||||||
USERNAME_LOOKUP("usernameLookup", false, new RateLimiterConfig(100, Duration.ofMinutes(15), true)),
|
USERNAME_LOOKUP("usernameLookup", new RateLimiterConfig(100, Duration.ofMinutes(15), true)),
|
||||||
USERNAME_SET("usernameSet", false, new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
|
USERNAME_SET("usernameSet", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
|
||||||
USERNAME_RESERVE("usernameReserve", false, new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
|
USERNAME_RESERVE("usernameReserve", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
|
||||||
USERNAME_LINK_OPERATION("usernameLinkOperation", false, new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
|
USERNAME_LINK_OPERATION("usernameLinkOperation", new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
|
||||||
USERNAME_LINK_LOOKUP_PER_IP("usernameLinkLookupPerIp", false, new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
|
USERNAME_LINK_LOOKUP_PER_IP("usernameLinkLookupPerIp", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
|
||||||
CHECK_ACCOUNT_EXISTENCE("checkAccountExistence", false, new RateLimiterConfig(1000, Duration.ofSeconds(4), true)),
|
CHECK_ACCOUNT_EXISTENCE("checkAccountExistence", new RateLimiterConfig(1000, Duration.ofSeconds(4), true)),
|
||||||
REGISTRATION("registration", false, new RateLimiterConfig(6, Duration.ofSeconds(30), false)),
|
REGISTRATION("registration", new RateLimiterConfig(6, Duration.ofSeconds(30), false)),
|
||||||
VERIFICATION_PUSH_CHALLENGE("verificationPushChallenge", false, new RateLimiterConfig(5, Duration.ofSeconds(30), false)),
|
VERIFICATION_PUSH_CHALLENGE("verificationPushChallenge", new RateLimiterConfig(5, Duration.ofSeconds(30), false)),
|
||||||
VERIFICATION_CAPTCHA("verificationCaptcha", false, new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
|
VERIFICATION_CAPTCHA("verificationCaptcha", new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
|
||||||
RATE_LIMIT_RESET("rateLimitReset", true, new RateLimiterConfig(2, Duration.ofHours(12), false)),
|
RATE_LIMIT_RESET("rateLimitReset", new RateLimiterConfig(2, Duration.ofHours(12), false)),
|
||||||
CAPTCHA_CHALLENGE_ATTEMPT("captchaChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofMinutes(144), false)),
|
CAPTCHA_CHALLENGE_ATTEMPT("captchaChallengeAttempt", new RateLimiterConfig(10, Duration.ofMinutes(144), false)),
|
||||||
CAPTCHA_CHALLENGE_SUCCESS("captchaChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofHours(12), false)),
|
CAPTCHA_CHALLENGE_SUCCESS("captchaChallengeSuccess", new RateLimiterConfig(2, Duration.ofHours(12), false)),
|
||||||
SET_BACKUP_ID("setBackupId", true, new RateLimiterConfig(10, Duration.ofHours(1), false)),
|
SET_BACKUP_ID("setBackupId", new RateLimiterConfig(10, Duration.ofHours(1), false)),
|
||||||
SET_PAID_MEDIA_BACKUP_ID("setPaidMediaBackupId", true, new RateLimiterConfig(5, Duration.ofDays(7), false)),
|
SET_PAID_MEDIA_BACKUP_ID("setPaidMediaBackupId", new RateLimiterConfig(5, Duration.ofDays(7), false)),
|
||||||
PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofMinutes(144), false)),
|
PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", new RateLimiterConfig(10, Duration.ofMinutes(144), false)),
|
||||||
PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofHours(12), false)),
|
PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", new RateLimiterConfig(2, Duration.ofHours(12), false)),
|
||||||
GET_CALLING_RELAYS("getCallingRelays", false, new RateLimiterConfig(100, Duration.ofMinutes(10), false)),
|
GET_CALLING_RELAYS("getCallingRelays", new RateLimiterConfig(100, Duration.ofMinutes(10), false)),
|
||||||
CREATE_CALL_LINK("createCallLink", false, new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
|
CREATE_CALL_LINK("createCallLink", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
|
||||||
INBOUND_MESSAGE_BYTES("inboundMessageBytes", true, new RateLimiterConfig(128 * 1024 * 1024, Duration.ofNanos(500_000), true)),
|
INBOUND_MESSAGE_BYTES("inboundMessageBytes", new RateLimiterConfig(128 * 1024 * 1024, Duration.ofNanos(500_000), true)),
|
||||||
EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", true, new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
|
EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
|
||||||
KEY_TRANSPARENCY_DISTINGUISHED_PER_IP("keyTransparencyDistinguished", true, new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
|
KEY_TRANSPARENCY_DISTINGUISHED_PER_IP("keyTransparencyDistinguished", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
|
||||||
KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", true, new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
|
KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
|
||||||
KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", true, new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
|
KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", new RateLimiterConfig(100, Duration.ofSeconds(15), true)),
|
||||||
WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", true, new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
|
WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
|
||||||
UPLOAD_TRANSFER_ARCHIVE("uploadTransferArchive", true, new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
|
UPLOAD_TRANSFER_ARCHIVE("uploadTransferArchive", new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
|
||||||
WAIT_FOR_TRANSFER_ARCHIVE("waitForTransferArchive", true, new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
|
WAIT_FOR_TRANSFER_ARCHIVE("waitForTransferArchive", new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
|
||||||
RECORD_DEVICE_TRANSFER_REQUEST("recordDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100), true)),
|
RECORD_DEVICE_TRANSFER_REQUEST("recordDeviceTransferRequest", new RateLimiterConfig(10, Duration.ofMillis(100), true)),
|
||||||
WAIT_FOR_DEVICE_TRANSFER_REQUEST("waitForDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100), true)),
|
WAIT_FOR_DEVICE_TRANSFER_REQUEST("waitForDeviceTransferRequest", new RateLimiterConfig(10, Duration.ofMillis(100), true)),
|
||||||
DEVICE_CHECK_CHALLENGE("deviceCheckChallenge", true, new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
|
DEVICE_CHECK_CHALLENGE("deviceCheckChallenge", new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
|
||||||
;
|
;
|
||||||
|
|
||||||
private final String id;
|
private final String id;
|
||||||
|
|
||||||
private final boolean dynamic;
|
|
||||||
|
|
||||||
private final RateLimiterConfig defaultConfig;
|
private final RateLimiterConfig defaultConfig;
|
||||||
|
|
||||||
For(final String id, final boolean dynamic, final RateLimiterConfig defaultConfig) {
|
For(final String id, final RateLimiterConfig defaultConfig) {
|
||||||
this.id = id;
|
this.id = id;
|
||||||
this.dynamic = dynamic;
|
|
||||||
this.defaultConfig = defaultConfig;
|
this.defaultConfig = defaultConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,34 +70,25 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isDynamic() {
|
|
||||||
return dynamic;
|
|
||||||
}
|
|
||||||
|
|
||||||
public RateLimiterConfig defaultConfig() {
|
public RateLimiterConfig defaultConfig() {
|
||||||
return defaultConfig;
|
return defaultConfig;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static RateLimiters createAndValidate(
|
public static RateLimiters create(
|
||||||
final Map<String, RateLimiterConfig> configs,
|
|
||||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||||
final FaultTolerantRedisClusterClient cacheCluster) {
|
final FaultTolerantRedisClusterClient cacheCluster) {
|
||||||
final RateLimiters rateLimiters = new RateLimiters(
|
return new RateLimiters(
|
||||||
configs, dynamicConfigurationManager, defaultScript(cacheCluster), cacheCluster, Clock.systemUTC());
|
dynamicConfigurationManager, defaultScript(cacheCluster), cacheCluster, Clock.systemUTC());
|
||||||
rateLimiters.validateValuesAndConfigs();
|
|
||||||
return rateLimiters;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@VisibleForTesting
|
@VisibleForTesting
|
||||||
RateLimiters(
|
RateLimiters(
|
||||||
final Map<String, RateLimiterConfig> configs,
|
|
||||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||||
final ClusterLuaScript validateScript,
|
final ClusterLuaScript validateScript,
|
||||||
final FaultTolerantRedisClusterClient cacheCluster,
|
final FaultTolerantRedisClusterClient cacheCluster,
|
||||||
final Clock clock) {
|
final Clock clock) {
|
||||||
super(For.values(), configs, dynamicConfigurationManager, validateScript, cacheCluster, clock);
|
super(For.values(), dynamicConfigurationManager, validateScript, cacheCluster, clock);
|
||||||
}
|
}
|
||||||
|
|
||||||
public RateLimiter getAllocateDeviceLimiter() {
|
public RateLimiter getAllocateDeviceLimiter() {
|
||||||
|
|
|
@ -1,170 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright 2013 Signal Messenger, LLC
|
|
||||||
* SPDX-License-Identifier: AGPL-3.0-only
|
|
||||||
*/
|
|
||||||
package org.whispersystems.textsecuregcm.limits;
|
|
||||||
|
|
||||||
import static java.util.Objects.requireNonNull;
|
|
||||||
import static java.util.concurrent.CompletableFuture.completedFuture;
|
|
||||||
import static java.util.concurrent.CompletableFuture.failedFuture;
|
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
|
||||||
import io.micrometer.core.instrument.Counter;
|
|
||||||
import io.micrometer.core.instrument.Metrics;
|
|
||||||
import java.time.Clock;
|
|
||||||
import java.time.Duration;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.concurrent.CompletionStage;
|
|
||||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
|
||||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
|
||||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
|
||||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
|
|
||||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
|
||||||
import org.whispersystems.textsecuregcm.util.Util;
|
|
||||||
|
|
||||||
public class StaticRateLimiter implements RateLimiter {
|
|
||||||
|
|
||||||
protected final String name;
|
|
||||||
|
|
||||||
private final RateLimiterConfig config;
|
|
||||||
|
|
||||||
private final Counter limitExceededCounter;
|
|
||||||
|
|
||||||
private final ClusterLuaScript validateScript;
|
|
||||||
|
|
||||||
private final FaultTolerantRedisClusterClient cacheCluster;
|
|
||||||
|
|
||||||
private final Clock clock;
|
|
||||||
|
|
||||||
|
|
||||||
public StaticRateLimiter(
|
|
||||||
final String name,
|
|
||||||
final RateLimiterConfig config,
|
|
||||||
final ClusterLuaScript validateScript,
|
|
||||||
final FaultTolerantRedisClusterClient cacheCluster,
|
|
||||||
final Clock clock) {
|
|
||||||
this.name = requireNonNull(name);
|
|
||||||
this.config = requireNonNull(config);
|
|
||||||
this.validateScript = requireNonNull(validateScript);
|
|
||||||
this.cacheCluster = requireNonNull(cacheCluster);
|
|
||||||
this.clock = requireNonNull(clock);
|
|
||||||
this.limitExceededCounter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "rateLimiterName", name);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void validate(final String key, final int amount) throws RateLimitExceededException {
|
|
||||||
try {
|
|
||||||
final long deficitPermitsAmount = executeValidateScript(key, amount, true);
|
|
||||||
if (deficitPermitsAmount > 0) {
|
|
||||||
limitExceededCounter.increment();
|
|
||||||
final Duration retryAfter = Duration.ofMillis(
|
|
||||||
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
|
|
||||||
throw new RateLimitExceededException(retryAfter);
|
|
||||||
}
|
|
||||||
} catch (final Exception e) {
|
|
||||||
if (e instanceof RateLimitExceededException rateLimitExceededException) {
|
|
||||||
throw rateLimitExceededException;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!config.failOpen()) {
|
|
||||||
throw e;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public CompletionStage<Void> validateAsync(final String key, final int amount) {
|
|
||||||
return executeValidateScriptAsync(key, amount, true)
|
|
||||||
.thenCompose(deficitPermitsAmount -> {
|
|
||||||
if (deficitPermitsAmount == 0) {
|
|
||||||
return completedFuture((Void) null);
|
|
||||||
}
|
|
||||||
limitExceededCounter.increment();
|
|
||||||
final Duration retryAfter = Duration.ofMillis(
|
|
||||||
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
|
|
||||||
return failedFuture(new RateLimitExceededException(retryAfter));
|
|
||||||
})
|
|
||||||
.exceptionally(throwable -> {
|
|
||||||
if (ExceptionUtils.unwrap(throwable) instanceof RateLimitExceededException rateLimitExceededException) {
|
|
||||||
throw ExceptionUtils.wrap(rateLimitExceededException);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (config.failOpen()) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
throw ExceptionUtils.wrap(throwable);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean hasAvailablePermits(final String key, final int amount) {
|
|
||||||
try {
|
|
||||||
final long deficitPermitsAmount = executeValidateScript(key, amount, false);
|
|
||||||
return deficitPermitsAmount == 0;
|
|
||||||
} catch (final Exception e) {
|
|
||||||
if (config.failOpen()) {
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
throw e;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public CompletionStage<Boolean> hasAvailablePermitsAsync(final String key, final int amount) {
|
|
||||||
return executeValidateScriptAsync(key, amount, false)
|
|
||||||
.thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0)
|
|
||||||
.exceptionally(throwable -> {
|
|
||||||
if (config.failOpen()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
throw ExceptionUtils.wrap(throwable);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void clear(final String key) {
|
|
||||||
cacheCluster.useCluster(connection -> connection.sync().del(bucketName(name, key)));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public CompletionStage<Void> clearAsync(final String key) {
|
|
||||||
return cacheCluster.withCluster(connection -> connection.async().del(bucketName(name, key)))
|
|
||||||
.thenRun(Util.NOOP);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public RateLimiterConfig config() {
|
|
||||||
return config;
|
|
||||||
}
|
|
||||||
|
|
||||||
private long executeValidateScript(final String key, final int amount, final boolean applyChanges) {
|
|
||||||
final List<String> keys = List.of(bucketName(name, key));
|
|
||||||
final List<String> arguments = List.of(
|
|
||||||
String.valueOf(config.bucketSize()),
|
|
||||||
String.valueOf(config.leakRatePerMillis()),
|
|
||||||
String.valueOf(clock.millis()),
|
|
||||||
String.valueOf(amount),
|
|
||||||
String.valueOf(applyChanges)
|
|
||||||
);
|
|
||||||
return (Long) validateScript.execute(keys, arguments);
|
|
||||||
}
|
|
||||||
|
|
||||||
private CompletionStage<Long> executeValidateScriptAsync(final String key, final int amount, final boolean applyChanges) {
|
|
||||||
final List<String> keys = List.of(bucketName(name, key));
|
|
||||||
final List<String> arguments = List.of(
|
|
||||||
String.valueOf(config.bucketSize()),
|
|
||||||
String.valueOf(config.leakRatePerMillis()),
|
|
||||||
String.valueOf(clock.millis()),
|
|
||||||
String.valueOf(amount),
|
|
||||||
String.valueOf(applyChanges)
|
|
||||||
);
|
|
||||||
return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o);
|
|
||||||
}
|
|
||||||
|
|
||||||
@VisibleForTesting
|
|
||||||
protected static String bucketName(final String name, final String key) {
|
|
||||||
return "leaky_bucket::" + name + "::" + key;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -241,8 +241,7 @@ record CommandDependencies(
|
||||||
secureStorageClient, secureValueRecovery2Client, disconnectionRequestManager,
|
secureStorageClient, secureValueRecovery2Client, disconnectionRequestManager,
|
||||||
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, messagePollExecutor,
|
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, messagePollExecutor,
|
||||||
clock, configuration.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager);
|
clock, configuration.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager);
|
||||||
RateLimiters rateLimiters = RateLimiters.createAndValidate(configuration.getLimitsConfiguration(),
|
RateLimiters rateLimiters = RateLimiters.create(dynamicConfigurationManager, rateLimitersCluster);
|
||||||
dynamicConfigurationManager, rateLimitersCluster);
|
|
||||||
final BackupsDb backupsDb =
|
final BackupsDb backupsDb =
|
||||||
new BackupsDb(dynamoDbAsyncClient, configuration.getDynamoDbTables().getBackups().getTableName(), clock);
|
new BackupsDb(dynamoDbAsyncClient, configuration.getDynamoDbTables().getBackups().getTableName(), clock);
|
||||||
final GenericServerSecretParams backupsGenericZkSecretParams;
|
final GenericServerSecretParams backupsGenericZkSecretParams;
|
||||||
|
|
|
@ -0,0 +1,273 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.limits;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
import io.lettuce.core.ScriptOutputType;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.concurrent.CompletionException;
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
import org.apache.commons.lang3.RandomStringUtils;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||||
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
|
import org.junit.jupiter.params.provider.ValueSource;
|
||||||
|
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||||
|
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
||||||
|
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
|
||||||
|
import org.whispersystems.textsecuregcm.util.TestClock;
|
||||||
|
|
||||||
|
class DynamicRateLimiterTest {
|
||||||
|
|
||||||
|
private ClusterLuaScript validateRateLimitScript;
|
||||||
|
|
||||||
|
private static final TestClock CLOCK = TestClock.pinned(Instant.now());
|
||||||
|
|
||||||
|
@RegisterExtension
|
||||||
|
private static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() throws IOException {
|
||||||
|
validateRateLimitScript = ClusterLuaScript.fromResource(
|
||||||
|
REDIS_CLUSTER_EXTENSION.getRedisCluster(), "lua/validate_rate_limit.lua", ScriptOutputType.INTEGER);
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource(booleans = {true, false})
|
||||||
|
void validate(final boolean failOpen) {
|
||||||
|
final DynamicRateLimiter rateLimiter = new DynamicRateLimiter(
|
||||||
|
"test",
|
||||||
|
() -> new RateLimiterConfig(1, Duration.ofHours(1), failOpen),
|
||||||
|
validateRateLimitScript,
|
||||||
|
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
|
||||||
|
CLOCK);
|
||||||
|
|
||||||
|
final String key = RandomStringUtils.insecure().nextAlphanumeric(16);
|
||||||
|
|
||||||
|
assertDoesNotThrow(() -> rateLimiter.validate(key));
|
||||||
|
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key));
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource(booleans = {true, false})
|
||||||
|
void validateAsync(final boolean failOpen) {
|
||||||
|
final DynamicRateLimiter rateLimiter = new DynamicRateLimiter(
|
||||||
|
"test",
|
||||||
|
() -> new RateLimiterConfig(1, Duration.ofHours(1), failOpen),
|
||||||
|
validateRateLimitScript,
|
||||||
|
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
|
||||||
|
CLOCK);
|
||||||
|
|
||||||
|
final String key = RandomStringUtils.insecure().nextAlphanumeric(16);
|
||||||
|
|
||||||
|
assertDoesNotThrow(() -> rateLimiter.validateAsync(key).toCompletableFuture().join());
|
||||||
|
final CompletionException completionException =
|
||||||
|
assertThrows(CompletionException.class, () -> rateLimiter.validateAsync(key).toCompletableFuture().join());
|
||||||
|
|
||||||
|
assertInstanceOf(RateLimitExceededException.class, completionException.getCause());
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource(booleans = {true, false})
|
||||||
|
void validateFailOpen(final boolean failOpen) {
|
||||||
|
final ClusterLuaScript failingScript = mock(ClusterLuaScript.class);
|
||||||
|
when(failingScript.execute(any(), any())).thenThrow(new RuntimeException("OH NO"));
|
||||||
|
|
||||||
|
final DynamicRateLimiter rateLimiter = new DynamicRateLimiter(
|
||||||
|
"test",
|
||||||
|
() -> new RateLimiterConfig(1, Duration.ofHours(1), failOpen),
|
||||||
|
failingScript,
|
||||||
|
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
|
||||||
|
CLOCK);
|
||||||
|
|
||||||
|
final String key = RandomStringUtils.insecure().nextAlphanumeric(16);
|
||||||
|
|
||||||
|
if (failOpen) {
|
||||||
|
assertDoesNotThrow(() -> rateLimiter.validate(key));
|
||||||
|
} else {
|
||||||
|
assertThrows(RuntimeException.class, () -> rateLimiter.validate(key));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource(booleans = {true, false})
|
||||||
|
void validateFailOpenAsync(final boolean failOpen) {
|
||||||
|
final ClusterLuaScript failingScript = mock(ClusterLuaScript.class);
|
||||||
|
when(failingScript.executeAsync(any(), any())).thenReturn(CompletableFuture.failedFuture(new RuntimeException("OH NO")));
|
||||||
|
|
||||||
|
final DynamicRateLimiter rateLimiter = new DynamicRateLimiter(
|
||||||
|
"test",
|
||||||
|
() -> new RateLimiterConfig(1, Duration.ofHours(1), failOpen),
|
||||||
|
failingScript,
|
||||||
|
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
|
||||||
|
CLOCK);
|
||||||
|
|
||||||
|
final String key = RandomStringUtils.insecure().nextAlphanumeric(16);
|
||||||
|
|
||||||
|
if (failOpen) {
|
||||||
|
assertDoesNotThrow(() -> rateLimiter.validate(key));
|
||||||
|
} else {
|
||||||
|
final CompletionException completionException =
|
||||||
|
assertThrows(CompletionException.class, () -> rateLimiter.validateAsync(key).toCompletableFuture().join());
|
||||||
|
|
||||||
|
assertInstanceOf(RuntimeException.class, completionException.getCause());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void configChange_ReduceRefillRate() {
|
||||||
|
final AtomicReference<Duration> refillRate = new AtomicReference<>(Duration.ofMinutes(5));
|
||||||
|
final DynamicRateLimiter rateLimiter = new DynamicRateLimiter(
|
||||||
|
"test",
|
||||||
|
() -> new RateLimiterConfig(1, refillRate.get(), false),
|
||||||
|
validateRateLimitScript,
|
||||||
|
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
|
||||||
|
CLOCK);
|
||||||
|
|
||||||
|
final String key = RandomStringUtils.insecure().nextAlphanumeric(16);
|
||||||
|
|
||||||
|
assertDoesNotThrow(() -> rateLimiter.validate(key));
|
||||||
|
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key));
|
||||||
|
|
||||||
|
CLOCK.pin(CLOCK.instant().plus(Duration.ofMinutes(1)));
|
||||||
|
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key));
|
||||||
|
|
||||||
|
refillRate.set(Duration.ofMinutes(1));
|
||||||
|
assertDoesNotThrow(() -> rateLimiter.validate(key));
|
||||||
|
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void configChange_IncreaseRefillRate() {
|
||||||
|
final AtomicReference<Duration> refillRate = new AtomicReference<>(Duration.ofMinutes(5));
|
||||||
|
final DynamicRateLimiter rateLimiter = new DynamicRateLimiter(
|
||||||
|
"test",
|
||||||
|
() -> new RateLimiterConfig(1, refillRate.get(), false),
|
||||||
|
validateRateLimitScript,
|
||||||
|
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
|
||||||
|
CLOCK);
|
||||||
|
|
||||||
|
final String key = RandomStringUtils.insecure().nextAlphanumeric(16);
|
||||||
|
|
||||||
|
assertDoesNotThrow(() -> rateLimiter.validate(key));
|
||||||
|
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key));
|
||||||
|
|
||||||
|
CLOCK.pin(CLOCK.instant().plus(Duration.ofMinutes(5)));
|
||||||
|
assertTrue(rateLimiter.hasAvailablePermits(key, 1));
|
||||||
|
|
||||||
|
refillRate.set(Duration.ofMinutes(10));
|
||||||
|
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key));
|
||||||
|
|
||||||
|
CLOCK.pin(CLOCK.instant().plus(Duration.ofMinutes(5)));
|
||||||
|
assertDoesNotThrow(() -> rateLimiter.validate(key));
|
||||||
|
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void configChange_ReduceBucketSize() {
|
||||||
|
final AtomicInteger bucketSize = new AtomicInteger(5);
|
||||||
|
final DynamicRateLimiter rateLimiter = new DynamicRateLimiter(
|
||||||
|
"test",
|
||||||
|
() -> new RateLimiterConfig(bucketSize.get(), Duration.ofMinutes(1), false),
|
||||||
|
validateRateLimitScript,
|
||||||
|
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
|
||||||
|
CLOCK);
|
||||||
|
|
||||||
|
final String key = RandomStringUtils.insecure().nextAlphanumeric(16);
|
||||||
|
|
||||||
|
assertDoesNotThrow(() -> rateLimiter.validate(key));
|
||||||
|
assertTrue(rateLimiter.hasAvailablePermits(key, 4));
|
||||||
|
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key, 5));
|
||||||
|
|
||||||
|
bucketSize.set(1);
|
||||||
|
// Changing the bucket size doesn't spend the tokens remaining in existing buckets, but does
|
||||||
|
// effectively make those buckets overflow if it got smaller. There were 4 tokens available
|
||||||
|
// before, so changing the bucket size to 1 effectively means there is 1 token left, not 0
|
||||||
|
assertTrue(rateLimiter.hasAvailablePermits(key, 1));
|
||||||
|
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key, 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void configChange_IncreaseBucketSize() {
|
||||||
|
final AtomicInteger bucketSize = new AtomicInteger(5);
|
||||||
|
final DynamicRateLimiter rateLimiter = new DynamicRateLimiter(
|
||||||
|
"test",
|
||||||
|
() -> new RateLimiterConfig(bucketSize.get(), Duration.ofMinutes(1), false),
|
||||||
|
validateRateLimitScript,
|
||||||
|
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
|
||||||
|
CLOCK);
|
||||||
|
|
||||||
|
final String key = RandomStringUtils.insecure().nextAlphanumeric(16);
|
||||||
|
|
||||||
|
assertDoesNotThrow(() -> rateLimiter.validate(key));
|
||||||
|
assertTrue(rateLimiter.hasAvailablePermits(key, 4));
|
||||||
|
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key, 5));
|
||||||
|
|
||||||
|
bucketSize.set(10);
|
||||||
|
// Increasing the bucket size doesn't retroactively refill buckets in redis, so we have to wait
|
||||||
|
// until the bucket fills up
|
||||||
|
CLOCK.pin(CLOCK.instant().plus(Duration.ofMinutes(10)));
|
||||||
|
assertTrue(rateLimiter.hasAvailablePermits(key, 10));
|
||||||
|
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key, 11));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void configChange_enableFailOpen() {
|
||||||
|
final ClusterLuaScript failingScript = mock(ClusterLuaScript.class);
|
||||||
|
when(failingScript.execute(any(), any())).thenThrow(new RuntimeException("OH NO"));
|
||||||
|
|
||||||
|
final AtomicBoolean failOpen = new AtomicBoolean(false);
|
||||||
|
final DynamicRateLimiter rateLimiter = new DynamicRateLimiter(
|
||||||
|
"test",
|
||||||
|
() -> new RateLimiterConfig(1, Duration.ofMinutes(1), failOpen.get()),
|
||||||
|
failingScript,
|
||||||
|
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
|
||||||
|
CLOCK);
|
||||||
|
|
||||||
|
final String key = RandomStringUtils.insecure().nextAlphanumeric(16);
|
||||||
|
|
||||||
|
assertThrows(RuntimeException.class, () -> rateLimiter.validate(key));
|
||||||
|
|
||||||
|
failOpen.set(true);
|
||||||
|
|
||||||
|
assertDoesNotThrow(() -> rateLimiter.validate(key));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void configChange_disableFailOpen() {
|
||||||
|
final ClusterLuaScript failingScript = mock(ClusterLuaScript.class);
|
||||||
|
when(failingScript.execute(any(), any())).thenThrow(new RuntimeException("OH NO"));
|
||||||
|
|
||||||
|
final AtomicBoolean failOpen = new AtomicBoolean(true);
|
||||||
|
final DynamicRateLimiter rateLimiter = new DynamicRateLimiter(
|
||||||
|
"test",
|
||||||
|
() -> new RateLimiterConfig(1, Duration.ofMinutes(1), failOpen.get()),
|
||||||
|
failingScript,
|
||||||
|
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
|
||||||
|
CLOCK);
|
||||||
|
|
||||||
|
final String key = RandomStringUtils.insecure().nextAlphanumeric(16);
|
||||||
|
|
||||||
|
assertDoesNotThrow(() -> rateLimiter.validate(key));
|
||||||
|
|
||||||
|
failOpen.set(false);
|
||||||
|
|
||||||
|
assertThrows(RuntimeException.class, () -> rateLimiter.validate(key));
|
||||||
|
}
|
||||||
|
}
|
|
@ -57,9 +57,11 @@ public class RateLimitersLuaScriptTest {
|
||||||
@Test
|
@Test
|
||||||
public void testWithEmbeddedRedis() throws Exception {
|
public void testWithEmbeddedRedis() throws Exception {
|
||||||
final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION;
|
final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION;
|
||||||
|
final Map<String, RateLimiterConfig> limiterConfig = Map.of(descriptor.id(), new RateLimiterConfig(60, Duration.ofSeconds(1), false));
|
||||||
|
when(configuration.getLimits()).thenReturn(limiterConfig);
|
||||||
|
|
||||||
final FaultTolerantRedisClusterClient redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster();
|
final FaultTolerantRedisClusterClient redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster();
|
||||||
final RateLimiters limiters = new RateLimiters(
|
final RateLimiters limiters = new RateLimiters(
|
||||||
Map.of(descriptor.id(), new RateLimiterConfig(60, Duration.ofSeconds(1), false)),
|
|
||||||
dynamicConfig,
|
dynamicConfig,
|
||||||
RateLimiters.defaultScript(redisCluster),
|
RateLimiters.defaultScript(redisCluster),
|
||||||
redisCluster,
|
redisCluster,
|
||||||
|
@ -74,9 +76,11 @@ public class RateLimitersLuaScriptTest {
|
||||||
@Test
|
@Test
|
||||||
public void testTtl() throws Exception {
|
public void testTtl() throws Exception {
|
||||||
final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION;
|
final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION;
|
||||||
|
final Map<String, RateLimiterConfig> limiterConfig = Map.of(descriptor.id(), new RateLimiterConfig(1000, Duration.ofSeconds(1), false));
|
||||||
|
when(configuration.getLimits()).thenReturn(limiterConfig);
|
||||||
|
|
||||||
final FaultTolerantRedisClusterClient redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster();
|
final FaultTolerantRedisClusterClient redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster();
|
||||||
final RateLimiters limiters = new RateLimiters(
|
final RateLimiters limiters = new RateLimiters(
|
||||||
Map.of(descriptor.id(), new RateLimiterConfig(1000, Duration.ofSeconds(1), false)),
|
|
||||||
dynamicConfig,
|
dynamicConfig,
|
||||||
RateLimiters.defaultScript(redisCluster),
|
RateLimiters.defaultScript(redisCluster),
|
||||||
redisCluster,
|
redisCluster,
|
||||||
|
@ -126,8 +130,11 @@ public class RateLimitersLuaScriptTest {
|
||||||
public void testFailOpen(final boolean failOpen) {
|
public void testFailOpen(final boolean failOpen) {
|
||||||
final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION;
|
final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION;
|
||||||
final FaultTolerantRedisClusterClient redisCluster = mock(FaultTolerantRedisClusterClient.class);
|
final FaultTolerantRedisClusterClient redisCluster = mock(FaultTolerantRedisClusterClient.class);
|
||||||
|
|
||||||
|
final Map<String, RateLimiterConfig> limiterConfig = Map.of(descriptor.id(), new RateLimiterConfig(1, Duration.ofSeconds(1), failOpen));
|
||||||
|
when(configuration.getLimits()).thenReturn(limiterConfig);
|
||||||
|
|
||||||
final RateLimiters limiters = new RateLimiters(
|
final RateLimiters limiters = new RateLimiters(
|
||||||
Map.of(descriptor.id(), new RateLimiterConfig(1000, Duration.ofSeconds(1), failOpen)),
|
|
||||||
dynamicConfig,
|
dynamicConfig,
|
||||||
RateLimiters.defaultScript(redisCluster),
|
RateLimiters.defaultScript(redisCluster),
|
||||||
redisCluster,
|
redisCluster,
|
||||||
|
|
|
@ -5,17 +5,12 @@
|
||||||
|
|
||||||
package org.whispersystems.textsecuregcm.limits;
|
package org.whispersystems.textsecuregcm.limits;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
||||||
import jakarta.validation.Valid;
|
|
||||||
import jakarta.validation.constraints.NotNull;
|
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
@ -40,48 +35,6 @@ public class RateLimitersTest {
|
||||||
|
|
||||||
private final MutableClock clock = MockUtils.mutableClock(0);
|
private final MutableClock clock = MockUtils.mutableClock(0);
|
||||||
|
|
||||||
private static final String BAD_YAML = """
|
|
||||||
limits:
|
|
||||||
prekeys:
|
|
||||||
bucketSize: 150
|
|
||||||
permitRegenerationDuration: PT6S
|
|
||||||
unexpected:
|
|
||||||
bucketSize: 4
|
|
||||||
permitRegenerationDuration: PT30S
|
|
||||||
""";
|
|
||||||
|
|
||||||
private static final String GOOD_YAML = """
|
|
||||||
limits:
|
|
||||||
prekeys:
|
|
||||||
bucketSize: 150
|
|
||||||
permitRegenerationDuration: PT6S
|
|
||||||
failOpen: true
|
|
||||||
attachmentCreate:
|
|
||||||
bucketSize: 4
|
|
||||||
permitRegenerationDuration: PT30S
|
|
||||||
failOpen: true
|
|
||||||
""";
|
|
||||||
|
|
||||||
public record SimpleDynamicConfiguration(@Valid @NotNull @JsonProperty Map<String, RateLimiterConfig> limits) {
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testValidateConfigs() throws Exception {
|
|
||||||
assertThrows(IllegalArgumentException.class, () -> {
|
|
||||||
final SimpleDynamicConfiguration dynamicConfiguration =
|
|
||||||
DynamicConfigurationManager.parseConfiguration(BAD_YAML, SimpleDynamicConfiguration.class).orElseThrow();
|
|
||||||
|
|
||||||
final RateLimiters rateLimiters = new RateLimiters(dynamicConfiguration.limits(), dynamicConfig, validateScript, redisCluster, clock);
|
|
||||||
rateLimiters.validateValuesAndConfigs();
|
|
||||||
});
|
|
||||||
|
|
||||||
final SimpleDynamicConfiguration dynamicConfiguration =
|
|
||||||
DynamicConfigurationManager.parseConfiguration(GOOD_YAML, SimpleDynamicConfiguration.class).orElseThrow();
|
|
||||||
|
|
||||||
final RateLimiters rateLimiters = new RateLimiters(dynamicConfiguration.limits(), dynamicConfig, validateScript, redisCluster, clock);
|
|
||||||
assertDoesNotThrow(rateLimiters::validateValuesAndConfigs);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testValidateDuplicates() throws Exception {
|
public void testValidateDuplicates() throws Exception {
|
||||||
final TestDescriptor td1 = new TestDescriptor("id1");
|
final TestDescriptor td1 = new TestDescriptor("id1");
|
||||||
|
@ -91,7 +44,6 @@ public class RateLimitersTest {
|
||||||
|
|
||||||
assertThrows(IllegalStateException.class, () -> new BaseRateLimiters<>(
|
assertThrows(IllegalStateException.class, () -> new BaseRateLimiters<>(
|
||||||
new TestDescriptor[] { td1, td2, td3, tdDup },
|
new TestDescriptor[] { td1, td2, td3, tdDup },
|
||||||
Collections.emptyMap(),
|
|
||||||
dynamicConfig,
|
dynamicConfig,
|
||||||
validateScript,
|
validateScript,
|
||||||
redisCluster,
|
redisCluster,
|
||||||
|
@ -99,7 +51,6 @@ public class RateLimitersTest {
|
||||||
|
|
||||||
new BaseRateLimiters<>(
|
new BaseRateLimiters<>(
|
||||||
new TestDescriptor[] { td1, td2, td3 },
|
new TestDescriptor[] { td1, td2, td3 },
|
||||||
Collections.emptyMap(),
|
|
||||||
dynamicConfig,
|
dynamicConfig,
|
||||||
validateScript,
|
validateScript,
|
||||||
redisCluster,
|
redisCluster,
|
||||||
|
@ -108,10 +59,10 @@ public class RateLimitersTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testUnchangingConfiguration() {
|
void testUnchangingConfiguration() {
|
||||||
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock);
|
final RateLimiters rateLimiters = new RateLimiters(dynamicConfig, validateScript, redisCluster, clock);
|
||||||
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
|
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
|
||||||
final RateLimiterConfig expected = RateLimiters.For.RATE_LIMIT_RESET.defaultConfig();
|
final RateLimiterConfig expected = RateLimiters.For.RATE_LIMIT_RESET.defaultConfig();
|
||||||
assertEquals(expected, config(limiter));
|
assertEquals(expected, limiter.config());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -127,78 +78,49 @@ public class RateLimitersTest {
|
||||||
|
|
||||||
when(configuration.getLimits()).thenReturn(limitsConfigMap);
|
when(configuration.getLimits()).thenReturn(limitsConfigMap);
|
||||||
|
|
||||||
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock);
|
final RateLimiters rateLimiters = new RateLimiters(dynamicConfig, validateScript, redisCluster, clock);
|
||||||
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
|
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
|
||||||
|
|
||||||
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), initialRateLimiterConfig);
|
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), initialRateLimiterConfig);
|
||||||
assertEquals(initialRateLimiterConfig, config(limiter));
|
assertEquals(initialRateLimiterConfig, limiter.config());
|
||||||
|
|
||||||
assertEquals(baseConfig, config(rateLimiters.getCaptchaChallengeAttemptLimiter()));
|
assertEquals(baseConfig, rateLimiters.getCaptchaChallengeAttemptLimiter().config());
|
||||||
assertEquals(baseConfig, config(rateLimiters.getCaptchaChallengeSuccessLimiter()));
|
assertEquals(baseConfig, rateLimiters.getCaptchaChallengeSuccessLimiter().config());
|
||||||
|
|
||||||
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), updatedRateLimiterCongig);
|
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), updatedRateLimiterCongig);
|
||||||
assertEquals(updatedRateLimiterCongig, config(limiter));
|
assertEquals(updatedRateLimiterCongig, limiter.config());
|
||||||
|
|
||||||
assertEquals(baseConfig, config(rateLimiters.getCaptchaChallengeAttemptLimiter()));
|
assertEquals(baseConfig, rateLimiters.getCaptchaChallengeAttemptLimiter().config());
|
||||||
assertEquals(baseConfig, config(rateLimiters.getCaptchaChallengeSuccessLimiter()));
|
assertEquals(baseConfig, rateLimiters.getCaptchaChallengeSuccessLimiter().config());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRateLimiterHasItsPrioritiesStraight() throws Exception {
|
public void testRateLimiterHasItsPrioritiesStraight() throws Exception {
|
||||||
final RateLimiters.For descriptor = RateLimiters.For.CAPTCHA_CHALLENGE_ATTEMPT;
|
final RateLimiters.For descriptor = RateLimiters.For.CAPTCHA_CHALLENGE_ATTEMPT;
|
||||||
final RateLimiterConfig configForDynamic = new RateLimiterConfig(1, Duration.ofMinutes(1), false);
|
final RateLimiterConfig configForDynamic = new RateLimiterConfig(1, Duration.ofMinutes(1), false);
|
||||||
final RateLimiterConfig configForStatic = new RateLimiterConfig(2, Duration.ofSeconds(30), false);
|
|
||||||
final RateLimiterConfig defaultConfig = descriptor.defaultConfig();
|
final RateLimiterConfig defaultConfig = descriptor.defaultConfig();
|
||||||
|
|
||||||
final Map<String, RateLimiterConfig> mapForDynamic = new HashMap<>();
|
final Map<String, RateLimiterConfig> mapForDynamic = new HashMap<>();
|
||||||
final Map<String, RateLimiterConfig> mapForStatic = new HashMap<>();
|
|
||||||
|
|
||||||
when(configuration.getLimits()).thenReturn(mapForDynamic);
|
when(configuration.getLimits()).thenReturn(mapForDynamic);
|
||||||
|
|
||||||
final RateLimiters rateLimiters = new RateLimiters(mapForStatic, dynamicConfig, validateScript, redisCluster, clock);
|
final RateLimiters rateLimiters = new RateLimiters(dynamicConfig, validateScript, redisCluster, clock);
|
||||||
final RateLimiter limiter = rateLimiters.forDescriptor(descriptor);
|
final RateLimiter limiter = rateLimiters.forDescriptor(descriptor);
|
||||||
|
|
||||||
// test only default is present
|
// test only default is present
|
||||||
mapForDynamic.remove(descriptor.id());
|
mapForDynamic.remove(descriptor.id());
|
||||||
mapForStatic.remove(descriptor.id());
|
assertEquals(defaultConfig, limiter.config());
|
||||||
assertEquals(defaultConfig, config(limiter));
|
|
||||||
|
|
||||||
// test dynamic and no static
|
// test dynamic config is present
|
||||||
mapForDynamic.put(descriptor.id(), configForDynamic);
|
mapForDynamic.put(descriptor.id(), configForDynamic);
|
||||||
mapForStatic.remove(descriptor.id());
|
assertEquals(configForDynamic, limiter.config());
|
||||||
assertEquals(configForDynamic, config(limiter));
|
|
||||||
|
|
||||||
// test dynamic and static
|
|
||||||
mapForDynamic.put(descriptor.id(), configForDynamic);
|
|
||||||
mapForStatic.put(descriptor.id(), configForStatic);
|
|
||||||
assertEquals(configForDynamic, config(limiter));
|
|
||||||
|
|
||||||
// test static, but no dynamic
|
|
||||||
mapForDynamic.remove(descriptor.id());
|
|
||||||
mapForStatic.put(descriptor.id(), configForStatic);
|
|
||||||
assertEquals(configForStatic, config(limiter));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private record TestDescriptor(String id) implements RateLimiterDescriptor {
|
private record TestDescriptor(String id) implements RateLimiterDescriptor {
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isDynamic() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RateLimiterConfig defaultConfig() {
|
public RateLimiterConfig defaultConfig() {
|
||||||
return new RateLimiterConfig(1, Duration.ofMinutes(1), false);
|
return new RateLimiterConfig(1, Duration.ofMinutes(1), false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static RateLimiterConfig config(final RateLimiter rateLimiter) {
|
|
||||||
if (rateLimiter instanceof StaticRateLimiter rm) {
|
|
||||||
return rm.config();
|
|
||||||
}
|
|
||||||
if (rateLimiter instanceof DynamicRateLimiter rm) {
|
|
||||||
return rm.config();
|
|
||||||
}
|
|
||||||
throw new IllegalArgumentException("Rate limiter is of an unexpected type: " + rateLimiter.getClass().getName());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,74 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright 2025 Signal Messenger, LLC
|
|
||||||
* SPDX-License-Identifier: AGPL-3.0-only
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.whispersystems.textsecuregcm.limits;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
|
||||||
|
|
||||||
import io.lettuce.core.ScriptOutputType;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.time.Duration;
|
|
||||||
import java.time.Instant;
|
|
||||||
import java.util.concurrent.CompletionException;
|
|
||||||
import org.apache.commons.lang3.RandomStringUtils;
|
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
|
||||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
|
||||||
import org.junit.jupiter.params.ParameterizedTest;
|
|
||||||
import org.junit.jupiter.params.provider.ValueSource;
|
|
||||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
|
||||||
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
|
|
||||||
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
|
|
||||||
import org.whispersystems.textsecuregcm.util.TestClock;
|
|
||||||
|
|
||||||
class StaticRateLimiterTest {
|
|
||||||
|
|
||||||
private ClusterLuaScript validateRateLimitScript;
|
|
||||||
|
|
||||||
private static final TestClock CLOCK = TestClock.pinned(Instant.now());
|
|
||||||
|
|
||||||
@RegisterExtension
|
|
||||||
private static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
|
|
||||||
|
|
||||||
@BeforeEach
|
|
||||||
void setUp() throws IOException {
|
|
||||||
validateRateLimitScript = ClusterLuaScript.fromResource(
|
|
||||||
REDIS_CLUSTER_EXTENSION.getRedisCluster(), "lua/validate_rate_limit.lua", ScriptOutputType.INTEGER);
|
|
||||||
}
|
|
||||||
|
|
||||||
@ParameterizedTest
|
|
||||||
@ValueSource(booleans = {true, false})
|
|
||||||
void validate(final boolean failOpen) {
|
|
||||||
final StaticRateLimiter rateLimiter = new StaticRateLimiter("test",
|
|
||||||
new RateLimiterConfig(1, Duration.ofHours(1), failOpen),
|
|
||||||
validateRateLimitScript,
|
|
||||||
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
|
|
||||||
CLOCK);
|
|
||||||
|
|
||||||
final String key = RandomStringUtils.insecure().nextAlphanumeric(16);
|
|
||||||
|
|
||||||
assertDoesNotThrow(() -> rateLimiter.validate(key));
|
|
||||||
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key));
|
|
||||||
}
|
|
||||||
|
|
||||||
@ParameterizedTest
|
|
||||||
@ValueSource(booleans = {true, false})
|
|
||||||
void validateAsync(final boolean failOpen) {
|
|
||||||
final StaticRateLimiter rateLimiter = new StaticRateLimiter("test",
|
|
||||||
new RateLimiterConfig(1, Duration.ofHours(1), failOpen),
|
|
||||||
validateRateLimitScript,
|
|
||||||
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
|
|
||||||
CLOCK);
|
|
||||||
|
|
||||||
final String key = RandomStringUtils.insecure().nextAlphanumeric(16);
|
|
||||||
|
|
||||||
assertDoesNotThrow(() -> rateLimiter.validateAsync(key).toCompletableFuture().join());
|
|
||||||
final CompletionException completionException =
|
|
||||||
assertThrows(CompletionException.class, () -> rateLimiter.validateAsync(key).toCompletableFuture().join());
|
|
||||||
|
|
||||||
assertInstanceOf(RateLimitExceededException.class, completionException.getCause());
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue