From 35604cf15159f091a1b6875e2df284f8d7eca559 Mon Sep 17 00:00:00 2001 From: Jonathan Klabunde Tomer <125505367+jkt-signal@users.noreply.github.com> Date: Wed, 21 May 2025 10:29:26 -0700 Subject: [PATCH] Simplify rate limiters by making them all dynamic --- .../WhisperServerConfiguration.java | 4 - .../textsecuregcm/WhisperServerService.java | 3 +- .../limits/BaseRateLimiters.java | 44 +-- .../limits/DynamicRateLimiter.java | 122 ++++++-- .../limits/RateLimiterDescriptor.java | 7 +- .../textsecuregcm/limits/RateLimiters.java | 104 +++---- .../limits/StaticRateLimiter.java | 170 ----------- .../workers/CommandDependencies.java | 3 +- .../limits/DynamicRateLimiterTest.java | 273 ++++++++++++++++++ .../limits/RateLimitersLuaScriptTest.java | 13 +- .../limits/RateLimitersTest.java | 104 +------ .../limits/StaticRateLimiterTest.java | 74 ----- 12 files changed, 449 insertions(+), 472 deletions(-) delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiterTest.java delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiterTest.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java index 8e8702b76..3e40550be 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java @@ -407,10 +407,6 @@ public class WhisperServerConfiguration extends Configuration { return rateLimitersCluster; } - public Map getLimitsConfiguration() { - return limits; - } - public FcmConfiguration getFcmConfiguration() { return fcm; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 2f52c7dcb..1282339d5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -639,8 +639,7 @@ public class WhisperServerService extends Application { - private final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); - private final Map rateLimiterByDescriptor; - private final Map configs; - - protected BaseRateLimiters( final T[] values, - final Map configs, final DynamicConfigurationManager dynamicConfigurationManager, final ClusterLuaScript validateScript, final FaultTolerantRedisClusterClient cacheCluster, final Clock clock) { - this.configs = configs; this.rateLimiterByDescriptor = Arrays.stream(values) .map(descriptor -> Pair.of( descriptor, - createForDescriptor(descriptor, configs, dynamicConfigurationManager, validateScript, cacheCluster, clock))) + createForDescriptor(descriptor, dynamicConfigurationManager, validateScript, cacheCluster, clock))) .collect(Collectors.toUnmodifiableMap(Pair::getKey, Pair::getValue)); } @@ -53,22 +42,6 @@ public abstract class BaseRateLimiters { return requireNonNull(rateLimiterByDescriptor.get(handle)); } - public void validateValuesAndConfigs() { - final Set ids = rateLimiterByDescriptor.keySet().stream() - .map(RateLimiterDescriptor::id) - .collect(Collectors.toSet()); - for (final String key: configs.keySet()) { - if (!ids.contains(key)) { - final String message = String.format( - "Static configuration has an unexpected field '%s' that doesn't match any RateLimiterDescriptor", - key - ); - logger.error(message); - throw new IllegalArgumentException(message); - } - } - } - protected static ClusterLuaScript defaultScript(final FaultTolerantRedisClusterClient cacheCluster) { try { return ClusterLuaScript.fromResource( @@ -80,21 +53,12 @@ public abstract class BaseRateLimiters { private static RateLimiter createForDescriptor( final RateLimiterDescriptor descriptor, - final Map configs, final DynamicConfigurationManager dynamicConfigurationManager, final ClusterLuaScript validateScript, final FaultTolerantRedisClusterClient cacheCluster, final Clock clock) { - if (descriptor.isDynamic()) { - final Supplier configResolver = () -> { - final RateLimiterConfig config = dynamicConfigurationManager.getConfiguration().getLimits().get(descriptor.id()); - return config != null - ? config - : configs.getOrDefault(descriptor.id(), descriptor.defaultConfig()); - }; - return new DynamicRateLimiter(descriptor.id(), dynamicConfigurationManager, configResolver, validateScript, cacheCluster, clock); - } - final RateLimiterConfig cfg = configs.getOrDefault(descriptor.id(), descriptor.defaultConfig()); - return new StaticRateLimiter(descriptor.id(), cfg, validateScript, cacheCluster, clock); + final Supplier configResolver = + () -> dynamicConfigurationManager.getConfiguration().getLimits().getOrDefault(descriptor.id(), descriptor.defaultConfig()); + return new DynamicRateLimiter(descriptor.id(), configResolver, validateScript, cacheCluster, clock); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiter.java index 1f60550a6..24c788c08 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiter.java @@ -7,87 +7,167 @@ package org.whispersystems.textsecuregcm.limits; import static java.util.Objects.requireNonNull; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; -import org.apache.commons.lang3.tuple.Pair; -import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; -import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; +import org.whispersystems.textsecuregcm.util.ExceptionUtils; +import org.whispersystems.textsecuregcm.util.Util; public class DynamicRateLimiter implements RateLimiter { private final String name; - private final DynamicConfigurationManager dynamicConfigurationManager; private final Supplier configResolver; private final ClusterLuaScript validateScript; private final FaultTolerantRedisClusterClient cluster; - private final Clock clock; + private final Counter limitExceededCounter; - private final AtomicReference> currentHolder = new AtomicReference<>(); + private final Clock clock; public DynamicRateLimiter( final String name, - final DynamicConfigurationManager dynamicConfigurationManager, final Supplier configResolver, final ClusterLuaScript validateScript, final FaultTolerantRedisClusterClient cluster, final Clock clock) { this.name = requireNonNull(name); - this.dynamicConfigurationManager = dynamicConfigurationManager; this.configResolver = requireNonNull(configResolver); this.validateScript = requireNonNull(validateScript); this.cluster = requireNonNull(cluster); this.clock = requireNonNull(clock); + this.limitExceededCounter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "rateLimiterName", name); } @Override public void validate(final String key, final int amount) throws RateLimitExceededException { - current().getRight().validate(key, amount); + final RateLimiterConfig config = config(); + try { + final long deficitPermitsAmount = executeValidateScript(config, key, amount, true); + if (deficitPermitsAmount > 0) { + limitExceededCounter.increment(); + final Duration retryAfter = Duration.ofMillis( + (long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis())); + throw new RateLimitExceededException(retryAfter); + } + } catch (final Exception e) { + if (e instanceof RateLimitExceededException rateLimitExceededException) { + throw rateLimitExceededException; + } + + if (!config.failOpen()) { + throw e; + } + } } @Override public CompletionStage validateAsync(final String key, final int amount) { - return current().getRight().validateAsync(key, amount); + final RateLimiterConfig config = config(); + + return executeValidateScriptAsync(config, key, amount, true) + .thenCompose(deficitPermitsAmount -> { + if (deficitPermitsAmount == 0) { + return CompletableFuture.completedFuture((Void) null); + } + limitExceededCounter.increment(); + final Duration retryAfter = Duration.ofMillis( + (long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis())); + return CompletableFuture.failedFuture(new RateLimitExceededException(retryAfter)); + }) + .exceptionally(throwable -> { + if (ExceptionUtils.unwrap(throwable) instanceof RateLimitExceededException rateLimitExceededException) { + throw ExceptionUtils.wrap(rateLimitExceededException); + } + + if (config.failOpen()) { + return null; + } + + throw ExceptionUtils.wrap(throwable); + }); } @Override public boolean hasAvailablePermits(final String key, final int permits) { - return current().getRight().hasAvailablePermits(key, permits); + final RateLimiterConfig config = config(); + try { + final long deficitPermitsAmount = executeValidateScript(config, key, permits, false); + return deficitPermitsAmount == 0; + } catch (final Exception e) { + if (config.failOpen()) { + return true; + } else { + throw e; + } + } } @Override public CompletionStage hasAvailablePermitsAsync(final String key, final int amount) { - return current().getRight().hasAvailablePermitsAsync(key, amount); + final RateLimiterConfig config = config(); + return executeValidateScriptAsync(config, key, amount, false) + .thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0) + .exceptionally(throwable -> { + if (config.failOpen()) { + return true; + } + throw ExceptionUtils.wrap(throwable); + }); } @Override public void clear(final String key) { - current().getRight().clear(key); + cluster.useCluster(connection -> connection.sync().del(bucketName(name, key))); } @Override public CompletionStage clearAsync(final String key) { - return current().getRight().clearAsync(key); + return cluster.withCluster(connection -> connection.async().del(bucketName(name, key))) + .thenRun(Util.NOOP); } @Override public RateLimiterConfig config() { - return current().getLeft(); + return configResolver.get(); } - private Pair current() { - final RateLimiterConfig cfg = configResolver.get(); - return currentHolder.updateAndGet(p -> p != null && p.getLeft().equals(cfg) - ? p - : Pair.of(cfg, new StaticRateLimiter(name, cfg, validateScript, cluster, clock)) + private long executeValidateScript(final RateLimiterConfig config, final String key, final int amount, final boolean applyChanges) { + final List keys = List.of(bucketName(name, key)); + final List arguments = List.of( + String.valueOf(config.bucketSize()), + String.valueOf(config.leakRatePerMillis()), + String.valueOf(clock.millis()), + String.valueOf(amount), + String.valueOf(applyChanges) ); + return (Long) validateScript.execute(keys, arguments); + } + + private CompletionStage executeValidateScriptAsync(final RateLimiterConfig config, final String key, final int amount, final boolean applyChanges) { + final List keys = List.of(bucketName(name, key)); + final List arguments = List.of( + String.valueOf(config.bucketSize()), + String.valueOf(config.leakRatePerMillis()), + String.valueOf(clock.millis()), + String.valueOf(amount), + String.valueOf(applyChanges) + ); + return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o); + } + + private static String bucketName(final String name, final String key) { + return "leaky_bucket::" + name + "::" + key; } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiterDescriptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiterDescriptor.java index 38ae87b80..c27e21622 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiterDescriptor.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiterDescriptor.java @@ -15,14 +15,9 @@ public interface RateLimiterDescriptor { */ String id(); - /** - * @return {@code true} if this rate limiter needs to watch for dynamic configuration changes. - */ - boolean isDynamic(); - /** * @return an instance of {@link RateLimiterConfig} to be used by default, - * i.e. if there is no overrides in the application configuration files (static or dynamic). + * i.e. if there is no override in the application dynamic configuration. */ RateLimiterConfig defaultConfig(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java index 052f89340..e42584913 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java @@ -4,11 +4,9 @@ */ package org.whispersystems.textsecuregcm.limits; - import com.google.common.annotations.VisibleForTesting; import java.time.Clock; import java.time.Duration; -import java.util.Map; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; @@ -17,57 +15,54 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; public class RateLimiters extends BaseRateLimiters { public enum For implements RateLimiterDescriptor { - BACKUP_AUTH_CHECK("backupAuthCheck", false, new RateLimiterConfig(100, Duration.ofMinutes(15), true)), - PIN("pin", false, new RateLimiterConfig(10, Duration.ofDays(1), false)), - ATTACHMENT("attachmentCreate", false, new RateLimiterConfig(50, Duration.ofMillis(1200), true)), - BACKUP_ATTACHMENT("backupAttachmentCreate", true, new RateLimiterConfig(10_000, Duration.ofSeconds(1), true)), - PRE_KEYS("prekeys", false, new RateLimiterConfig(6, Duration.ofMinutes(10), false)), - MESSAGES("messages", false, new RateLimiterConfig(60, Duration.ofSeconds(1), true)), - STORIES("stories", false, new RateLimiterConfig(5_000, Duration.ofSeconds(8), true)), - ALLOCATE_DEVICE("allocateDevice", false, new RateLimiterConfig(6, Duration.ofMinutes(2), false)), - VERIFY_DEVICE("verifyDevice", false, new RateLimiterConfig(6, Duration.ofMinutes(2), false)), - PROFILE("profile", false, new RateLimiterConfig(4320, Duration.ofSeconds(20), true)), - STICKER_PACK("stickerPack", false, new RateLimiterConfig(50, Duration.ofMinutes(72), false)), - USERNAME_LOOKUP("usernameLookup", false, new RateLimiterConfig(100, Duration.ofMinutes(15), true)), - USERNAME_SET("usernameSet", false, new RateLimiterConfig(100, Duration.ofMinutes(15), false)), - USERNAME_RESERVE("usernameReserve", false, new RateLimiterConfig(100, Duration.ofMinutes(15), false)), - USERNAME_LINK_OPERATION("usernameLinkOperation", false, new RateLimiterConfig(10, Duration.ofMinutes(1), false)), - USERNAME_LINK_LOOKUP_PER_IP("usernameLinkLookupPerIp", false, new RateLimiterConfig(100, Duration.ofSeconds(15), true)), - CHECK_ACCOUNT_EXISTENCE("checkAccountExistence", false, new RateLimiterConfig(1000, Duration.ofSeconds(4), true)), - REGISTRATION("registration", false, new RateLimiterConfig(6, Duration.ofSeconds(30), false)), - VERIFICATION_PUSH_CHALLENGE("verificationPushChallenge", false, new RateLimiterConfig(5, Duration.ofSeconds(30), false)), - VERIFICATION_CAPTCHA("verificationCaptcha", false, new RateLimiterConfig(10, Duration.ofSeconds(30), false)), - RATE_LIMIT_RESET("rateLimitReset", true, new RateLimiterConfig(2, Duration.ofHours(12), false)), - CAPTCHA_CHALLENGE_ATTEMPT("captchaChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofMinutes(144), false)), - CAPTCHA_CHALLENGE_SUCCESS("captchaChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofHours(12), false)), - SET_BACKUP_ID("setBackupId", true, new RateLimiterConfig(10, Duration.ofHours(1), false)), - SET_PAID_MEDIA_BACKUP_ID("setPaidMediaBackupId", true, new RateLimiterConfig(5, Duration.ofDays(7), false)), - PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofMinutes(144), false)), - PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofHours(12), false)), - GET_CALLING_RELAYS("getCallingRelays", false, new RateLimiterConfig(100, Duration.ofMinutes(10), false)), - CREATE_CALL_LINK("createCallLink", false, new RateLimiterConfig(100, Duration.ofMinutes(15), false)), - INBOUND_MESSAGE_BYTES("inboundMessageBytes", true, new RateLimiterConfig(128 * 1024 * 1024, Duration.ofNanos(500_000), true)), - EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", true, new RateLimiterConfig(100, Duration.ofMinutes(15), false)), - KEY_TRANSPARENCY_DISTINGUISHED_PER_IP("keyTransparencyDistinguished", true, new RateLimiterConfig(100, Duration.ofSeconds(15), true)), - KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", true, new RateLimiterConfig(100, Duration.ofSeconds(15), true)), - KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", true, new RateLimiterConfig(100, Duration.ofSeconds(15), true)), - WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", true, new RateLimiterConfig(10, Duration.ofSeconds(30), false)), - UPLOAD_TRANSFER_ARCHIVE("uploadTransferArchive", true, new RateLimiterConfig(10, Duration.ofMinutes(1), false)), - WAIT_FOR_TRANSFER_ARCHIVE("waitForTransferArchive", true, new RateLimiterConfig(10, Duration.ofSeconds(30), false)), - RECORD_DEVICE_TRANSFER_REQUEST("recordDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100), true)), - WAIT_FOR_DEVICE_TRANSFER_REQUEST("waitForDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100), true)), - DEVICE_CHECK_CHALLENGE("deviceCheckChallenge", true, new RateLimiterConfig(10, Duration.ofMinutes(1), false)), + BACKUP_AUTH_CHECK("backupAuthCheck", new RateLimiterConfig(100, Duration.ofMinutes(15), true)), + PIN("pin", new RateLimiterConfig(10, Duration.ofDays(1), false)), + ATTACHMENT("attachmentCreate", new RateLimiterConfig(50, Duration.ofMillis(1200), true)), + BACKUP_ATTACHMENT("backupAttachmentCreate", new RateLimiterConfig(10_000, Duration.ofSeconds(1), true)), + PRE_KEYS("prekeys", new RateLimiterConfig(6, Duration.ofMinutes(10), false)), + MESSAGES("messages", new RateLimiterConfig(60, Duration.ofSeconds(1), true)), + STORIES("stories", new RateLimiterConfig(5_000, Duration.ofSeconds(8), true)), + ALLOCATE_DEVICE("allocateDevice", new RateLimiterConfig(6, Duration.ofMinutes(2), false)), + VERIFY_DEVICE("verifyDevice", new RateLimiterConfig(6, Duration.ofMinutes(2), false)), + PROFILE("profile", new RateLimiterConfig(4320, Duration.ofSeconds(20), true)), + STICKER_PACK("stickerPack", new RateLimiterConfig(50, Duration.ofMinutes(72), false)), + USERNAME_LOOKUP("usernameLookup", new RateLimiterConfig(100, Duration.ofMinutes(15), true)), + USERNAME_SET("usernameSet", new RateLimiterConfig(100, Duration.ofMinutes(15), false)), + USERNAME_RESERVE("usernameReserve", new RateLimiterConfig(100, Duration.ofMinutes(15), false)), + USERNAME_LINK_OPERATION("usernameLinkOperation", new RateLimiterConfig(10, Duration.ofMinutes(1), false)), + USERNAME_LINK_LOOKUP_PER_IP("usernameLinkLookupPerIp", new RateLimiterConfig(100, Duration.ofSeconds(15), true)), + CHECK_ACCOUNT_EXISTENCE("checkAccountExistence", new RateLimiterConfig(1000, Duration.ofSeconds(4), true)), + REGISTRATION("registration", new RateLimiterConfig(6, Duration.ofSeconds(30), false)), + VERIFICATION_PUSH_CHALLENGE("verificationPushChallenge", new RateLimiterConfig(5, Duration.ofSeconds(30), false)), + VERIFICATION_CAPTCHA("verificationCaptcha", new RateLimiterConfig(10, Duration.ofSeconds(30), false)), + RATE_LIMIT_RESET("rateLimitReset", new RateLimiterConfig(2, Duration.ofHours(12), false)), + CAPTCHA_CHALLENGE_ATTEMPT("captchaChallengeAttempt", new RateLimiterConfig(10, Duration.ofMinutes(144), false)), + CAPTCHA_CHALLENGE_SUCCESS("captchaChallengeSuccess", new RateLimiterConfig(2, Duration.ofHours(12), false)), + SET_BACKUP_ID("setBackupId", new RateLimiterConfig(10, Duration.ofHours(1), false)), + SET_PAID_MEDIA_BACKUP_ID("setPaidMediaBackupId", new RateLimiterConfig(5, Duration.ofDays(7), false)), + PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", new RateLimiterConfig(10, Duration.ofMinutes(144), false)), + PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", new RateLimiterConfig(2, Duration.ofHours(12), false)), + GET_CALLING_RELAYS("getCallingRelays", new RateLimiterConfig(100, Duration.ofMinutes(10), false)), + CREATE_CALL_LINK("createCallLink", new RateLimiterConfig(100, Duration.ofMinutes(15), false)), + INBOUND_MESSAGE_BYTES("inboundMessageBytes", new RateLimiterConfig(128 * 1024 * 1024, Duration.ofNanos(500_000), true)), + EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", new RateLimiterConfig(100, Duration.ofMinutes(15), false)), + KEY_TRANSPARENCY_DISTINGUISHED_PER_IP("keyTransparencyDistinguished", new RateLimiterConfig(100, Duration.ofSeconds(15), true)), + KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", new RateLimiterConfig(100, Duration.ofSeconds(15), true)), + KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", new RateLimiterConfig(100, Duration.ofSeconds(15), true)), + WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", new RateLimiterConfig(10, Duration.ofSeconds(30), false)), + UPLOAD_TRANSFER_ARCHIVE("uploadTransferArchive", new RateLimiterConfig(10, Duration.ofMinutes(1), false)), + WAIT_FOR_TRANSFER_ARCHIVE("waitForTransferArchive", new RateLimiterConfig(10, Duration.ofSeconds(30), false)), + RECORD_DEVICE_TRANSFER_REQUEST("recordDeviceTransferRequest", new RateLimiterConfig(10, Duration.ofMillis(100), true)), + WAIT_FOR_DEVICE_TRANSFER_REQUEST("waitForDeviceTransferRequest", new RateLimiterConfig(10, Duration.ofMillis(100), true)), + DEVICE_CHECK_CHALLENGE("deviceCheckChallenge", new RateLimiterConfig(10, Duration.ofMinutes(1), false)), ; private final String id; - private final boolean dynamic; - private final RateLimiterConfig defaultConfig; - For(final String id, final boolean dynamic, final RateLimiterConfig defaultConfig) { + For(final String id, final RateLimiterConfig defaultConfig) { this.id = id; - this.dynamic = dynamic; this.defaultConfig = defaultConfig; } @@ -75,34 +70,25 @@ public class RateLimiters extends BaseRateLimiters { return id; } - @Override - public boolean isDynamic() { - return dynamic; - } - public RateLimiterConfig defaultConfig() { return defaultConfig; } } - public static RateLimiters createAndValidate( - final Map configs, + public static RateLimiters create( final DynamicConfigurationManager dynamicConfigurationManager, final FaultTolerantRedisClusterClient cacheCluster) { - final RateLimiters rateLimiters = new RateLimiters( - configs, dynamicConfigurationManager, defaultScript(cacheCluster), cacheCluster, Clock.systemUTC()); - rateLimiters.validateValuesAndConfigs(); - return rateLimiters; + return new RateLimiters( + dynamicConfigurationManager, defaultScript(cacheCluster), cacheCluster, Clock.systemUTC()); } @VisibleForTesting RateLimiters( - final Map configs, final DynamicConfigurationManager dynamicConfigurationManager, final ClusterLuaScript validateScript, final FaultTolerantRedisClusterClient cacheCluster, final Clock clock) { - super(For.values(), configs, dynamicConfigurationManager, validateScript, cacheCluster, clock); + super(For.values(), dynamicConfigurationManager, validateScript, cacheCluster, clock); } public RateLimiter getAllocateDeviceLimiter() { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java deleted file mode 100644 index 5aa611b84..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java +++ /dev/null @@ -1,170 +0,0 @@ -/* - * Copyright 2013 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.limits; - -import static java.util.Objects.requireNonNull; -import static java.util.concurrent.CompletableFuture.completedFuture; -import static java.util.concurrent.CompletableFuture.failedFuture; - -import com.google.common.annotations.VisibleForTesting; -import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.Metrics; -import java.time.Clock; -import java.time.Duration; -import java.util.List; -import java.util.concurrent.CompletionStage; -import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; -import org.whispersystems.textsecuregcm.metrics.MetricsUtil; -import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; -import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; -import org.whispersystems.textsecuregcm.util.ExceptionUtils; -import org.whispersystems.textsecuregcm.util.Util; - -public class StaticRateLimiter implements RateLimiter { - - protected final String name; - - private final RateLimiterConfig config; - - private final Counter limitExceededCounter; - - private final ClusterLuaScript validateScript; - - private final FaultTolerantRedisClusterClient cacheCluster; - - private final Clock clock; - - - public StaticRateLimiter( - final String name, - final RateLimiterConfig config, - final ClusterLuaScript validateScript, - final FaultTolerantRedisClusterClient cacheCluster, - final Clock clock) { - this.name = requireNonNull(name); - this.config = requireNonNull(config); - this.validateScript = requireNonNull(validateScript); - this.cacheCluster = requireNonNull(cacheCluster); - this.clock = requireNonNull(clock); - this.limitExceededCounter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "rateLimiterName", name); - } - - @Override - public void validate(final String key, final int amount) throws RateLimitExceededException { - try { - final long deficitPermitsAmount = executeValidateScript(key, amount, true); - if (deficitPermitsAmount > 0) { - limitExceededCounter.increment(); - final Duration retryAfter = Duration.ofMillis( - (long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis())); - throw new RateLimitExceededException(retryAfter); - } - } catch (final Exception e) { - if (e instanceof RateLimitExceededException rateLimitExceededException) { - throw rateLimitExceededException; - } - - if (!config.failOpen()) { - throw e; - } - } - } - - @Override - public CompletionStage validateAsync(final String key, final int amount) { - return executeValidateScriptAsync(key, amount, true) - .thenCompose(deficitPermitsAmount -> { - if (deficitPermitsAmount == 0) { - return completedFuture((Void) null); - } - limitExceededCounter.increment(); - final Duration retryAfter = Duration.ofMillis( - (long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis())); - return failedFuture(new RateLimitExceededException(retryAfter)); - }) - .exceptionally(throwable -> { - if (ExceptionUtils.unwrap(throwable) instanceof RateLimitExceededException rateLimitExceededException) { - throw ExceptionUtils.wrap(rateLimitExceededException); - } - - if (config.failOpen()) { - return null; - } - - throw ExceptionUtils.wrap(throwable); - }); - } - - @Override - public boolean hasAvailablePermits(final String key, final int amount) { - try { - final long deficitPermitsAmount = executeValidateScript(key, amount, false); - return deficitPermitsAmount == 0; - } catch (final Exception e) { - if (config.failOpen()) { - return true; - } else { - throw e; - } - } - } - - @Override - public CompletionStage hasAvailablePermitsAsync(final String key, final int amount) { - return executeValidateScriptAsync(key, amount, false) - .thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0) - .exceptionally(throwable -> { - if (config.failOpen()) { - return true; - } - throw ExceptionUtils.wrap(throwable); - }); - } - - @Override - public void clear(final String key) { - cacheCluster.useCluster(connection -> connection.sync().del(bucketName(name, key))); - } - - @Override - public CompletionStage clearAsync(final String key) { - return cacheCluster.withCluster(connection -> connection.async().del(bucketName(name, key))) - .thenRun(Util.NOOP); - } - - @Override - public RateLimiterConfig config() { - return config; - } - - private long executeValidateScript(final String key, final int amount, final boolean applyChanges) { - final List keys = List.of(bucketName(name, key)); - final List arguments = List.of( - String.valueOf(config.bucketSize()), - String.valueOf(config.leakRatePerMillis()), - String.valueOf(clock.millis()), - String.valueOf(amount), - String.valueOf(applyChanges) - ); - return (Long) validateScript.execute(keys, arguments); - } - - private CompletionStage executeValidateScriptAsync(final String key, final int amount, final boolean applyChanges) { - final List keys = List.of(bucketName(name, key)); - final List arguments = List.of( - String.valueOf(config.bucketSize()), - String.valueOf(config.leakRatePerMillis()), - String.valueOf(clock.millis()), - String.valueOf(amount), - String.valueOf(applyChanges) - ); - return validateScript.executeAsync(keys, arguments).thenApply(o -> (Long) o); - } - - @VisibleForTesting - protected static String bucketName(final String name, final String key) { - return "leaky_bucket::" + name + "::" + key; - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java index 7d3d6fdda..c836dd0bb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java @@ -241,8 +241,7 @@ record CommandDependencies( secureStorageClient, secureValueRecovery2Client, disconnectionRequestManager, registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, messagePollExecutor, clock, configuration.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager); - RateLimiters rateLimiters = RateLimiters.createAndValidate(configuration.getLimitsConfiguration(), - dynamicConfigurationManager, rateLimitersCluster); + RateLimiters rateLimiters = RateLimiters.create(dynamicConfigurationManager, rateLimitersCluster); final BackupsDb backupsDb = new BackupsDb(dynamoDbAsyncClient, configuration.getDynamoDbTables().getBackups().getTableName(), clock); final GenericServerSecretParams backupsGenericZkSecretParams; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiterTest.java new file mode 100644 index 000000000..0a731a67e --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiterTest.java @@ -0,0 +1,273 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.limits; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.lettuce.core.ScriptOutputType; +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; +import org.whispersystems.textsecuregcm.util.TestClock; + +class DynamicRateLimiterTest { + + private ClusterLuaScript validateRateLimitScript; + + private static final TestClock CLOCK = TestClock.pinned(Instant.now()); + + @RegisterExtension + private static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); + + @BeforeEach + void setUp() throws IOException { + validateRateLimitScript = ClusterLuaScript.fromResource( + REDIS_CLUSTER_EXTENSION.getRedisCluster(), "lua/validate_rate_limit.lua", ScriptOutputType.INTEGER); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void validate(final boolean failOpen) { + final DynamicRateLimiter rateLimiter = new DynamicRateLimiter( + "test", + () -> new RateLimiterConfig(1, Duration.ofHours(1), failOpen), + validateRateLimitScript, + REDIS_CLUSTER_EXTENSION.getRedisCluster(), + CLOCK); + + final String key = RandomStringUtils.insecure().nextAlphanumeric(16); + + assertDoesNotThrow(() -> rateLimiter.validate(key)); + assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key)); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void validateAsync(final boolean failOpen) { + final DynamicRateLimiter rateLimiter = new DynamicRateLimiter( + "test", + () -> new RateLimiterConfig(1, Duration.ofHours(1), failOpen), + validateRateLimitScript, + REDIS_CLUSTER_EXTENSION.getRedisCluster(), + CLOCK); + + final String key = RandomStringUtils.insecure().nextAlphanumeric(16); + + assertDoesNotThrow(() -> rateLimiter.validateAsync(key).toCompletableFuture().join()); + final CompletionException completionException = + assertThrows(CompletionException.class, () -> rateLimiter.validateAsync(key).toCompletableFuture().join()); + + assertInstanceOf(RateLimitExceededException.class, completionException.getCause()); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void validateFailOpen(final boolean failOpen) { + final ClusterLuaScript failingScript = mock(ClusterLuaScript.class); + when(failingScript.execute(any(), any())).thenThrow(new RuntimeException("OH NO")); + + final DynamicRateLimiter rateLimiter = new DynamicRateLimiter( + "test", + () -> new RateLimiterConfig(1, Duration.ofHours(1), failOpen), + failingScript, + REDIS_CLUSTER_EXTENSION.getRedisCluster(), + CLOCK); + + final String key = RandomStringUtils.insecure().nextAlphanumeric(16); + + if (failOpen) { + assertDoesNotThrow(() -> rateLimiter.validate(key)); + } else { + assertThrows(RuntimeException.class, () -> rateLimiter.validate(key)); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void validateFailOpenAsync(final boolean failOpen) { + final ClusterLuaScript failingScript = mock(ClusterLuaScript.class); + when(failingScript.executeAsync(any(), any())).thenReturn(CompletableFuture.failedFuture(new RuntimeException("OH NO"))); + + final DynamicRateLimiter rateLimiter = new DynamicRateLimiter( + "test", + () -> new RateLimiterConfig(1, Duration.ofHours(1), failOpen), + failingScript, + REDIS_CLUSTER_EXTENSION.getRedisCluster(), + CLOCK); + + final String key = RandomStringUtils.insecure().nextAlphanumeric(16); + + if (failOpen) { + assertDoesNotThrow(() -> rateLimiter.validate(key)); + } else { + final CompletionException completionException = + assertThrows(CompletionException.class, () -> rateLimiter.validateAsync(key).toCompletableFuture().join()); + + assertInstanceOf(RuntimeException.class, completionException.getCause()); + } + } + + @Test + void configChange_ReduceRefillRate() { + final AtomicReference refillRate = new AtomicReference<>(Duration.ofMinutes(5)); + final DynamicRateLimiter rateLimiter = new DynamicRateLimiter( + "test", + () -> new RateLimiterConfig(1, refillRate.get(), false), + validateRateLimitScript, + REDIS_CLUSTER_EXTENSION.getRedisCluster(), + CLOCK); + + final String key = RandomStringUtils.insecure().nextAlphanumeric(16); + + assertDoesNotThrow(() -> rateLimiter.validate(key)); + assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key)); + + CLOCK.pin(CLOCK.instant().plus(Duration.ofMinutes(1))); + assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key)); + + refillRate.set(Duration.ofMinutes(1)); + assertDoesNotThrow(() -> rateLimiter.validate(key)); + assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key)); + } + + @Test + void configChange_IncreaseRefillRate() { + final AtomicReference refillRate = new AtomicReference<>(Duration.ofMinutes(5)); + final DynamicRateLimiter rateLimiter = new DynamicRateLimiter( + "test", + () -> new RateLimiterConfig(1, refillRate.get(), false), + validateRateLimitScript, + REDIS_CLUSTER_EXTENSION.getRedisCluster(), + CLOCK); + + final String key = RandomStringUtils.insecure().nextAlphanumeric(16); + + assertDoesNotThrow(() -> rateLimiter.validate(key)); + assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key)); + + CLOCK.pin(CLOCK.instant().plus(Duration.ofMinutes(5))); + assertTrue(rateLimiter.hasAvailablePermits(key, 1)); + + refillRate.set(Duration.ofMinutes(10)); + assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key)); + + CLOCK.pin(CLOCK.instant().plus(Duration.ofMinutes(5))); + assertDoesNotThrow(() -> rateLimiter.validate(key)); + assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key)); + } + + @Test + void configChange_ReduceBucketSize() { + final AtomicInteger bucketSize = new AtomicInteger(5); + final DynamicRateLimiter rateLimiter = new DynamicRateLimiter( + "test", + () -> new RateLimiterConfig(bucketSize.get(), Duration.ofMinutes(1), false), + validateRateLimitScript, + REDIS_CLUSTER_EXTENSION.getRedisCluster(), + CLOCK); + + final String key = RandomStringUtils.insecure().nextAlphanumeric(16); + + assertDoesNotThrow(() -> rateLimiter.validate(key)); + assertTrue(rateLimiter.hasAvailablePermits(key, 4)); + assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key, 5)); + + bucketSize.set(1); + // Changing the bucket size doesn't spend the tokens remaining in existing buckets, but does + // effectively make those buckets overflow if it got smaller. There were 4 tokens available + // before, so changing the bucket size to 1 effectively means there is 1 token left, not 0 + assertTrue(rateLimiter.hasAvailablePermits(key, 1)); + assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key, 2)); + } + + @Test + void configChange_IncreaseBucketSize() { + final AtomicInteger bucketSize = new AtomicInteger(5); + final DynamicRateLimiter rateLimiter = new DynamicRateLimiter( + "test", + () -> new RateLimiterConfig(bucketSize.get(), Duration.ofMinutes(1), false), + validateRateLimitScript, + REDIS_CLUSTER_EXTENSION.getRedisCluster(), + CLOCK); + + final String key = RandomStringUtils.insecure().nextAlphanumeric(16); + + assertDoesNotThrow(() -> rateLimiter.validate(key)); + assertTrue(rateLimiter.hasAvailablePermits(key, 4)); + assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key, 5)); + + bucketSize.set(10); + // Increasing the bucket size doesn't retroactively refill buckets in redis, so we have to wait + // until the bucket fills up + CLOCK.pin(CLOCK.instant().plus(Duration.ofMinutes(10))); + assertTrue(rateLimiter.hasAvailablePermits(key, 10)); + assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key, 11)); + } + + @Test + void configChange_enableFailOpen() { + final ClusterLuaScript failingScript = mock(ClusterLuaScript.class); + when(failingScript.execute(any(), any())).thenThrow(new RuntimeException("OH NO")); + + final AtomicBoolean failOpen = new AtomicBoolean(false); + final DynamicRateLimiter rateLimiter = new DynamicRateLimiter( + "test", + () -> new RateLimiterConfig(1, Duration.ofMinutes(1), failOpen.get()), + failingScript, + REDIS_CLUSTER_EXTENSION.getRedisCluster(), + CLOCK); + + final String key = RandomStringUtils.insecure().nextAlphanumeric(16); + + assertThrows(RuntimeException.class, () -> rateLimiter.validate(key)); + + failOpen.set(true); + + assertDoesNotThrow(() -> rateLimiter.validate(key)); + } + + @Test + void configChange_disableFailOpen() { + final ClusterLuaScript failingScript = mock(ClusterLuaScript.class); + when(failingScript.execute(any(), any())).thenThrow(new RuntimeException("OH NO")); + + final AtomicBoolean failOpen = new AtomicBoolean(true); + final DynamicRateLimiter rateLimiter = new DynamicRateLimiter( + "test", + () -> new RateLimiterConfig(1, Duration.ofMinutes(1), failOpen.get()), + failingScript, + REDIS_CLUSTER_EXTENSION.getRedisCluster(), + CLOCK); + + final String key = RandomStringUtils.insecure().nextAlphanumeric(16); + + assertDoesNotThrow(() -> rateLimiter.validate(key)); + + failOpen.set(false); + + assertThrows(RuntimeException.class, () -> rateLimiter.validate(key)); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java index c1346bccc..c3336d3c1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java @@ -57,9 +57,11 @@ public class RateLimitersLuaScriptTest { @Test public void testWithEmbeddedRedis() throws Exception { final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION; + final Map limiterConfig = Map.of(descriptor.id(), new RateLimiterConfig(60, Duration.ofSeconds(1), false)); + when(configuration.getLimits()).thenReturn(limiterConfig); + final FaultTolerantRedisClusterClient redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster(); final RateLimiters limiters = new RateLimiters( - Map.of(descriptor.id(), new RateLimiterConfig(60, Duration.ofSeconds(1), false)), dynamicConfig, RateLimiters.defaultScript(redisCluster), redisCluster, @@ -74,9 +76,11 @@ public class RateLimitersLuaScriptTest { @Test public void testTtl() throws Exception { final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION; + final Map limiterConfig = Map.of(descriptor.id(), new RateLimiterConfig(1000, Duration.ofSeconds(1), false)); + when(configuration.getLimits()).thenReturn(limiterConfig); + final FaultTolerantRedisClusterClient redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster(); final RateLimiters limiters = new RateLimiters( - Map.of(descriptor.id(), new RateLimiterConfig(1000, Duration.ofSeconds(1), false)), dynamicConfig, RateLimiters.defaultScript(redisCluster), redisCluster, @@ -126,8 +130,11 @@ public class RateLimitersLuaScriptTest { public void testFailOpen(final boolean failOpen) { final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION; final FaultTolerantRedisClusterClient redisCluster = mock(FaultTolerantRedisClusterClient.class); + + final Map limiterConfig = Map.of(descriptor.id(), new RateLimiterConfig(1, Duration.ofSeconds(1), failOpen)); + when(configuration.getLimits()).thenReturn(limiterConfig); + final RateLimiters limiters = new RateLimiters( - Map.of(descriptor.id(), new RateLimiterConfig(1000, Duration.ofSeconds(1), failOpen)), dynamicConfig, RateLimiters.defaultScript(redisCluster), redisCluster, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java index fcf783df5..608b67f82 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java @@ -5,17 +5,12 @@ package org.whispersystems.textsecuregcm.limits; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import com.fasterxml.jackson.annotation.JsonProperty; -import jakarta.validation.Valid; -import jakarta.validation.constraints.NotNull; import java.time.Duration; -import java.util.Collections; import java.util.HashMap; import java.util.Map; import org.junit.jupiter.api.Test; @@ -40,48 +35,6 @@ public class RateLimitersTest { private final MutableClock clock = MockUtils.mutableClock(0); - private static final String BAD_YAML = """ - limits: - prekeys: - bucketSize: 150 - permitRegenerationDuration: PT6S - unexpected: - bucketSize: 4 - permitRegenerationDuration: PT30S - """; - - private static final String GOOD_YAML = """ - limits: - prekeys: - bucketSize: 150 - permitRegenerationDuration: PT6S - failOpen: true - attachmentCreate: - bucketSize: 4 - permitRegenerationDuration: PT30S - failOpen: true - """; - - public record SimpleDynamicConfiguration(@Valid @NotNull @JsonProperty Map limits) { - } - - @Test - public void testValidateConfigs() throws Exception { - assertThrows(IllegalArgumentException.class, () -> { - final SimpleDynamicConfiguration dynamicConfiguration = - DynamicConfigurationManager.parseConfiguration(BAD_YAML, SimpleDynamicConfiguration.class).orElseThrow(); - - final RateLimiters rateLimiters = new RateLimiters(dynamicConfiguration.limits(), dynamicConfig, validateScript, redisCluster, clock); - rateLimiters.validateValuesAndConfigs(); - }); - - final SimpleDynamicConfiguration dynamicConfiguration = - DynamicConfigurationManager.parseConfiguration(GOOD_YAML, SimpleDynamicConfiguration.class).orElseThrow(); - - final RateLimiters rateLimiters = new RateLimiters(dynamicConfiguration.limits(), dynamicConfig, validateScript, redisCluster, clock); - assertDoesNotThrow(rateLimiters::validateValuesAndConfigs); - } - @Test public void testValidateDuplicates() throws Exception { final TestDescriptor td1 = new TestDescriptor("id1"); @@ -91,7 +44,6 @@ public class RateLimitersTest { assertThrows(IllegalStateException.class, () -> new BaseRateLimiters<>( new TestDescriptor[] { td1, td2, td3, tdDup }, - Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, @@ -99,7 +51,6 @@ public class RateLimitersTest { new BaseRateLimiters<>( new TestDescriptor[] { td1, td2, td3 }, - Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, @@ -108,10 +59,10 @@ public class RateLimitersTest { @Test void testUnchangingConfiguration() { - final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock); + final RateLimiters rateLimiters = new RateLimiters(dynamicConfig, validateScript, redisCluster, clock); final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter(); final RateLimiterConfig expected = RateLimiters.For.RATE_LIMIT_RESET.defaultConfig(); - assertEquals(expected, config(limiter)); + assertEquals(expected, limiter.config()); } @Test @@ -127,78 +78,49 @@ public class RateLimitersTest { when(configuration.getLimits()).thenReturn(limitsConfigMap); - final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, validateScript, redisCluster, clock); + final RateLimiters rateLimiters = new RateLimiters(dynamicConfig, validateScript, redisCluster, clock); final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter(); limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), initialRateLimiterConfig); - assertEquals(initialRateLimiterConfig, config(limiter)); + assertEquals(initialRateLimiterConfig, limiter.config()); - assertEquals(baseConfig, config(rateLimiters.getCaptchaChallengeAttemptLimiter())); - assertEquals(baseConfig, config(rateLimiters.getCaptchaChallengeSuccessLimiter())); + assertEquals(baseConfig, rateLimiters.getCaptchaChallengeAttemptLimiter().config()); + assertEquals(baseConfig, rateLimiters.getCaptchaChallengeSuccessLimiter().config()); limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), updatedRateLimiterCongig); - assertEquals(updatedRateLimiterCongig, config(limiter)); + assertEquals(updatedRateLimiterCongig, limiter.config()); - assertEquals(baseConfig, config(rateLimiters.getCaptchaChallengeAttemptLimiter())); - assertEquals(baseConfig, config(rateLimiters.getCaptchaChallengeSuccessLimiter())); + assertEquals(baseConfig, rateLimiters.getCaptchaChallengeAttemptLimiter().config()); + assertEquals(baseConfig, rateLimiters.getCaptchaChallengeSuccessLimiter().config()); } @Test public void testRateLimiterHasItsPrioritiesStraight() throws Exception { final RateLimiters.For descriptor = RateLimiters.For.CAPTCHA_CHALLENGE_ATTEMPT; final RateLimiterConfig configForDynamic = new RateLimiterConfig(1, Duration.ofMinutes(1), false); - final RateLimiterConfig configForStatic = new RateLimiterConfig(2, Duration.ofSeconds(30), false); final RateLimiterConfig defaultConfig = descriptor.defaultConfig(); final Map mapForDynamic = new HashMap<>(); - final Map mapForStatic = new HashMap<>(); when(configuration.getLimits()).thenReturn(mapForDynamic); - final RateLimiters rateLimiters = new RateLimiters(mapForStatic, dynamicConfig, validateScript, redisCluster, clock); + final RateLimiters rateLimiters = new RateLimiters(dynamicConfig, validateScript, redisCluster, clock); final RateLimiter limiter = rateLimiters.forDescriptor(descriptor); // test only default is present mapForDynamic.remove(descriptor.id()); - mapForStatic.remove(descriptor.id()); - assertEquals(defaultConfig, config(limiter)); + assertEquals(defaultConfig, limiter.config()); - // test dynamic and no static + // test dynamic config is present mapForDynamic.put(descriptor.id(), configForDynamic); - mapForStatic.remove(descriptor.id()); - assertEquals(configForDynamic, config(limiter)); - - // test dynamic and static - mapForDynamic.put(descriptor.id(), configForDynamic); - mapForStatic.put(descriptor.id(), configForStatic); - assertEquals(configForDynamic, config(limiter)); - - // test static, but no dynamic - mapForDynamic.remove(descriptor.id()); - mapForStatic.put(descriptor.id(), configForStatic); - assertEquals(configForStatic, config(limiter)); + assertEquals(configForDynamic, limiter.config()); } private record TestDescriptor(String id) implements RateLimiterDescriptor { - @Override - public boolean isDynamic() { - return false; - } - @Override public RateLimiterConfig defaultConfig() { return new RateLimiterConfig(1, Duration.ofMinutes(1), false); } } - - private static RateLimiterConfig config(final RateLimiter rateLimiter) { - if (rateLimiter instanceof StaticRateLimiter rm) { - return rm.config(); - } - if (rateLimiter instanceof DynamicRateLimiter rm) { - return rm.config(); - } - throw new IllegalArgumentException("Rate limiter is of an unexpected type: " + rateLimiter.getClass().getName()); - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiterTest.java deleted file mode 100644 index 1f355ca0f..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiterTest.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright 2025 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.limits; - -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertThrows; - -import io.lettuce.core.ScriptOutputType; -import java.io.IOException; -import java.time.Duration; -import java.time.Instant; -import java.util.concurrent.CompletionException; -import org.apache.commons.lang3.RandomStringUtils; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.extension.RegisterExtension; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; -import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; -import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; -import org.whispersystems.textsecuregcm.util.TestClock; - -class StaticRateLimiterTest { - - private ClusterLuaScript validateRateLimitScript; - - private static final TestClock CLOCK = TestClock.pinned(Instant.now()); - - @RegisterExtension - private static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); - - @BeforeEach - void setUp() throws IOException { - validateRateLimitScript = ClusterLuaScript.fromResource( - REDIS_CLUSTER_EXTENSION.getRedisCluster(), "lua/validate_rate_limit.lua", ScriptOutputType.INTEGER); - } - - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void validate(final boolean failOpen) { - final StaticRateLimiter rateLimiter = new StaticRateLimiter("test", - new RateLimiterConfig(1, Duration.ofHours(1), failOpen), - validateRateLimitScript, - REDIS_CLUSTER_EXTENSION.getRedisCluster(), - CLOCK); - - final String key = RandomStringUtils.insecure().nextAlphanumeric(16); - - assertDoesNotThrow(() -> rateLimiter.validate(key)); - assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key)); - } - - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void validateAsync(final boolean failOpen) { - final StaticRateLimiter rateLimiter = new StaticRateLimiter("test", - new RateLimiterConfig(1, Duration.ofHours(1), failOpen), - validateRateLimitScript, - REDIS_CLUSTER_EXTENSION.getRedisCluster(), - CLOCK); - - final String key = RandomStringUtils.insecure().nextAlphanumeric(16); - - assertDoesNotThrow(() -> rateLimiter.validateAsync(key).toCompletableFuture().join()); - final CompletionException completionException = - assertThrows(CompletionException.class, () -> rateLimiter.validateAsync(key).toCompletableFuture().join()); - - assertInstanceOf(RateLimitExceededException.class, completionException.getCause()); - } -}