From 0fe6485038db88d72a8610de5329d77621f5d4f8 Mon Sep 17 00:00:00 2001 From: ravi-signal <99042880+ravi-signal@users.noreply.github.com> Date: Fri, 14 Apr 2023 13:08:14 -0500 Subject: [PATCH] Add a configuration to make rate limiters fail open --- .../dynamic/DynamicConfiguration.java | 9 +++ .../dynamic/DynamicRateLimitPolicy.java | 8 +++ .../limits/BaseRateLimiters.java | 4 +- .../limits/DynamicRateLimiter.java | 8 ++- .../limits/StaticRateLimiter.java | 59 +++++++++++++++---- .../limits/RateLimitersLuaScriptTest.java | 19 ++++++ .../limits/RateLimitersTest.java | 9 ++- 7 files changed, 100 insertions(+), 16 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicRateLimitPolicy.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicConfiguration.java index cbff47240..f28f74287 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicConfiguration.java @@ -62,6 +62,11 @@ public class DynamicConfiguration { @Valid DynamicPushNotificationConfiguration pushNotifications = new DynamicPushNotificationConfiguration(); + + @JsonProperty + @Valid + DynamicRateLimitPolicy rateLimitPolicy = new DynamicRateLimitPolicy(false); + public Optional getExperimentEnrollmentConfiguration( final String experimentName) { return Optional.ofNullable(experiments.get(experimentName)); @@ -111,4 +116,8 @@ public class DynamicConfiguration { public DynamicPushNotificationConfiguration getPushNotificationConfiguration() { return pushNotifications; } + + public DynamicRateLimitPolicy getRateLimitPolicy() { + return rateLimitPolicy; + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicRateLimitPolicy.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicRateLimitPolicy.java new file mode 100644 index 000000000..3fb0172ea --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/dynamic/DynamicRateLimitPolicy.java @@ -0,0 +1,8 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.configuration.dynamic; + +public record DynamicRateLimitPolicy(boolean failOpen) {} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/BaseRateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/BaseRateLimiters.java index 65ad263ca..4f87761ae 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/BaseRateLimiters.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/BaseRateLimiters.java @@ -92,9 +92,9 @@ public abstract class BaseRateLimiters { ? config : configs.getOrDefault(descriptor.id(), descriptor.defaultConfig()); }; - return new DynamicRateLimiter(descriptor.id(), configResolver, validateScript, cacheCluster, clock); + return new DynamicRateLimiter(descriptor.id(), dynamicConfigurationManager, configResolver, validateScript, cacheCluster, clock); } final RateLimiterConfig cfg = configs.getOrDefault(descriptor.id(), descriptor.defaultConfig()); - return new StaticRateLimiter(descriptor.id(), cfg, validateScript, cacheCluster, clock); + return new StaticRateLimiter(descriptor.id(), cfg, validateScript, cacheCluster, clock, dynamicConfigurationManager); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiter.java index 8f7b0ec82..2b38483e2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/DynamicRateLimiter.java @@ -12,14 +12,16 @@ import java.util.concurrent.CompletionStage; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import org.apache.commons.lang3.tuple.Pair; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; public class DynamicRateLimiter implements RateLimiter { private final String name; - + private final DynamicConfigurationManager dynamicConfigurationManager; private final Supplier configResolver; private final ClusterLuaScript validateScript; @@ -33,11 +35,13 @@ public class DynamicRateLimiter implements RateLimiter { public DynamicRateLimiter( final String name, + final DynamicConfigurationManager dynamicConfigurationManager, final Supplier configResolver, final ClusterLuaScript validateScript, final FaultTolerantRedisCluster cluster, final Clock clock) { this.name = requireNonNull(name); + this.dynamicConfigurationManager = dynamicConfigurationManager; this.configResolver = requireNonNull(configResolver); this.validateScript = requireNonNull(validateScript); this.cluster = requireNonNull(cluster); @@ -83,7 +87,7 @@ public class DynamicRateLimiter implements RateLimiter { final RateLimiterConfig cfg = configResolver.get(); return currentHolder.updateAndGet(p -> p != null && p.getLeft().equals(cfg) ? p - : Pair.of(cfg, new StaticRateLimiter(name, cfg, validateScript, cluster, clock)) + : Pair.of(cfg, new StaticRateLimiter(name, cfg, validateScript, cluster, clock, dynamicConfigurationManager)) ); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java index e5a992048..5d3560ead 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java @@ -9,16 +9,20 @@ import static java.util.concurrent.CompletableFuture.completedFuture; import static java.util.concurrent.CompletableFuture.failedFuture; import com.google.common.annotations.VisibleForTesting; +import io.lettuce.core.RedisException; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; import java.time.Clock; import java.time.Duration; import java.util.List; import java.util.concurrent.CompletionStage; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; +import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.Util; public class StaticRateLimiter implements RateLimiter { @@ -28,6 +32,7 @@ public class StaticRateLimiter implements RateLimiter { private final RateLimiterConfig config; private final Counter counter; + private final DynamicConfigurationManager dynamicConfigurationManager; private final ClusterLuaScript validateScript; @@ -41,23 +46,31 @@ public class StaticRateLimiter implements RateLimiter { final RateLimiterConfig config, final ClusterLuaScript validateScript, final FaultTolerantRedisCluster cacheCluster, - final Clock clock) { + final Clock clock, + final DynamicConfigurationManager dynamicConfigurationManager) { this.name = requireNonNull(name); this.config = requireNonNull(config); this.validateScript = requireNonNull(validateScript); this.cacheCluster = requireNonNull(cacheCluster); this.clock = requireNonNull(clock); this.counter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "name", name); + this.dynamicConfigurationManager = dynamicConfigurationManager; } @Override public void validate(final String key, final int amount) throws RateLimitExceededException { - final long deficitPermitsAmount = executeValidateScript(key, amount, true); - if (deficitPermitsAmount > 0) { - counter.increment(); - final Duration retryAfter = Duration.ofMillis( - (long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis())); - throw new RateLimitExceededException(retryAfter, true); + try { + final long deficitPermitsAmount = executeValidateScript(key, amount, true); + if (deficitPermitsAmount > 0) { + counter.increment(); + final Duration retryAfter = Duration.ofMillis( + (long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis())); + throw new RateLimitExceededException(retryAfter, true); + } + } catch (RedisException e) { + if (!failOpen()) { + throw e; + } } } @@ -66,25 +79,45 @@ public class StaticRateLimiter implements RateLimiter { return executeValidateScriptAsync(key, amount, true) .thenCompose(deficitPermitsAmount -> { if (deficitPermitsAmount == 0) { - return completedFuture(null); + return completedFuture((Void) null); } counter.increment(); final Duration retryAfter = Duration.ofMillis( (long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis())); return failedFuture(new RateLimitExceededException(retryAfter, true)); + }) + .exceptionally(throwable -> { + if (ExceptionUtils.unwrap(throwable) instanceof RedisException && failOpen()) { + return null; + } + throw ExceptionUtils.wrap(throwable); }); } @Override public boolean hasAvailablePermits(final String key, final int amount) { - final long deficitPermitsAmount = executeValidateScript(key, amount, false); - return deficitPermitsAmount == 0; + try { + final long deficitPermitsAmount = executeValidateScript(key, amount, false); + return deficitPermitsAmount == 0; + } catch (RedisException e) { + if (failOpen()) { + return true; + } else { + throw e; + } + } } @Override public CompletionStage hasAvailablePermitsAsync(final String key, final int amount) { return executeValidateScriptAsync(key, amount, false) - .thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0); + .thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0) + .exceptionally(throwable -> { + if (ExceptionUtils.unwrap(throwable) instanceof RedisException && failOpen()) { + return true; + } + throw ExceptionUtils.wrap(throwable); + }); } @Override @@ -103,6 +136,10 @@ public class StaticRateLimiter implements RateLimiter { return config; } + private boolean failOpen() { + return this.dynamicConfigurationManager.getConfiguration().getRateLimitPolicy().failOpen(); + } + private long executeValidateScript(final String key, final int amount, final boolean applyChanges) { final List keys = List.of(bucketName(name, key)); final List arguments = List.of( diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java index 30ec78ea7..33601538c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersLuaScriptTest.java @@ -8,10 +8,12 @@ package org.whispersystems.textsecuregcm.limits; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import com.fasterxml.jackson.core.JsonProcessingException; +import io.lettuce.core.RedisException; import io.lettuce.core.ScriptOutputType; import java.time.Clock; import java.util.List; @@ -20,6 +22,7 @@ import java.util.Optional; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitPolicy; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; @@ -159,6 +162,22 @@ public class RateLimitersLuaScriptTest { assertEquals(750L, decodeBucket(key).orElseThrow().tokensRemaining); } + @Test + public void testFailOpen() throws Exception { + when(configuration.getRateLimitPolicy()).thenReturn(new DynamicRateLimitPolicy(true)); + final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION; + final FaultTolerantRedisCluster redisCluster = mock(FaultTolerantRedisCluster.class); + final RateLimiters limiters = new RateLimiters( + Map.of(descriptor.id(), new RateLimiterConfig(1000, 60)), + dynamicConfig, + RateLimiters.defaultScript(redisCluster), + redisCluster, + Clock.systemUTC()); + when(redisCluster.withCluster(any())).thenThrow(new RedisException("fail")); + final RateLimiter rateLimiter = limiters.forDescriptor(descriptor); + rateLimiter.validate("test", 200); + } + private String serializeToOldBucketValueFormat( final long bucketSize, final long leakRatePerMillis, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java index d1705e7f9..a07db7781 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitersTest.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.limits; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -18,6 +19,7 @@ import javax.validation.Valid; import javax.validation.constraints.NotNull; import org.junit.jupiter.api.Test; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitPolicy; import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; @@ -56,9 +58,13 @@ public class RateLimitersTest { attachmentCreate: bucketSize: 4 leakRatePerMinute: 2 + rateLimitPolicy: + failOpen: true """; - public record GenericHolder(@Valid @NotNull @JsonProperty Map limits) { + public record GenericHolder( + @Valid @NotNull @JsonProperty Map limits, + @Valid @JsonProperty DynamicRateLimitPolicy rateLimitPolicy) { } @Test @@ -70,6 +76,7 @@ public class RateLimitersTest { }); final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(GOOD_YAML, GenericHolder.class).orElseThrow(); + assertTrue(cfg.rateLimitPolicy.failOpen()); final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, validateScript, redisCluster, clock); rateLimiters.validateValuesAndConfigs(); }