Use Durations everywhere, drop unused constructors, and add tests.
This commit is contained in:
parent
1faedd3870
commit
af2a8548c3
|
@ -10,19 +10,14 @@ public class RateLimitExceededException extends Exception {
|
||||||
|
|
||||||
private final Duration retryDuration;
|
private final Duration retryDuration;
|
||||||
|
|
||||||
public RateLimitExceededException() {
|
public RateLimitExceededException(final Duration retryDuration) {
|
||||||
super();
|
super();
|
||||||
retryDuration = Duration.ZERO;
|
this.retryDuration = retryDuration;
|
||||||
}
|
}
|
||||||
|
|
||||||
public RateLimitExceededException(String message) {
|
public RateLimitExceededException(final String message, final Duration retryDuration) {
|
||||||
super(message);
|
super(message);
|
||||||
retryDuration = Duration.ZERO;
|
this.retryDuration = retryDuration;
|
||||||
}
|
|
||||||
|
|
||||||
public RateLimitExceededException(String message, long retryAfterMillis) {
|
|
||||||
super(message);
|
|
||||||
retryDuration = Duration.ofMillis(retryAfterMillis);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Duration getRetryDuration() { return retryDuration; }
|
public Duration getRetryDuration() { return retryDuration; }
|
||||||
|
|
|
@ -10,14 +10,8 @@ import java.time.Duration;
|
||||||
public class RetryLaterException extends Exception {
|
public class RetryLaterException extends Exception {
|
||||||
private final Duration backoffDuration;
|
private final Duration backoffDuration;
|
||||||
|
|
||||||
public RetryLaterException() {
|
|
||||||
backoffDuration = Duration.ZERO;
|
|
||||||
}
|
|
||||||
public RetryLaterException(int retryLaterMillis) {
|
|
||||||
backoffDuration = Duration.ofMillis(retryLaterMillis);
|
|
||||||
}
|
|
||||||
|
|
||||||
public RetryLaterException(RateLimitExceededException e) {
|
public RetryLaterException(RateLimitExceededException e) {
|
||||||
|
super(e);
|
||||||
this.backoffDuration = e.getRetryDuration();
|
this.backoffDuration = e.getRetryDuration();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -59,7 +59,8 @@ public class CardinalityRateLimiter {
|
||||||
});
|
});
|
||||||
|
|
||||||
if (rateLimitExceeded) {
|
if (rateLimitExceeded) {
|
||||||
throw new RateLimitExceededException();
|
// Using the TTL as the "retry after" time isn't EXACTLY right, but it's a reasonable approximation
|
||||||
|
throw new RateLimitExceededException(ttl);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.time.Duration;
|
||||||
|
|
||||||
public class LeakyBucket {
|
public class LeakyBucket {
|
||||||
|
|
||||||
|
@ -48,15 +49,15 @@ public class LeakyBucket {
|
||||||
(int)Math.floor(this.spaceRemaining + (elapsedTime * this.leakRatePerMillis)));
|
(int)Math.floor(this.spaceRemaining + (elapsedTime * this.leakRatePerMillis)));
|
||||||
}
|
}
|
||||||
|
|
||||||
public long getMillisUntilSpace(double amount) {
|
public Duration getTimeUntilSpaceAvailable(int amount) {
|
||||||
int currentSpaceRemaining = getUpdatedSpaceRemaining();
|
int currentSpaceRemaining = getUpdatedSpaceRemaining();
|
||||||
if (currentSpaceRemaining >= amount) {
|
if (currentSpaceRemaining >= amount) {
|
||||||
return 0;
|
return Duration.ZERO;
|
||||||
} else if (amount > this.bucketSize) {
|
} else if (amount > this.bucketSize) {
|
||||||
// This shouldn't happen today but if so we should bubble this to the clients somehow
|
// This shouldn't happen today but if so we should bubble this to the clients somehow
|
||||||
return -1;
|
throw new IllegalArgumentException("Requested permits exceed maximum bucket size");
|
||||||
} else {
|
} else {
|
||||||
return (long)Math.ceil(amount - currentSpaceRemaining / this.leakRatePerMillis);
|
return Duration.ofMillis((long)Math.ceil((double)(amount - currentSpaceRemaining) / this.leakRatePerMillis));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,8 @@ import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||||
import org.whispersystems.textsecuregcm.util.Constants;
|
import org.whispersystems.textsecuregcm.util.Constants;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
|
||||||
import static com.codahale.metrics.MetricRegistry.name;
|
import static com.codahale.metrics.MetricRegistry.name;
|
||||||
|
|
||||||
public class LockingRateLimiter extends RateLimiter {
|
public class LockingRateLimiter extends RateLimiter {
|
||||||
|
@ -30,7 +32,7 @@ public class LockingRateLimiter extends RateLimiter {
|
||||||
public void validate(String key, int amount) throws RateLimitExceededException {
|
public void validate(String key, int amount) throws RateLimitExceededException {
|
||||||
if (!acquireLock(key)) {
|
if (!acquireLock(key)) {
|
||||||
meter.mark();
|
meter.mark();
|
||||||
throw new RateLimitExceededException("Locked");
|
throw new RateLimitExceededException("Locked", Duration.ZERO);
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|
|
@ -19,6 +19,7 @@ import org.whispersystems.textsecuregcm.util.Constants;
|
||||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.time.Duration;
|
||||||
|
|
||||||
import static com.codahale.metrics.MetricRegistry.name;
|
import static com.codahale.metrics.MetricRegistry.name;
|
||||||
|
|
||||||
|
@ -28,7 +29,7 @@ public class RateLimiter {
|
||||||
private final ObjectMapper mapper = SystemMapper.getMapper();
|
private final ObjectMapper mapper = SystemMapper.getMapper();
|
||||||
|
|
||||||
private final Meter meter;
|
private final Meter meter;
|
||||||
protected final Timer validateTimer;
|
private final Timer validateTimer;
|
||||||
protected final FaultTolerantRedisCluster cacheCluster;
|
protected final FaultTolerantRedisCluster cacheCluster;
|
||||||
protected final String name;
|
protected final String name;
|
||||||
private final int bucketSize;
|
private final int bucketSize;
|
||||||
|
@ -66,7 +67,7 @@ public class RateLimiter {
|
||||||
setBucket(key, bucket);
|
setBucket(key, bucket);
|
||||||
} else {
|
} else {
|
||||||
meter.mark();
|
meter.mark();
|
||||||
throw new RateLimitExceededException(key + " , " + amount, bucket.getMillisUntilSpace(amount));
|
throw new RateLimitExceededException(key + " , " + amount, bucket.getTimeUntilSpaceAvailable(amount));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -87,7 +88,7 @@ public class RateLimiter {
|
||||||
return leakRatePerMinute;
|
return leakRatePerMinute;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void setBucket(String key, LeakyBucket bucket) {
|
private void setBucket(String key, LeakyBucket bucket) {
|
||||||
try {
|
try {
|
||||||
final String serialized = bucket.serialize(mapper);
|
final String serialized = bucket.serialize(mapper);
|
||||||
|
|
||||||
|
@ -97,7 +98,7 @@ public class RateLimiter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected LeakyBucket getBucket(String key) {
|
private LeakyBucket getBucket(String key) {
|
||||||
try {
|
try {
|
||||||
final String serialized = cacheCluster.withCluster(connection -> connection.sync().get(getBucketName(key)));
|
final String serialized = cacheCluster.withCluster(connection -> connection.sync().get(getBucketName(key)));
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@ import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
|
||||||
import io.dropwizard.testing.junit.ResourceTestRule;
|
import io.dropwizard.testing.junit.ResourceTestRule;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.security.SecureRandom;
|
import java.security.SecureRandom;
|
||||||
|
import java.time.Duration;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
@ -209,14 +210,14 @@ public class AccountControllerTest {
|
||||||
when(recaptchaClient.verify(eq(INVALID_CAPTCHA_TOKEN), anyString())).thenReturn(false);
|
when(recaptchaClient.verify(eq(INVALID_CAPTCHA_TOKEN), anyString())).thenReturn(false);
|
||||||
when(recaptchaClient.verify(eq(VALID_CAPTCHA_TOKEN), anyString())).thenReturn(true);
|
when(recaptchaClient.verify(eq(VALID_CAPTCHA_TOKEN), anyString())).thenReturn(true);
|
||||||
|
|
||||||
doThrow(new RateLimitExceededException(SENDER_OVER_PIN)).when(pinLimiter).validate(eq(SENDER_OVER_PIN));
|
doThrow(new RateLimitExceededException(SENDER_OVER_PIN, Duration.ZERO)).when(pinLimiter).validate(eq(SENDER_OVER_PIN));
|
||||||
|
|
||||||
doThrow(new RateLimitExceededException(RATE_LIMITED_PREFIX_HOST)).when(autoBlockLimiter).validate(eq(RATE_LIMITED_PREFIX_HOST));
|
doThrow(new RateLimitExceededException(RATE_LIMITED_PREFIX_HOST, Duration.ZERO)).when(autoBlockLimiter).validate(eq(RATE_LIMITED_PREFIX_HOST));
|
||||||
doThrow(new RateLimitExceededException(RATE_LIMITED_IP_HOST)).when(autoBlockLimiter).validate(eq(RATE_LIMITED_IP_HOST));
|
doThrow(new RateLimitExceededException(RATE_LIMITED_IP_HOST, Duration.ZERO)).when(autoBlockLimiter).validate(eq(RATE_LIMITED_IP_HOST));
|
||||||
|
|
||||||
doThrow(new RateLimitExceededException(SENDER_OVER_PREFIX)).when(smsVoicePrefixLimiter).validate(SENDER_OVER_PREFIX.substring(0, 4+2));
|
doThrow(new RateLimitExceededException(SENDER_OVER_PREFIX, Duration.ZERO)).when(smsVoicePrefixLimiter).validate(SENDER_OVER_PREFIX.substring(0, 4+2));
|
||||||
doThrow(new RateLimitExceededException(RATE_LIMITED_IP_HOST)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_IP_HOST);
|
doThrow(new RateLimitExceededException(RATE_LIMITED_IP_HOST, Duration.ZERO)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_IP_HOST);
|
||||||
doThrow(new RateLimitExceededException(RATE_LIMITED_HOST2)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_HOST2);
|
doThrow(new RateLimitExceededException(RATE_LIMITED_HOST2, Duration.ZERO)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_HOST2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -10,9 +10,12 @@ import org.junit.Test;
|
||||||
import org.whispersystems.textsecuregcm.limits.LeakyBucket;
|
import org.whispersystems.textsecuregcm.limits.LeakyBucket;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.time.Duration;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertFalse;
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertThrows;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
public class LeakyBucketTest {
|
public class LeakyBucketTest {
|
||||||
|
@ -54,4 +57,30 @@ public class LeakyBucketTest {
|
||||||
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
|
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
|
||||||
assertFalse(leakyBucket.add(1));
|
assertFalse(leakyBucket.add(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGetTimeUntilSpaceAvailable() throws Exception {
|
||||||
|
ObjectMapper mapper = new ObjectMapper();
|
||||||
|
|
||||||
|
{
|
||||||
|
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":2,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
|
||||||
|
|
||||||
|
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
|
||||||
|
|
||||||
|
assertEquals(Duration.ZERO, leakyBucket.getTimeUntilSpaceAvailable(1));
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> leakyBucket.getTimeUntilSpaceAvailable(5000));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
String serialized = "{\"bucketSize\":2,\"leakRatePerMillis\":8.333333333333334E-6,\"spaceRemaining\":0,\"lastUpdateTimeMillis\":" + (System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(1)) + "}";
|
||||||
|
|
||||||
|
LeakyBucket leakyBucket = LeakyBucket.fromSerialized(mapper, serialized);
|
||||||
|
|
||||||
|
Duration timeUntilSpaceAvailable = leakyBucket.getTimeUntilSpaceAvailable(1);
|
||||||
|
|
||||||
|
// TODO Refactor LeakyBucket to be more test-friendly and accept a Clock
|
||||||
|
assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMillis(119_000)) > 0);
|
||||||
|
assertTrue(timeUntilSpaceAvailable.compareTo(Duration.ofMinutes(2)) <= 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue