Add a configuration to make rate limiters fail open

This commit is contained in:
ravi-signal 2023-04-14 13:08:14 -05:00 committed by GitHub
parent a553093046
commit 0fe6485038
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 100 additions and 16 deletions

View File

@ -62,6 +62,11 @@ public class DynamicConfiguration {
@Valid @Valid
DynamicPushNotificationConfiguration pushNotifications = new DynamicPushNotificationConfiguration(); DynamicPushNotificationConfiguration pushNotifications = new DynamicPushNotificationConfiguration();
@JsonProperty
@Valid
DynamicRateLimitPolicy rateLimitPolicy = new DynamicRateLimitPolicy(false);
public Optional<DynamicExperimentEnrollmentConfiguration> getExperimentEnrollmentConfiguration( public Optional<DynamicExperimentEnrollmentConfiguration> getExperimentEnrollmentConfiguration(
final String experimentName) { final String experimentName) {
return Optional.ofNullable(experiments.get(experimentName)); return Optional.ofNullable(experiments.get(experimentName));
@ -111,4 +116,8 @@ public class DynamicConfiguration {
public DynamicPushNotificationConfiguration getPushNotificationConfiguration() { public DynamicPushNotificationConfiguration getPushNotificationConfiguration() {
return pushNotifications; return pushNotifications;
} }
public DynamicRateLimitPolicy getRateLimitPolicy() {
return rateLimitPolicy;
}
} }

View File

@ -0,0 +1,8 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.configuration.dynamic;
public record DynamicRateLimitPolicy(boolean failOpen) {}

View File

@ -92,9 +92,9 @@ public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
? config ? config
: configs.getOrDefault(descriptor.id(), descriptor.defaultConfig()); : configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
}; };
return new DynamicRateLimiter(descriptor.id(), configResolver, validateScript, cacheCluster, clock); return new DynamicRateLimiter(descriptor.id(), dynamicConfigurationManager, configResolver, validateScript, cacheCluster, clock);
} }
final RateLimiterConfig cfg = configs.getOrDefault(descriptor.id(), descriptor.defaultConfig()); final RateLimiterConfig cfg = configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
return new StaticRateLimiter(descriptor.id(), cfg, validateScript, cacheCluster, clock); return new StaticRateLimiter(descriptor.id(), cfg, validateScript, cacheCluster, clock, dynamicConfigurationManager);
} }
} }

View File

@ -12,14 +12,16 @@ import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier; import java.util.function.Supplier;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public class DynamicRateLimiter implements RateLimiter { public class DynamicRateLimiter implements RateLimiter {
private final String name; private final String name;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final Supplier<RateLimiterConfig> configResolver; private final Supplier<RateLimiterConfig> configResolver;
private final ClusterLuaScript validateScript; private final ClusterLuaScript validateScript;
@ -33,11 +35,13 @@ public class DynamicRateLimiter implements RateLimiter {
public DynamicRateLimiter( public DynamicRateLimiter(
final String name, final String name,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final Supplier<RateLimiterConfig> configResolver, final Supplier<RateLimiterConfig> configResolver,
final ClusterLuaScript validateScript, final ClusterLuaScript validateScript,
final FaultTolerantRedisCluster cluster, final FaultTolerantRedisCluster cluster,
final Clock clock) { final Clock clock) {
this.name = requireNonNull(name); this.name = requireNonNull(name);
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.configResolver = requireNonNull(configResolver); this.configResolver = requireNonNull(configResolver);
this.validateScript = requireNonNull(validateScript); this.validateScript = requireNonNull(validateScript);
this.cluster = requireNonNull(cluster); this.cluster = requireNonNull(cluster);
@ -83,7 +87,7 @@ public class DynamicRateLimiter implements RateLimiter {
final RateLimiterConfig cfg = configResolver.get(); final RateLimiterConfig cfg = configResolver.get();
return currentHolder.updateAndGet(p -> p != null && p.getLeft().equals(cfg) return currentHolder.updateAndGet(p -> p != null && p.getLeft().equals(cfg)
? p ? p
: Pair.of(cfg, new StaticRateLimiter(name, cfg, validateScript, cluster, clock)) : Pair.of(cfg, new StaticRateLimiter(name, cfg, validateScript, cluster, clock, dynamicConfigurationManager))
); );
} }
} }

View File

@ -9,16 +9,20 @@ import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.failedFuture; import static java.util.concurrent.CompletableFuture.failedFuture;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.lettuce.core.RedisException;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.util.List; import java.util.List;
import java.util.concurrent.CompletionStage; import java.util.concurrent.CompletionStage;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
public class StaticRateLimiter implements RateLimiter { public class StaticRateLimiter implements RateLimiter {
@ -28,6 +32,7 @@ public class StaticRateLimiter implements RateLimiter {
private final RateLimiterConfig config; private final RateLimiterConfig config;
private final Counter counter; private final Counter counter;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final ClusterLuaScript validateScript; private final ClusterLuaScript validateScript;
@ -41,23 +46,31 @@ public class StaticRateLimiter implements RateLimiter {
final RateLimiterConfig config, final RateLimiterConfig config,
final ClusterLuaScript validateScript, final ClusterLuaScript validateScript,
final FaultTolerantRedisCluster cacheCluster, final FaultTolerantRedisCluster cacheCluster,
final Clock clock) { final Clock clock,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) {
this.name = requireNonNull(name); this.name = requireNonNull(name);
this.config = requireNonNull(config); this.config = requireNonNull(config);
this.validateScript = requireNonNull(validateScript); this.validateScript = requireNonNull(validateScript);
this.cacheCluster = requireNonNull(cacheCluster); this.cacheCluster = requireNonNull(cacheCluster);
this.clock = requireNonNull(clock); this.clock = requireNonNull(clock);
this.counter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "name", name); this.counter = Metrics.counter(MetricsUtil.name(getClass(), "exceeded"), "name", name);
this.dynamicConfigurationManager = dynamicConfigurationManager;
} }
@Override @Override
public void validate(final String key, final int amount) throws RateLimitExceededException { public void validate(final String key, final int amount) throws RateLimitExceededException {
final long deficitPermitsAmount = executeValidateScript(key, amount, true); try {
if (deficitPermitsAmount > 0) { final long deficitPermitsAmount = executeValidateScript(key, amount, true);
counter.increment(); if (deficitPermitsAmount > 0) {
final Duration retryAfter = Duration.ofMillis( counter.increment();
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis())); final Duration retryAfter = Duration.ofMillis(
throw new RateLimitExceededException(retryAfter, true); (long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
throw new RateLimitExceededException(retryAfter, true);
}
} catch (RedisException e) {
if (!failOpen()) {
throw e;
}
} }
} }
@ -66,25 +79,45 @@ public class StaticRateLimiter implements RateLimiter {
return executeValidateScriptAsync(key, amount, true) return executeValidateScriptAsync(key, amount, true)
.thenCompose(deficitPermitsAmount -> { .thenCompose(deficitPermitsAmount -> {
if (deficitPermitsAmount == 0) { if (deficitPermitsAmount == 0) {
return completedFuture(null); return completedFuture((Void) null);
} }
counter.increment(); counter.increment();
final Duration retryAfter = Duration.ofMillis( final Duration retryAfter = Duration.ofMillis(
(long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis())); (long) Math.ceil((double) deficitPermitsAmount / config.leakRatePerMillis()));
return failedFuture(new RateLimitExceededException(retryAfter, true)); return failedFuture(new RateLimitExceededException(retryAfter, true));
})
.exceptionally(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof RedisException && failOpen()) {
return null;
}
throw ExceptionUtils.wrap(throwable);
}); });
} }
@Override @Override
public boolean hasAvailablePermits(final String key, final int amount) { public boolean hasAvailablePermits(final String key, final int amount) {
final long deficitPermitsAmount = executeValidateScript(key, amount, false); try {
return deficitPermitsAmount == 0; final long deficitPermitsAmount = executeValidateScript(key, amount, false);
return deficitPermitsAmount == 0;
} catch (RedisException e) {
if (failOpen()) {
return true;
} else {
throw e;
}
}
} }
@Override @Override
public CompletionStage<Boolean> hasAvailablePermitsAsync(final String key, final int amount) { public CompletionStage<Boolean> hasAvailablePermitsAsync(final String key, final int amount) {
return executeValidateScriptAsync(key, amount, false) return executeValidateScriptAsync(key, amount, false)
.thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0); .thenApply(deficitPermitsAmount -> deficitPermitsAmount == 0)
.exceptionally(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof RedisException && failOpen()) {
return true;
}
throw ExceptionUtils.wrap(throwable);
});
} }
@Override @Override
@ -103,6 +136,10 @@ public class StaticRateLimiter implements RateLimiter {
return config; return config;
} }
private boolean failOpen() {
return this.dynamicConfigurationManager.getConfiguration().getRateLimitPolicy().failOpen();
}
private long executeValidateScript(final String key, final int amount, final boolean applyChanges) { private long executeValidateScript(final String key, final int amount, final boolean applyChanges) {
final List<String> keys = List.of(bucketName(name, key)); final List<String> keys = List.of(bucketName(name, key));
final List<String> arguments = List.of( final List<String> arguments = List.of(

View File

@ -8,10 +8,12 @@ package org.whispersystems.textsecuregcm.limits;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import io.lettuce.core.RedisException;
import io.lettuce.core.ScriptOutputType; import io.lettuce.core.ScriptOutputType;
import java.time.Clock; import java.time.Clock;
import java.util.List; import java.util.List;
@ -20,6 +22,7 @@ import java.util.Optional;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitPolicy;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
@ -159,6 +162,22 @@ public class RateLimitersLuaScriptTest {
assertEquals(750L, decodeBucket(key).orElseThrow().tokensRemaining); assertEquals(750L, decodeBucket(key).orElseThrow().tokensRemaining);
} }
@Test
public void testFailOpen() throws Exception {
when(configuration.getRateLimitPolicy()).thenReturn(new DynamicRateLimitPolicy(true));
final RateLimiters.For descriptor = RateLimiters.For.REGISTRATION;
final FaultTolerantRedisCluster redisCluster = mock(FaultTolerantRedisCluster.class);
final RateLimiters limiters = new RateLimiters(
Map.of(descriptor.id(), new RateLimiterConfig(1000, 60)),
dynamicConfig,
RateLimiters.defaultScript(redisCluster),
redisCluster,
Clock.systemUTC());
when(redisCluster.withCluster(any())).thenThrow(new RedisException("fail"));
final RateLimiter rateLimiter = limiters.forDescriptor(descriptor);
rateLimiter.validate("test", 200);
}
private String serializeToOldBucketValueFormat( private String serializeToOldBucketValueFormat(
final long bucketSize, final long bucketSize,
final long leakRatePerMillis, final long leakRatePerMillis,

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.limits;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -18,6 +19,7 @@ import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitPolicy;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
@ -56,9 +58,13 @@ public class RateLimitersTest {
attachmentCreate: attachmentCreate:
bucketSize: 4 bucketSize: 4
leakRatePerMinute: 2 leakRatePerMinute: 2
rateLimitPolicy:
failOpen: true
"""; """;
public record GenericHolder(@Valid @NotNull @JsonProperty Map<String, RateLimiterConfig> limits) { public record GenericHolder(
@Valid @NotNull @JsonProperty Map<String, RateLimiterConfig> limits,
@Valid @JsonProperty DynamicRateLimitPolicy rateLimitPolicy) {
} }
@Test @Test
@ -70,6 +76,7 @@ public class RateLimitersTest {
}); });
final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(GOOD_YAML, GenericHolder.class).orElseThrow(); final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(GOOD_YAML, GenericHolder.class).orElseThrow();
assertTrue(cfg.rateLimitPolicy.failOpen());
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, validateScript, redisCluster, clock); final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, validateScript, redisCluster, clock);
rateLimiters.validateValuesAndConfigs(); rateLimiters.validateValuesAndConfigs();
} }