From af2a8548c373b743f9d09b596b4e6ac50e91c695 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Fri, 5 Mar 2021 12:06:26 -0500 Subject: [PATCH] Use Durations everywhere, drop unused constructors, and add tests. --- .../RateLimitExceededException.java | 13 +++------ .../controllers/RetryLaterException.java | 8 +---- .../limits/CardinalityRateLimiter.java | 3 +- .../textsecuregcm/limits/LeakyBucket.java | 9 +++--- .../limits/LockingRateLimiter.java | 4 ++- .../textsecuregcm/limits/RateLimiter.java | 9 +++--- .../controllers/AccountControllerTest.java | 13 +++++---- .../tests/limits/LeakyBucketTest.java | 29 +++++++++++++++++++ 8 files changed, 56 insertions(+), 32 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RateLimitExceededException.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RateLimitExceededException.java index 9c7e4f1da..9ca296ad3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RateLimitExceededException.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RateLimitExceededException.java @@ -10,19 +10,14 @@ public class RateLimitExceededException extends Exception { private final Duration retryDuration; - public RateLimitExceededException() { + public RateLimitExceededException(final Duration retryDuration) { super(); - retryDuration = Duration.ZERO; + this.retryDuration = retryDuration; } - public RateLimitExceededException(String message) { + public RateLimitExceededException(final String message, final Duration retryDuration) { super(message); - retryDuration = Duration.ZERO; - } - - public RateLimitExceededException(String message, long retryAfterMillis) { - super(message); - retryDuration = Duration.ofMillis(retryAfterMillis); + this.retryDuration = retryDuration; } public Duration getRetryDuration() { return retryDuration; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RetryLaterException.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RetryLaterException.java index b577aee9d..c7c0afffa 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RetryLaterException.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RetryLaterException.java @@ -10,14 +10,8 @@ import java.time.Duration; public class RetryLaterException extends Exception { private final Duration backoffDuration; - public RetryLaterException() { - backoffDuration = Duration.ZERO; - } - public RetryLaterException(int retryLaterMillis) { - backoffDuration = Duration.ofMillis(retryLaterMillis); - } - public RetryLaterException(RateLimitExceededException e) { + super(e); this.backoffDuration = e.getRetryDuration(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/CardinalityRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/CardinalityRateLimiter.java index f9dc918e5..522d37261 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/CardinalityRateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/CardinalityRateLimiter.java @@ -59,7 +59,8 @@ public class CardinalityRateLimiter { }); if (rateLimitExceeded) { - throw new RateLimitExceededException(); + // Using the TTL as the "retry after" time isn't EXACTLY right, but it's a reasonable approximation + throw new RateLimitExceededException(ttl); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java index 6f76b98fd..51b60ea8d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LeakyBucket.java @@ -9,6 +9,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; +import java.time.Duration; public class LeakyBucket { @@ -48,15 +49,15 @@ public class LeakyBucket { (int)Math.floor(this.spaceRemaining + (elapsedTime * this.leakRatePerMillis))); } - public long getMillisUntilSpace(double amount) { + public Duration getTimeUntilSpaceAvailable(int amount) { int currentSpaceRemaining = getUpdatedSpaceRemaining(); if (currentSpaceRemaining >= amount) { - return 0; + return Duration.ZERO; } else if (amount > this.bucketSize) { // This shouldn't happen today but if so we should bubble this to the clients somehow - return -1; + throw new IllegalArgumentException("Requested permits exceed maximum bucket size"); } else { - return (long)Math.ceil(amount - currentSpaceRemaining / this.leakRatePerMillis); + return Duration.ofMillis((long)Math.ceil((double)(amount - currentSpaceRemaining) / this.leakRatePerMillis)); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java index 067b7fb7f..935cabfa3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java @@ -13,6 +13,8 @@ import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.util.Constants; +import java.time.Duration; + import static com.codahale.metrics.MetricRegistry.name; public class LockingRateLimiter extends RateLimiter { @@ -30,7 +32,7 @@ public class LockingRateLimiter extends RateLimiter { public void validate(String key, int amount) throws RateLimitExceededException { if (!acquireLock(key)) { meter.mark(); - throw new RateLimitExceededException("Locked"); + throw new RateLimitExceededException("Locked", Duration.ZERO); } try { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java index 8006c40d1..4c4793670 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java @@ -19,6 +19,7 @@ import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.SystemMapper; import java.io.IOException; +import java.time.Duration; import static com.codahale.metrics.MetricRegistry.name; @@ -28,7 +29,7 @@ public class RateLimiter { private final ObjectMapper mapper = SystemMapper.getMapper(); private final Meter meter; - protected final Timer validateTimer; + private final Timer validateTimer; protected final FaultTolerantRedisCluster cacheCluster; protected final String name; private final int bucketSize; @@ -66,7 +67,7 @@ public class RateLimiter { setBucket(key, bucket); } else { meter.mark(); - throw new RateLimitExceededException(key + " , " + amount, bucket.getMillisUntilSpace(amount)); + throw new RateLimitExceededException(key + " , " + amount, bucket.getTimeUntilSpaceAvailable(amount)); } } } @@ -87,7 +88,7 @@ public class RateLimiter { return leakRatePerMinute; } - protected void setBucket(String key, LeakyBucket bucket) { + private void setBucket(String key, LeakyBucket bucket) { try { final String serialized = bucket.serialize(mapper); @@ -97,7 +98,7 @@ public class RateLimiter { } } - protected LeakyBucket getBucket(String key) { + private LeakyBucket getBucket(String key) { try { final String serialized = cacheCluster.withCluster(connection -> connection.sync().get(getBucketName(key))); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java index 26ae7164c..275df6a98 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java @@ -25,6 +25,7 @@ import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.testing.junit.ResourceTestRule; import java.io.IOException; import java.security.SecureRandom; +import java.time.Duration; import java.util.Collections; import java.util.HashMap; import java.util.Optional; @@ -209,14 +210,14 @@ public class AccountControllerTest { when(recaptchaClient.verify(eq(INVALID_CAPTCHA_TOKEN), anyString())).thenReturn(false); when(recaptchaClient.verify(eq(VALID_CAPTCHA_TOKEN), anyString())).thenReturn(true); - doThrow(new RateLimitExceededException(SENDER_OVER_PIN)).when(pinLimiter).validate(eq(SENDER_OVER_PIN)); + doThrow(new RateLimitExceededException(SENDER_OVER_PIN, Duration.ZERO)).when(pinLimiter).validate(eq(SENDER_OVER_PIN)); - doThrow(new RateLimitExceededException(RATE_LIMITED_PREFIX_HOST)).when(autoBlockLimiter).validate(eq(RATE_LIMITED_PREFIX_HOST)); - doThrow(new RateLimitExceededException(RATE_LIMITED_IP_HOST)).when(autoBlockLimiter).validate(eq(RATE_LIMITED_IP_HOST)); + doThrow(new RateLimitExceededException(RATE_LIMITED_PREFIX_HOST, Duration.ZERO)).when(autoBlockLimiter).validate(eq(RATE_LIMITED_PREFIX_HOST)); + doThrow(new RateLimitExceededException(RATE_LIMITED_IP_HOST, Duration.ZERO)).when(autoBlockLimiter).validate(eq(RATE_LIMITED_IP_HOST)); - doThrow(new RateLimitExceededException(SENDER_OVER_PREFIX)).when(smsVoicePrefixLimiter).validate(SENDER_OVER_PREFIX.substring(0, 4+2)); - doThrow(new RateLimitExceededException(RATE_LIMITED_IP_HOST)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_IP_HOST); - doThrow(new RateLimitExceededException(RATE_LIMITED_HOST2)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_HOST2); + doThrow(new RateLimitExceededException(SENDER_OVER_PREFIX, Duration.ZERO)).when(smsVoicePrefixLimiter).validate(SENDER_OVER_PREFIX.substring(0, 4+2)); + doThrow(new RateLimitExceededException(RATE_LIMITED_IP_HOST, Duration.ZERO)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_IP_HOST); + doThrow(new RateLimitExceededException(RATE_LIMITED_HOST2, Duration.ZERO)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_HOST2); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/limits/LeakyBucketTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/limits/LeakyBucketTest.java index a826b8e4a..236832fa4 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/limits/LeakyBucketTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/limits/LeakyBucketTest.java @@ -10,9 +10,12 @@ import org.junit.Test; import org.whispersystems.textsecuregcm.limits.LeakyBucket; import java.io.IOException; +import java.time.Duration; import java.util.concurrent.TimeUnit; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; public class LeakyBucketTest { @@ -54,4 +57,30 @@ public class LeakyBucketTest { LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized); assertFalse(leakyBucket.add(1)); } + + @Test + public void testGetTimeUntilSpaceAvailable() throws Exception { + ObjectMapper mapper = new ObjectMapper(); + + { + String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":2,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}"; + + LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized); + + assertEquals(Duration.ZERO, leakyBucket.getTimeUntilSpaceAvailable(1)); + assertThrows(IllegalArgumentException.class, () -> leakyBucket.getTimeUntilSpaceAvailable(5000)); + } + + { + String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}"; + + LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized); + + Duration timeUntilSpaceAvailable = leakyBucket.getTimeUntilSpaceAvailable(1); + + // TODO Refactor LeakyBucket to be more test-friendly and accept a Clock + assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMillis(119_000)) > 0); + assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMinutes(2)) <= 0); + } + } }