Return Retry-After time to clients when they are rate limited (#421)

* Return Retry-After time to clients when they are rate limited

* Update based on feedback

- New exception type that is mapped differently
- Always report time until allowed on rate limits
- Consume and transform into a differnt exception if we think it will be
  allowed later
This commit is contained in:
brock-signal 2021-03-05 10:23:03 -07:00 committed by GitHub
parent f57a4171ba
commit 1faedd3870
7 changed files with 109 additions and 17 deletions

View File

@ -100,6 +100,7 @@ import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapp
import org.whispersystems.textsecuregcm.mappers.IOExceptionMapper; import org.whispersystems.textsecuregcm.mappers.IOExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.InvalidWebsocketAddressExceptionMapper; import org.whispersystems.textsecuregcm.mappers.InvalidWebsocketAddressExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RetryLaterExceptionMapper;
import org.whispersystems.textsecuregcm.metrics.BufferPoolGauges; import org.whispersystems.textsecuregcm.metrics.BufferPoolGauges;
import org.whispersystems.textsecuregcm.metrics.CpuUsageGauge; import org.whispersystems.textsecuregcm.metrics.CpuUsageGauge;
import org.whispersystems.textsecuregcm.metrics.FileDescriptorGauge; import org.whispersystems.textsecuregcm.metrics.FileDescriptorGauge;
@ -498,16 +499,19 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.jersey().register(new RateLimitExceededExceptionMapper()); environment.jersey().register(new RateLimitExceededExceptionMapper());
environment.jersey().register(new InvalidWebsocketAddressExceptionMapper()); environment.jersey().register(new InvalidWebsocketAddressExceptionMapper());
environment.jersey().register(new DeviceLimitExceededExceptionMapper()); environment.jersey().register(new DeviceLimitExceededExceptionMapper());
environment.jersey().register(new RetryLaterExceptionMapper());
webSocketEnvironment.jersey().register(new IOExceptionMapper()); webSocketEnvironment.jersey().register(new IOExceptionMapper());
webSocketEnvironment.jersey().register(new RateLimitExceededExceptionMapper()); webSocketEnvironment.jersey().register(new RateLimitExceededExceptionMapper());
webSocketEnvironment.jersey().register(new InvalidWebsocketAddressExceptionMapper()); webSocketEnvironment.jersey().register(new InvalidWebsocketAddressExceptionMapper());
webSocketEnvironment.jersey().register(new DeviceLimitExceededExceptionMapper()); webSocketEnvironment.jersey().register(new DeviceLimitExceededExceptionMapper());
webSocketEnvironment.jersey().register(new RetryLaterExceptionMapper());
provisioningEnvironment.jersey().register(new IOExceptionMapper()); provisioningEnvironment.jersey().register(new IOExceptionMapper());
provisioningEnvironment.jersey().register(new RateLimitExceededExceptionMapper()); provisioningEnvironment.jersey().register(new RateLimitExceededExceptionMapper());
provisioningEnvironment.jersey().register(new InvalidWebsocketAddressExceptionMapper()); provisioningEnvironment.jersey().register(new InvalidWebsocketAddressExceptionMapper());
provisioningEnvironment.jersey().register(new DeviceLimitExceededExceptionMapper()); provisioningEnvironment.jersey().register(new DeviceLimitExceededExceptionMapper());
provisioningEnvironment.jersey().register(new RetryLaterExceptionMapper());
} }
private void registerCorsFilter(Environment environment) { private void registerCorsFilter(Environment environment) {

View File

@ -187,7 +187,7 @@ public class AccountController {
@QueryParam("client") Optional<String> client, @QueryParam("client") Optional<String> client,
@QueryParam("captcha") Optional<String> captcha, @QueryParam("captcha") Optional<String> captcha,
@QueryParam("challenge") Optional<String> pushChallenge) @QueryParam("challenge") Optional<String> pushChallenge)
throws RateLimitExceededException throws RateLimitExceededException, RetryLaterException
{ {
if (!Util.isValidNumber(number)) { if (!Util.isValidNumber(number)) {
logger.info("Invalid number: " + number); logger.info("Invalid number: " + number);
@ -217,6 +217,7 @@ public class AccountController {
return Response.status(402).build(); return Response.status(402).build();
} }
try {
switch (transport) { switch (transport) {
case "sms": case "sms":
rateLimiters.getSmsDestinationLimiter().validate(number); rateLimiters.getSmsDestinationLimiter().validate(number);
@ -228,6 +229,13 @@ public class AccountController {
default: default:
throw new WebApplicationException(Response.status(422).build()); throw new WebApplicationException(Response.status(422).build());
} }
} catch (RateLimitExceededException e) {
if (!e.getRetryDuration().isNegative()) {
throw new RetryLaterException(e);
} else {
throw e;
}
}
VerificationCode verificationCode = generateVerificationCode(number); VerificationCode verificationCode = generateVerificationCode(number);
StoredVerificationCode storedVerificationCode = new StoredVerificationCode(verificationCode.getVerificationCode(), StoredVerificationCode storedVerificationCode = new StoredVerificationCode(verificationCode.getVerificationCode(),

View File

@ -4,12 +4,26 @@
*/ */
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import java.time.Duration;
public class RateLimitExceededException extends Exception { public class RateLimitExceededException extends Exception {
private final Duration retryDuration;
public RateLimitExceededException() { public RateLimitExceededException() {
super(); super();
retryDuration = Duration.ZERO;
} }
public RateLimitExceededException(String number) { public RateLimitExceededException(String message) {
super(number); super(message);
retryDuration = Duration.ZERO;
} }
public RateLimitExceededException(String message, long retryAfterMillis) {
super(message);
retryDuration = Duration.ofMillis(retryAfterMillis);
}
public Duration getRetryDuration() { return retryDuration; }
} }

View File

@ -0,0 +1,25 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import java.time.Duration;
public class RetryLaterException extends Exception {
private final Duration backoffDuration;
public RetryLaterException() {
backoffDuration = Duration.ZERO;
}
public RetryLaterException(int retryLaterMillis) {
backoffDuration = Duration.ofMillis(retryLaterMillis);
}
public RetryLaterException(RateLimitExceededException e) {
this.backoffDuration = e.getRetryDuration();
}
public Duration getBackoffDuration() { return backoffDuration; }
}

View File

@ -48,6 +48,18 @@ public class LeakyBucket {
(int)Math.floor(this.spaceRemaining + (elapsedTime * this.leakRatePerMillis))); (int)Math.floor(this.spaceRemaining + (elapsedTime * this.leakRatePerMillis)));
} }
public long getMillisUntilSpace(double amount) {
int currentSpaceRemaining = getUpdatedSpaceRemaining();
if (currentSpaceRemaining >= amount) {
return 0;
} else if (amount > this.bucketSize) {
// This shouldn't happen today but if so we should bubble this to the clients somehow
return -1;
} else {
return (long)Math.ceil(amount - currentSpaceRemaining / this.leakRatePerMillis);
}
}
public String serialize(ObjectMapper mapper) throws JsonProcessingException { public String serialize(ObjectMapper mapper) throws JsonProcessingException {
return mapper.writeValueAsString(new LeakyBucketEntity(bucketSize, leakRatePerMillis, spaceRemaining, lastUpdateTimeMillis)); return mapper.writeValueAsString(new LeakyBucketEntity(bucketSize, leakRatePerMillis, spaceRemaining, lastUpdateTimeMillis));
} }

View File

@ -28,7 +28,7 @@ public class RateLimiter {
private final ObjectMapper mapper = SystemMapper.getMapper(); private final ObjectMapper mapper = SystemMapper.getMapper();
private final Meter meter; private final Meter meter;
private final Timer validateTimer; protected final Timer validateTimer;
protected final FaultTolerantRedisCluster cacheCluster; protected final FaultTolerantRedisCluster cacheCluster;
protected final String name; protected final String name;
private final int bucketSize; private final int bucketSize;
@ -66,7 +66,7 @@ public class RateLimiter {
setBucket(key, bucket); setBucket(key, bucket);
} else { } else {
meter.mark(); meter.mark();
throw new RateLimitExceededException(key + " , " + amount); throw new RateLimitExceededException(key + " , " + amount, bucket.getMillisUntilSpace(amount));
} }
} }
} }
@ -87,7 +87,7 @@ public class RateLimiter {
return leakRatePerMinute; return leakRatePerMinute;
} }
private void setBucket(String key, LeakyBucket bucket) { protected void setBucket(String key, LeakyBucket bucket) {
try { try {
final String serialized = bucket.serialize(mapper); final String serialized = bucket.serialize(mapper);
@ -97,7 +97,7 @@ public class RateLimiter {
} }
} }
private LeakyBucket getBucket(String key) { protected LeakyBucket getBucket(String key) {
try { try {
final String serialized = cacheCluster.withCluster(connection -> connection.sync().get(getBucketName(key))); final String serialized = cacheCluster.withCluster(connection -> connection.sync().get(getBucketName(key)));

View File

@ -0,0 +1,29 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.mappers;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import javax.ws.rs.core.Response;
import javax.ws.rs.ext.ExceptionMapper;
import javax.ws.rs.ext.Provider;
import org.whispersystems.textsecuregcm.controllers.RetryLaterException;
import javax.ws.rs.core.Response;
import javax.ws.rs.ext.ExceptionMapper;
import javax.ws.rs.ext.Provider;
import java.time.Duration;
@Provider
public class RetryLaterExceptionMapper implements ExceptionMapper<RetryLaterException> {
@Override
public Response toResponse(RetryLaterException e) {
return Response.status(413)
.header("Retry-After", e.getBackoffDuration().toSeconds())
.build();
}
}