Exclude `RateLimitExceededException` from fail-open checks

This commit is contained in:
Jon Chambers 2025-05-12 18:24:57 -04:00 committed by GitHub
parent cc7b030a41
commit 30c194c557
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 84 additions and 1 deletions

View File

@ -62,6 +62,10 @@ public class StaticRateLimiter implements RateLimiter {
throw new RateLimitExceededException(retryAfter);
}
} catch (final Exception e) {
if (e instanceof RateLimitExceededException rateLimitExceededException) {
throw rateLimitExceededException;
}
if (!config.failOpen()) {
throw e;
}
@ -81,10 +85,15 @@ public class StaticRateLimiter implements RateLimiter {
return failedFuture(new RateLimitExceededException(retryAfter));
})
.exceptionally(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof RateLimitExceededException rateLimitExceededException) {
throw ExceptionUtils.wrap(rateLimitExceededException);
}
if (config.failOpen()) {
return null;
}
throw ExceptionUtils.wrap(new RateLimitExceededException(null));
throw ExceptionUtils.wrap(throwable);
});
}

View File

@ -0,0 +1,74 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertThrows;
import io.lettuce.core.ScriptOutputType;
import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.util.concurrent.CompletionException;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.util.TestClock;
class StaticRateLimiterTest {
private ClusterLuaScript validateRateLimitScript;
private static final TestClock CLOCK = TestClock.pinned(Instant.now());
@RegisterExtension
private static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@BeforeEach
void setUp() throws IOException {
validateRateLimitScript = ClusterLuaScript.fromResource(
REDIS_CLUSTER_EXTENSION.getRedisCluster(), "lua/validate_rate_limit.lua", ScriptOutputType.INTEGER);
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void validate(final boolean failOpen) {
final StaticRateLimiter rateLimiter = new StaticRateLimiter("test",
new RateLimiterConfig(1, Duration.ofHours(1), failOpen),
validateRateLimitScript,
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
CLOCK);
final String key = RandomStringUtils.insecure().nextAlphanumeric(16);
assertDoesNotThrow(() -> rateLimiter.validate(key));
assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key));
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void validateAsync(final boolean failOpen) {
final StaticRateLimiter rateLimiter = new StaticRateLimiter("test",
new RateLimiterConfig(1, Duration.ofHours(1), failOpen),
validateRateLimitScript,
REDIS_CLUSTER_EXTENSION.getRedisCluster(),
CLOCK);
final String key = RandomStringUtils.insecure().nextAlphanumeric(16);
assertDoesNotThrow(() -> rateLimiter.validateAsync(key).toCompletableFuture().join());
final CompletionException completionException =
assertThrows(CompletionException.class, () -> rateLimiter.validateAsync(key).toCompletableFuture().join());
assertInstanceOf(RateLimitExceededException.class, completionException.getCause());
}
}