Configure fail-open policy on individual rate limiters

This commit is contained in:
Jon Chambers 2025-03-28 16:45:05 -04:00 committed by Jon Chambers
parent e9bd5da2c3
commit 771a700acd
10 changed files with 93 additions and 113 deletions

View File

@ -46,10 +46,6 @@ public class DynamicConfiguration {
@Valid @Valid
DynamicMessagePersisterConfiguration messagePersister = new DynamicMessagePersisterConfiguration(); DynamicMessagePersisterConfiguration messagePersister = new DynamicMessagePersisterConfiguration();
@JsonProperty
@Valid
DynamicRateLimitPolicy rateLimitPolicy = new DynamicRateLimitPolicy(false);
@JsonProperty @JsonProperty
@Valid @Valid
DynamicRegistrationConfiguration registrationConfiguration = new DynamicRegistrationConfiguration(false); DynamicRegistrationConfiguration registrationConfiguration = new DynamicRegistrationConfiguration(false);
@ -100,10 +96,6 @@ public class DynamicConfiguration {
return messagePersister; return messagePersister;
} }
public DynamicRateLimitPolicy getRateLimitPolicy() {
return rateLimitPolicy;
}
public DynamicRegistrationConfiguration getRegistrationConfiguration() { public DynamicRegistrationConfiguration getRegistrationConfiguration() {
return registrationConfiguration; return registrationConfiguration;
} }

View File

@ -1,8 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.configuration.dynamic;
public record DynamicRateLimitPolicy(boolean failOpen) {}

View File

@ -95,6 +95,6 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
return new DynamicRateLimiter(descriptor.id(), dynamicConfigurationManager, configResolver, validateScript, cacheCluster, clock); return new DynamicRateLimiter(descriptor.id(), dynamicConfigurationManager, configResolver, validateScript, cacheCluster, clock);
} }
final RateLimiterConfig cfg = configs.getOrDefault(descriptor.id(), descriptor.defaultConfig()); final RateLimiterConfig cfg = configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
return new StaticRateLimiter(descriptor.id(), cfg, validateScript, cacheCluster, clock, dynamicConfigurationManager); return new StaticRateLimiter(descriptor.id(), cfg, validateScript, cacheCluster, clock);
} }
} }

View File

@ -87,7 +87,7 @@ public class DynamicRateLimiter implements RateLimiter {
final RateLimiterConfig cfg = configResolver.get(); final RateLimiterConfig cfg = configResolver.get();
return currentHolder.updateAndGet(p -> p != null && p.getLeft().equals(cfg) return currentHolder.updateAndGet(p -> p != null && p.getLeft().equals(cfg)
? p ? p
: Pair.of(cfg, new StaticRateLimiter(name, cfg, validateScript, cluster, clock, dynamicConfigurationManager)) : Pair.of(cfg, new StaticRateLimiter(name, cfg, validateScript, cluster, clock))
); );
} }
} }

View File

@ -8,7 +8,7 @@ package org.whispersystems.textsecuregcm.limits;
import jakarta.validation.constraints.AssertTrue; import jakarta.validation.constraints.AssertTrue;
import java.time.Duration; import java.time.Duration;
public record RateLimiterConfig(int bucketSize, Duration permitRegenerationDuration) { public record RateLimiterConfig(int bucketSize, Duration permitRegenerationDuration, boolean failOpen) {
public double leakRatePerMillis() { public double leakRatePerMillis() {
return 1.0 / (permitRegenerationDuration.toNanos() / 1e6); return 1.0 / (permitRegenerationDuration.toNanos() / 1e6);

View File

@ -17,47 +17,46 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public class RateLimiters extends BaseRateLimiters<RateLimiters.For> { public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
public enum For implements RateLimiterDescriptor { public enum For implements RateLimiterDescriptor {
BACKUP_AUTH_CHECK("backupAuthCheck", false, new RateLimiterConfig(100, Duration.ofMinutes(15))), BACKUP_AUTH_CHECK("backupAuthCheck", false, new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
PIN("pin", false, new RateLimiterConfig(10, Duration.ofDays(1))), PIN("pin", false, new RateLimiterConfig(10, Duration.ofDays(1), false)),
ATTACHMENT("attachmentCreate", false, new RateLimiterConfig(50, Duration.ofMillis(1200))), ATTACHMENT("attachmentCreate", false, new RateLimiterConfig(50, Duration.ofMillis(1200), false)),
BACKUP_ATTACHMENT("backupAttachmentCreate", true, new RateLimiterConfig(10_000, Duration.ofSeconds(1))), BACKUP_ATTACHMENT("backupAttachmentCreate", true, new RateLimiterConfig(10_000, Duration.ofSeconds(1), false)),
PRE_KEYS("prekeys", false, new RateLimiterConfig(6, Duration.ofMinutes(10))), PRE_KEYS("prekeys", false, new RateLimiterConfig(6, Duration.ofMinutes(10), false)),
MESSAGES("messages", false, new RateLimiterConfig(60, Duration.ofSeconds(1))), MESSAGES("messages", false, new RateLimiterConfig(60, Duration.ofSeconds(1), false)),
STORIES("stories", false, new RateLimiterConfig(5_000, Duration.ofSeconds(8))), STORIES("stories", false, new RateLimiterConfig(5_000, Duration.ofSeconds(8), false)),
ALLOCATE_DEVICE("allocateDevice", false, new RateLimiterConfig(6, Duration.ofMinutes(2))), ALLOCATE_DEVICE("allocateDevice", false, new RateLimiterConfig(6, Duration.ofMinutes(2), false)),
VERIFY_DEVICE("verifyDevice", false, new RateLimiterConfig(6, Duration.ofMinutes(2))), VERIFY_DEVICE("verifyDevice", false, new RateLimiterConfig(6, Duration.ofMinutes(2), false)),
PROFILE("profile", false, new RateLimiterConfig(4320, Duration.ofSeconds(20))), PROFILE("profile", false, new RateLimiterConfig(4320, Duration.ofSeconds(20), false)),
STICKER_PACK("stickerPack", false, new RateLimiterConfig(50, Duration.ofMinutes(72))), STICKER_PACK("stickerPack", false, new RateLimiterConfig(50, Duration.ofMinutes(72), false)),
USERNAME_LOOKUP("usernameLookup", false, new RateLimiterConfig(100, Duration.ofMinutes(15))), USERNAME_LOOKUP("usernameLookup", false, new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
USERNAME_SET("usernameSet", false, new RateLimiterConfig(100, Duration.ofMinutes(15))), USERNAME_SET("usernameSet", false, new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
USERNAME_RESERVE("usernameReserve", false, new RateLimiterConfig(100, Duration.ofMinutes(15))), USERNAME_RESERVE("usernameReserve", false, new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
USERNAME_LINK_OPERATION("usernameLinkOperation", false, new RateLimiterConfig(10, Duration.ofMinutes(1))), USERNAME_LINK_OPERATION("usernameLinkOperation", false, new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
USERNAME_LINK_LOOKUP_PER_IP("usernameLinkLookupPerIp", false, new RateLimiterConfig(100, Duration.ofSeconds(15))), USERNAME_LINK_LOOKUP_PER_IP("usernameLinkLookupPerIp", false, new RateLimiterConfig(100, Duration.ofSeconds(15), false)),
CHECK_ACCOUNT_EXISTENCE("checkAccountExistence", false, new RateLimiterConfig(1000, Duration.ofSeconds(4))), CHECK_ACCOUNT_EXISTENCE("checkAccountExistence", false, new RateLimiterConfig(1000, Duration.ofSeconds(4), false)),
REGISTRATION("registration", false, new RateLimiterConfig(6, Duration.ofSeconds(30))), REGISTRATION("registration", false, new RateLimiterConfig(6, Duration.ofSeconds(30), false)),
VERIFICATION_PUSH_CHALLENGE("verificationPushChallenge", false, new RateLimiterConfig(5, Duration.ofSeconds(30))), VERIFICATION_PUSH_CHALLENGE("verificationPushChallenge", false, new RateLimiterConfig(5, Duration.ofSeconds(30), false)),
VERIFICATION_CAPTCHA("verificationCaptcha", false, new RateLimiterConfig(10, Duration.ofSeconds(30))), VERIFICATION_CAPTCHA("verificationCaptcha", false, new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
RATE_LIMIT_RESET("rateLimitReset", true, new RateLimiterConfig(2, Duration.ofHours(12))), RATE_LIMIT_RESET("rateLimitReset", true, new RateLimiterConfig(2, Duration.ofHours(12), false)),
CAPTCHA_CHALLENGE_ATTEMPT("captchaChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofMinutes(144))), CAPTCHA_CHALLENGE_ATTEMPT("captchaChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofMinutes(144), false)),
CAPTCHA_CHALLENGE_SUCCESS("captchaChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofHours(12))), CAPTCHA_CHALLENGE_SUCCESS("captchaChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofHours(12), false)),
SET_BACKUP_ID("setBackupId", true, new RateLimiterConfig(10, Duration.ofHours(1))), SET_BACKUP_ID("setBackupId", true, new RateLimiterConfig(10, Duration.ofHours(1), false)),
SET_PAID_MEDIA_BACKUP_ID("setPaidMediaBackupId", true, new RateLimiterConfig(5, Duration.ofDays(7))), SET_PAID_MEDIA_BACKUP_ID("setPaidMediaBackupId", true, new RateLimiterConfig(5, Duration.ofDays(7), false)),
PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofMinutes(144))), PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", true, new RateLimiterConfig(10, Duration.ofMinutes(144), false)),
PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofHours(12))), PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", true, new RateLimiterConfig(2, Duration.ofHours(12), false)),
GET_CALLING_RELAYS("getCallingRelays", false, new RateLimiterConfig(100, Duration.ofMinutes(10))), GET_CALLING_RELAYS("getCallingRelays", false, new RateLimiterConfig(100, Duration.ofMinutes(10), false)),
CREATE_CALL_LINK("createCallLink", false, new RateLimiterConfig(100, Duration.ofMinutes(15))), CREATE_CALL_LINK("createCallLink", false, new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
INBOUND_MESSAGE_BYTES("inboundMessageBytes", true, new RateLimiterConfig(128 * 1024 * 1024, Duration.ofNanos(500_000))), INBOUND_MESSAGE_BYTES("inboundMessageBytes", true, new RateLimiterConfig(128 * 1024 * 1024, Duration.ofNanos(500_000), false)),
EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", true, new RateLimiterConfig(100, Duration.ofMinutes(15))), EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", true, new RateLimiterConfig(100, Duration.ofMinutes(15), false)),
KEY_TRANSPARENCY_DISTINGUISHED_PER_IP("keyTransparencyDistinguished", true, KEY_TRANSPARENCY_DISTINGUISHED_PER_IP("keyTransparencyDistinguished", true, new RateLimiterConfig(100, Duration.ofSeconds(15), false)),
new RateLimiterConfig(100, Duration.ofSeconds(15))), KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", true, new RateLimiterConfig(100, Duration.ofSeconds(15), false)),
KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", true, new RateLimiterConfig(100, Duration.ofSeconds(15))), KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", true, new RateLimiterConfig(100, Duration.ofSeconds(15), false)),
KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", true, new RateLimiterConfig(100, Duration.ofSeconds(15))), WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", true, new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", true, new RateLimiterConfig(10, Duration.ofSeconds(30))), UPLOAD_TRANSFER_ARCHIVE("uploadTransferArchive", true, new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
UPLOAD_TRANSFER_ARCHIVE("uploadTransferArchive", true, new RateLimiterConfig(10, Duration.ofMinutes(1))), WAIT_FOR_TRANSFER_ARCHIVE("waitForTransferArchive", true, new RateLimiterConfig(10, Duration.ofSeconds(30), false)),
WAIT_FOR_TRANSFER_ARCHIVE("waitForTransferArchive", true, new RateLimiterConfig(10, Duration.ofSeconds(30))), RECORD_DEVICE_TRANSFER_REQUEST("recordDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100), false)),
RECORD_DEVICE_TRANSFER_REQUEST("recordDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100))), WAIT_FOR_DEVICE_TRANSFER_REQUEST("waitForDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100), false)),
WAIT_FOR_DEVICE_TRANSFER_REQUEST("waitForDeviceTransferRequest", true, new RateLimiterConfig(10, Duration.ofMillis(100))), DEVICE_CHECK_CHALLENGE("deviceCheckChallenge", true, new RateLimiterConfig(10, Duration.ofMinutes(1), false)),
DEVICE_CHECK_CHALLENGE("deviceCheckChallenge", true, new RateLimiterConfig(10, Duration.ofMinutes(1))),
; ;
private final String id; private final String id;

View File

@ -15,12 +15,10 @@ import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.util.List; import java.util.List;
import java.util.concurrent.CompletionStage; import java.util.concurrent.CompletionStage;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
@ -30,8 +28,7 @@ public class StaticRateLimiter implements RateLimiter {
private final RateLimiterConfig config; private final RateLimiterConfig config;
private final Counter counter; private final Counter limitExceededCounter;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final ClusterLuaScript validateScript; private final ClusterLuaScript validateScript;
@ -45,15 +42,13 @@ public class StaticRateLimiter implements RateLimiter {
final RateLimiterConfig config, final RateLimiterConfig config,
final ClusterLuaScript validateScript, final ClusterLuaScript validateScript,
final FaultTolerantRedisClusterClient cacheCluster, final FaultTolerantRedisClusterClient cacheCluster,
final Clock clock, final Clock clock) {
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) {
this.name = requireNonNull(name); this.name = requireNonNull(name);
this.config = requireNonNull(config); this.config = requireNonNull(config);
this.validateScript = requireNonNull(validateScript); this.validateScript = requireNonNull(validateScript);
this.cacheCluster = requireNonNull(cacheCluster); this.cacheCluster = requireNonNull(cacheCluster);
this.clock = requireNonNull(clock); this.clock = requireNonNull(clock);
this.counter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "rateLimiterName", name); this.limitExceededCounter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "rateLimiterName", name);
this.dynamicConfigurationManager = dynamicConfigurationManager;
} }
@Override @Override
@ -61,13 +56,13 @@ public class StaticRateLimiter implements RateLimiter {
try { try {
final long deficitPermitsAmount = executeValidateScript(key, amount, true); final long deficitPermitsAmount = executeValidateScript(key, amount, true);
if (deficitPermitsAmount > 0) { if (deficitPermitsAmount > 0) {
counter.increment(); limitExceededCounter.increment();
final Duration retryAfter = Duration.ofMillis( final Duration retryAfter = Duration.ofMillis(
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis())); (long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
throw new RateLimitExceededException(retryAfter); throw new RateLimitExceededException(retryAfter);
} }
} catch (final Exception e) { } catch (final Exception e) {
if (!failOpen()) { if (!config.failOpen()) {
throw e; throw e;
} }
} }
@ -80,16 +75,16 @@ public class StaticRateLimiter implements RateLimiter {
if (deficitPermitsAmount == 0) { if (deficitPermitsAmount == 0) {
return completedFuture((Void) null); return completedFuture((Void) null);
} }
counter.increment(); limitExceededCounter.increment();
final Duration retryAfter = Duration.ofMillis( final Duration retryAfter = Duration.ofMillis(
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis())); (long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
return failedFuture(new RateLimitExceededException(retryAfter)); return failedFuture(new RateLimitExceededException(retryAfter));
}) })
.exceptionally(throwable -> { .exceptionally(throwable -> {
if (failOpen()) { if (config.failOpen()) {
return null; return null;
} }
throw ExceptionUtils.wrap(throwable); throw ExceptionUtils.wrap(new RateLimitExceededException(null));
}); });
} }
@ -99,7 +94,7 @@ public class StaticRateLimiter implements RateLimiter {
final long deficitPermitsAmount = executeValidateScript(key, amount, false); final long deficitPermitsAmount = executeValidateScript(key, amount, false);
return deficitPermitsAmount == 0; return deficitPermitsAmount == 0;
} catch (final Exception e) { } catch (final Exception e) {
if (failOpen()) { if (config.failOpen()) {
return true; return true;
} else { } else {
throw e; throw e;
@ -112,7 +107,7 @@ public class StaticRateLimiter implements RateLimiter {
return executeValidateScriptAsync(key, amount, false) return executeValidateScriptAsync(key, amount, false)
.thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0) .thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0)
.exceptionally(throwable -> { .exceptionally(throwable -> {
if (failOpen()) { if (config.failOpen()) {
return true; return true;
} }
throw ExceptionUtils.wrap(throwable); throw ExceptionUtils.wrap(throwable);
@ -135,10 +130,6 @@ public class StaticRateLimiter implements RateLimiter {
return config; return config;
} }
private boolean failOpen() {
return this.dynamicConfigurationManager.getConfiguration().getRateLimitPolicy().failOpen();
}
private long executeValidateScript(final String key, final int amount, final boolean applyChanges) { private long executeValidateScript(final String key, final int amount, final boolean applyChanges) {
final List<String> keys = List.of(bucketName(name, key)); final List<String> keys = List.of(bucketName(name, key));
final List<String> arguments = List.of( final List<String> arguments = List.of(

View File

@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.limits;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.time.Duration; import java.time.Duration;
import java.util.Optional;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@ -16,15 +15,15 @@ class RateLimiterConfigTest {
@Test @Test
void leakRatePerMillis() { void leakRatePerMillis() {
assertEquals(0.001, new RateLimiterConfig(1, Duration.ofSeconds(1)).leakRatePerMillis()); assertEquals(0.001, new RateLimiterConfig(1, Duration.ofSeconds(1), false).leakRatePerMillis());
assertEquals(1e6, new RateLimiterConfig(1, Duration.ofNanos(1)).leakRatePerMillis()); assertEquals(1e6, new RateLimiterConfig(1, Duration.ofNanos(1), false).leakRatePerMillis());
} }
@Test @Test
void isRegenerationRatePositive() { void isRegenerationRatePositive() {
assertTrue(new RateLimiterConfig(1, Duration.ofSeconds(1)).hasPositiveRegenerationRate()); assertTrue(new RateLimiterConfig(1, Duration.ofSeconds(1), false).hasPositiveRegenerationRate());
assertTrue(new RateLimiterConfig(1, Duration.ofNanos(1)).hasPositiveRegenerationRate()); assertTrue(new RateLimiterConfig(1, Duration.ofNanos(1), false).hasPositiveRegenerationRate());
assertFalse(new RateLimiterConfig(1, Duration.ZERO).hasPositiveRegenerationRate()); assertFalse(new RateLimiterConfig(1, Duration.ZERO, false).hasPositiveRegenerationRate());
assertFalse(new RateLimiterConfig(1, Duration.ofSeconds(-1)).hasPositiveRegenerationRate()); assertFalse(new RateLimiterConfig(1, Duration.ofSeconds(-1), false).hasPositiveRegenerationRate());
} }
} }

View File

@ -5,6 +5,7 @@
package org.whispersystems.textsecuregcm.limits; package org.whispersystems.textsecuregcm.limits;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@ -22,8 +23,9 @@ import java.util.Map;
import java.util.Optional; import java.util.Optional;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitPolicy;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
@ -57,7 +59,7 @@ public class RateLimitersLuaScriptTest {
final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION; final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION;
final FaultTolerantRedisClusterClient redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster(); final FaultTolerantRedisClusterClient redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster();
final RateLimiters limiters = new RateLimiters( final RateLimiters limiters = new RateLimiters(
Map.of(descriptor.id(), new RateLimiterConfig(60, Duration.ofSeconds(1))), Map.of(descriptor.id(), new RateLimiterConfig(60, Duration.ofSeconds(1), false)),
dynamicConfig, dynamicConfig,
RateLimiters.defaultScript(redisCluster), RateLimiters.defaultScript(redisCluster),
redisCluster, redisCluster,
@ -74,7 +76,7 @@ public class RateLimitersLuaScriptTest {
final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION; final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION;
final FaultTolerantRedisClusterClient redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster(); final FaultTolerantRedisClusterClient redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster();
final RateLimiters limiters = new RateLimiters( final RateLimiters limiters = new RateLimiters(
Map.of(descriptor.id(), new RateLimiterConfig(1000, Duration.ofSeconds(1))), Map.of(descriptor.id(), new RateLimiterConfig(1000, Duration.ofSeconds(1), false)),
dynamicConfig, dynamicConfig,
RateLimiters.defaultScript(redisCluster), RateLimiters.defaultScript(redisCluster),
redisCluster, redisCluster,
@ -119,20 +121,25 @@ public class RateLimitersLuaScriptTest {
assertEquals(750L, decodeBucket(key).orElseThrow().tokensRemaining); assertEquals(750L, decodeBucket(key).orElseThrow().tokensRemaining);
} }
@Test @ParameterizedTest
public void testFailOpen() throws Exception { @ValueSource(booleans = {true, false})
when(configuration.getRateLimitPolicy()).thenReturn(new DynamicRateLimitPolicy(true)); public void testFailOpen(final boolean failOpen) {
final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION; final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION;
final FaultTolerantRedisClusterClient redisCluster = mock(FaultTolerantRedisClusterClient.class); final FaultTolerantRedisClusterClient redisCluster = mock(FaultTolerantRedisClusterClient.class);
final RateLimiters limiters = new RateLimiters( final RateLimiters limiters = new RateLimiters(
Map.of(descriptor.id(), new RateLimiterConfig(1000, Duration.ofSeconds(1))), Map.of(descriptor.id(), new RateLimiterConfig(1000, Duration.ofSeconds(1), failOpen)),
dynamicConfig, dynamicConfig,
RateLimiters.defaultScript(redisCluster), RateLimiters.defaultScript(redisCluster),
redisCluster, redisCluster,
Clock.systemUTC()); Clock.systemUTC());
when(redisCluster.withCluster(any())).thenThrow(new RedisException("fail")); when(redisCluster.withCluster(any())).thenThrow(new RedisException("fail"));
final RateLimiter rateLimiter = limiters.forDescriptor(descriptor); final RateLimiter rateLimiter = limiters.forDescriptor(descriptor);
rateLimiter.validate("test", 200);
if (failOpen) {
assertDoesNotThrow(() -> rateLimiter.validate("test", 200));
} else {
assertThrows(RedisException.class, () -> rateLimiter.validate("test", 200));
}
} }
private String serializeToOldBucketValueFormat( private String serializeToOldBucketValueFormat(

View File

@ -5,9 +5,9 @@
package org.whispersystems.textsecuregcm.limits; package org.whispersystems.textsecuregcm.limits;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -20,7 +20,6 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitPolicy;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
@ -56,30 +55,31 @@ public class RateLimitersTest {
prekeys: prekeys:
bucketSize: 150 bucketSize: 150
permitRegenerationDuration: PT6S permitRegenerationDuration: PT6S
failOpen: true
attachmentCreate: attachmentCreate:
bucketSize: 4 bucketSize: 4
permitRegenerationDuration: PT30S permitRegenerationDuration: PT30S
rateLimitPolicy:
failOpen: true failOpen: true
"""; """;
public record GenericHolder( public record SimpleDynamicConfiguration(@Valid @NotNull @JsonProperty Map<String, RateLimiterConfig> limits) {
@Valid @NotNull @JsonProperty Map<String, RateLimiterConfig> limits,
@Valid @JsonProperty DynamicRateLimitPolicy rateLimitPolicy) {
} }
@Test @Test
public void testValidateConfigs() throws Exception { public void testValidateConfigs() throws Exception {
assertThrows(IllegalArgumentException.class, () -> { assertThrows(IllegalArgumentException.class, () -> {
final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(BAD_YAML, GenericHolder.class).orElseThrow(); final SimpleDynamicConfiguration dynamicConfiguration =
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, validateScript, redisCluster, clock); DynamicConfigurationManager.parseConfiguration(BAD_YAML, SimpleDynamicConfiguration.class).orElseThrow();
final RateLimiters rateLimiters = new RateLimiters(dynamicConfiguration.limits(), dynamicConfig, validateScript, redisCluster, clock);
rateLimiters.validateValuesAndConfigs(); rateLimiters.validateValuesAndConfigs();
}); });
final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(GOOD_YAML, GenericHolder.class).orElseThrow(); final SimpleDynamicConfiguration dynamicConfiguration =
assertTrue(cfg.rateLimitPolicy.failOpen()); DynamicConfigurationManager.parseConfiguration(GOOD_YAML, SimpleDynamicConfiguration.class).orElseThrow();
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, validateScript, redisCluster, clock);
rateLimiters.validateValuesAndConfigs(); final RateLimiters rateLimiters = new RateLimiters(dynamicConfiguration.limits(), dynamicConfig, validateScript, redisCluster, clock);
assertDoesNotThrow(rateLimiters::validateValuesAndConfigs);
} }
@Test @Test
@ -116,9 +116,9 @@ public class RateLimitersTest {
@Test @Test
void testChangingConfiguration() { void testChangingConfiguration() {
final RateLimiterConfig initialRateLimiterConfig = new RateLimiterConfig(4, Duration.ofMinutes(1)); final RateLimiterConfig initialRateLimiterConfig = new RateLimiterConfig(4, Duration.ofMinutes(1), false);
final RateLimiterConfig updatedRateLimiterCongig = new RateLimiterConfig(17, Duration.ofSeconds(3)); final RateLimiterConfig updatedRateLimiterCongig = new RateLimiterConfig(17, Duration.ofSeconds(3), false);
final RateLimiterConfig baseConfig = new RateLimiterConfig(1, Duration.ofMinutes(1)); final RateLimiterConfig baseConfig = new RateLimiterConfig(1, Duration.ofMinutes(1), false);
final Map<String, RateLimiterConfig> limitsConfigMap = new HashMap<>(); final Map<String, RateLimiterConfig> limitsConfigMap = new HashMap<>();
@ -146,8 +146,8 @@ public class RateLimitersTest {
@Test @Test
public void testRateLimiterHasItsPrioritiesStraight() throws Exception { public void testRateLimiterHasItsPrioritiesStraight() throws Exception {
final RateLimiters.For descriptor = RateLimiters.For.CAPTCHA_CHALLENGE_ATTEMPT; final RateLimiters.For descriptor = RateLimiters.For.CAPTCHA_CHALLENGE_ATTEMPT;
final RateLimiterConfig configForDynamic = new RateLimiterConfig(1, Duration.ofMinutes(1)); final RateLimiterConfig configForDynamic = new RateLimiterConfig(1, Duration.ofMinutes(1), false);
final RateLimiterConfig configForStatic = new RateLimiterConfig(2, Duration.ofSeconds(30)); final RateLimiterConfig configForStatic = new RateLimiterConfig(2, Duration.ofSeconds(30), false);
final RateLimiterConfig defaultConfig = descriptor.defaultConfig(); final RateLimiterConfig defaultConfig = descriptor.defaultConfig();
final Map<String, RateLimiterConfig> mapForDynamic = new HashMap<>(); final Map<String, RateLimiterConfig> mapForDynamic = new HashMap<>();
@ -188,7 +188,7 @@ public class RateLimitersTest {
@Override @Override
public RateLimiterConfig defaultConfig() { public RateLimiterConfig defaultConfig() {
return new RateLimiterConfig(1, Duration.ofMinutes(1)); return new RateLimiterConfig(1, Duration.ofMinutes(1), false);
} }
} }