From 30c194c5576520998e798e093a952d9c2ab7d9c3 Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Mon, 12 May 2025 18:24:57 -0400 Subject: [PATCH] Exclude `RateLimitExceededException` from fail-open checks --- .../limits/StaticRateLimiter.java | 11 ++- .../limits/StaticRateLimiterTest.java | 74 +++++++++++++++++++ 2 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiterTest.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java index 28b7c31b4..5aa611b84 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiter.java @@ -62,6 +62,10 @@ public class StaticRateLimiter implements RateLimiter { throw new RateLimitExceededException(retryAfter); } } catch (final Exception e) { + if (e instanceof RateLimitExceededException rateLimitExceededException) { + throw rateLimitExceededException; + } + if (!config.failOpen()) { throw e; } @@ -81,10 +85,15 @@ public class StaticRateLimiter implements RateLimiter { return failedFuture(new RateLimitExceededException(retryAfter)); }) .exceptionally(throwable -> { + if (ExceptionUtils.unwrap(throwable) instanceof RateLimitExceededException rateLimitExceededException) { + throw ExceptionUtils.wrap(rateLimitExceededException); + } + if (config.failOpen()) { return null; } - throw ExceptionUtils.wrap(new RateLimitExceededException(null)); + + throw ExceptionUtils.wrap(throwable); }); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiterTest.java new file mode 100644 index 000000000..1f355ca0f --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/StaticRateLimiterTest.java @@ -0,0 +1,74 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.limits; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import io.lettuce.core.ScriptOutputType; +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.concurrent.CompletionException; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; +import org.whispersystems.textsecuregcm.util.TestClock; + +class StaticRateLimiterTest { + + private ClusterLuaScript validateRateLimitScript; + + private static final TestClock CLOCK = TestClock.pinned(Instant.now()); + + @RegisterExtension + private static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); + + @BeforeEach + void setUp() throws IOException { + validateRateLimitScript = ClusterLuaScript.fromResource( + REDIS_CLUSTER_EXTENSION.getRedisCluster(), "lua/validate_rate_limit.lua", ScriptOutputType.INTEGER); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void validate(final boolean failOpen) { + final StaticRateLimiter rateLimiter = new StaticRateLimiter("test", + new RateLimiterConfig(1, Duration.ofHours(1), failOpen), + validateRateLimitScript, + REDIS_CLUSTER_EXTENSION.getRedisCluster(), + CLOCK); + + final String key = RandomStringUtils.insecure().nextAlphanumeric(16); + + assertDoesNotThrow(() -> rateLimiter.validate(key)); + assertThrows(RateLimitExceededException.class, () -> rateLimiter.validate(key)); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void validateAsync(final boolean failOpen) { + final StaticRateLimiter rateLimiter = new StaticRateLimiter("test", + new RateLimiterConfig(1, Duration.ofHours(1), failOpen), + validateRateLimitScript, + REDIS_CLUSTER_EXTENSION.getRedisCluster(), + CLOCK); + + final String key = RandomStringUtils.insecure().nextAlphanumeric(16); + + assertDoesNotThrow(() -> rateLimiter.validateAsync(key).toCompletableFuture().join()); + final CompletionException completionException = + assertThrows(CompletionException.class, () -> rateLimiter.validateAsync(key).toCompletableFuture().join()); + + assertInstanceOf(RateLimitExceededException.class, completionException.getCause()); + } +}