Return a Retry-After on rate-limited responses

Previously, only endpoints throwing a RetryLaterException would include
a Retry-After header in the 413 response. Now, by default, all
RateLimitExceededExceptions will be marshalled into a 413 with a
Retry-After included if possible.
This commit is contained in:
Ravi Khadiwala 2022-02-16 14:34:25 -06:00 committed by ravi-signal
parent 43792e2426
commit ae3a5c5f5e
11 changed files with 103 additions and 77 deletions

View File

@ -123,7 +123,6 @@ import org.whispersystems.textsecuregcm.mappers.InvalidWebsocketAddressException
import org.whispersystems.textsecuregcm.mappers.NonNormalizedPhoneNumberExceptionMapper; import org.whispersystems.textsecuregcm.mappers.NonNormalizedPhoneNumberExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitChallengeExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitChallengeExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RetryLaterExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.ServerRejectedExceptionMapper; import org.whispersystems.textsecuregcm.mappers.ServerRejectedExceptionMapper;
import org.whispersystems.textsecuregcm.metrics.ApplicationShutdownMonitor; import org.whispersystems.textsecuregcm.metrics.ApplicationShutdownMonitor;
import org.whispersystems.textsecuregcm.metrics.BufferPoolGauges; import org.whispersystems.textsecuregcm.metrics.BufferPoolGauges;
@ -758,7 +757,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new RateLimitExceededExceptionMapper(), new RateLimitExceededExceptionMapper(),
new InvalidWebsocketAddressExceptionMapper(), new InvalidWebsocketAddressExceptionMapper(),
new DeviceLimitExceededExceptionMapper(), new DeviceLimitExceededExceptionMapper(),
new RetryLaterExceptionMapper(),
new ServerRejectedExceptionMapper(), new ServerRejectedExceptionMapper(),
new ImpossiblePhoneNumberExceptionMapper(), new ImpossiblePhoneNumberExceptionMapper(),
new NonNormalizedPhoneNumberExceptionMapper() new NonNormalizedPhoneNumberExceptionMapper()

View File

@ -213,7 +213,7 @@ public class AccountController {
@QueryParam("client") Optional<String> client, @QueryParam("client") Optional<String> client,
@QueryParam("captcha") Optional<String> captcha, @QueryParam("captcha") Optional<String> captcha,
@QueryParam("challenge") Optional<String> pushChallenge) @QueryParam("challenge") Optional<String> pushChallenge)
throws RateLimitExceededException, RetryLaterException, ImpossiblePhoneNumberException, NonNormalizedPhoneNumberException { throws RateLimitExceededException, ImpossiblePhoneNumberException, NonNormalizedPhoneNumberException {
Util.requireNormalizedNumber(number); Util.requireNormalizedNumber(number);
@ -234,24 +234,16 @@ public class AccountController {
return Response.status(402).build(); return Response.status(402).build();
} }
try { switch (transport) {
switch (transport) { case "sms":
case "sms": rateLimiters.getSmsDestinationLimiter().validate(number);
rateLimiters.getSmsDestinationLimiter().validate(number); break;
break; case "voice":
case "voice": rateLimiters.getVoiceDestinationLimiter().validate(number);
rateLimiters.getVoiceDestinationLimiter().validate(number); rateLimiters.getVoiceDestinationDailyLimiter().validate(number);
rateLimiters.getVoiceDestinationDailyLimiter().validate(number); break;
break; default:
default: throw new WebApplicationException(Response.status(422).build());
throw new WebApplicationException(Response.status(422).build());
}
} catch (RateLimitExceededException e) {
if (!e.getRetryDuration().isNegative()) {
throw new RetryLaterException(e);
} else {
throw e;
}
} }
VerificationCode verificationCode = generateVerificationCode(number); VerificationCode verificationCode = generateVerificationCode(number);
@ -643,7 +635,9 @@ public class AccountController {
} }
final String mostRecentProxy = ForwardedIpUtil.getMostRecentProxy(forwardedFor) final String mostRecentProxy = ForwardedIpUtil.getMostRecentProxy(forwardedFor)
.orElseThrow(() -> new RateLimitExceededException(Duration.ofHours(1))); // Missing/malformed Forwarded-For, cannot calculate a reasonable backoff
// duration
.orElseThrow(() -> new RateLimitExceededException(Duration.ofHours(-1)));
rateLimiters.getCheckAccountExistenceLimiter().validate(mostRecentProxy); rateLimiters.getCheckAccountExistenceLimiter().validate(mostRecentProxy);

View File

@ -51,7 +51,7 @@ public class ChallengeController {
public Response handleChallengeResponse(@Auth final AuthenticatedAccount auth, public Response handleChallengeResponse(@Auth final AuthenticatedAccount auth,
@Valid final AnswerChallengeRequest answerRequest, @Valid final AnswerChallengeRequest answerRequest,
@HeaderParam("X-Forwarded-For") final String forwardedFor, @HeaderParam("X-Forwarded-For") final String forwardedFor,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) throws RetryLaterException { @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) throws RateLimitExceededException {
Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)); Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent));
@ -76,8 +76,6 @@ public class ChallengeController {
} else { } else {
tags = tags.and(CHALLENGE_TYPE_TAG, "unrecognized"); tags = tags.and(CHALLENGE_TYPE_TAG, "unrecognized");
} }
} catch (final RateLimitExceededException e) {
throw new RetryLaterException(e);
} finally { } finally {
Metrics.counter(CHALLENGE_RESPONSE_COUNTER_NAME, tags).increment(); Metrics.counter(CHALLENGE_RESPONSE_COUNTER_NAME, tags).increment();
} }

View File

@ -5,10 +5,11 @@
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import java.time.Duration; import java.time.Duration;
import java.util.Optional;
public class RateLimitExceededException extends Exception { public class RateLimitExceededException extends Exception {
private final Duration retryDuration; private final Optional<Duration> retryDuration;
public RateLimitExceededException(final Duration retryDuration) { public RateLimitExceededException(final Duration retryDuration) {
this(null, retryDuration); this(null, retryDuration);
@ -16,8 +17,9 @@ public class RateLimitExceededException extends Exception {
public RateLimitExceededException(final String message, final Duration retryDuration) { public RateLimitExceededException(final String message, final Duration retryDuration) {
super(message, null, true, false); super(message, null, true, false);
this.retryDuration = retryDuration; // we won't provide a backoff in the case the duration is negative
this.retryDuration = retryDuration.isNegative() ? Optional.empty() : Optional.of(retryDuration);
} }
public Duration getRetryDuration() { return retryDuration; } public Optional<Duration> getRetryDuration() { return retryDuration; }
} }

View File

@ -1,19 +0,0 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import java.time.Duration;
public class RetryLaterException extends Exception {
private final Duration backoffDuration;
public RetryLaterException(RateLimitExceededException e) {
super(null, e, true, false);
this.backoffDuration = e.getRetryDuration();
}
public Duration getBackoffDuration() { return backoffDuration; }
}

View File

@ -12,8 +12,18 @@ import javax.ws.rs.ext.Provider;
@Provider @Provider
public class RateLimitExceededExceptionMapper implements ExceptionMapper<RateLimitExceededException> { public class RateLimitExceededExceptionMapper implements ExceptionMapper<RateLimitExceededException> {
/**
* Convert a RateLimitExceededException to a 413 response with a
* Retry-After header.
*
* @param e A RateLimitExceededException potentially containing a reccomended retry duration
* @return the response
*/
@Override @Override
public Response toResponse(RateLimitExceededException e) { public Response toResponse(RateLimitExceededException e) {
return Response.status(413).build(); return e.getRetryDuration()
.map(d -> Response.status(413).header("Retry-After", d.toSeconds()))
.orElseGet(() -> Response.status(413)).build();
} }
} }

View File

@ -1,23 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.mappers;
import org.whispersystems.textsecuregcm.controllers.RetryLaterException;
import javax.ws.rs.core.Response;
import javax.ws.rs.ext.ExceptionMapper;
import javax.ws.rs.ext.Provider;
@Provider
public class RetryLaterExceptionMapper implements ExceptionMapper<RetryLaterException> {
@Override
public Response toResponse(RetryLaterException e) {
return Response.status(413)
.header("Retry-After", e.getBackoffDuration().toSeconds())
.build();
}
}

View File

@ -27,7 +27,7 @@ import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager; import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.mappers.RetryLaterExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -45,7 +45,7 @@ class ChallengeControllerTest {
Set.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) Set.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setMapper(SystemMapper.getMapper()) .setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new RetryLaterExceptionMapper()) .addResource(new RateLimitExceededExceptionMapper())
.addResource(challengeController) .addResource(challengeController)
.build(); .build();

View File

@ -1804,6 +1804,38 @@ class AccountControllerTest {
.getStatus()).isEqualTo(404); .getStatus()).isEqualTo(404);
} }
@Test
void testAccountExistsRateLimited() throws RateLimitExceededException {
final Account account = mock(Account.class);
final UUID accountIdentifier = UUID.randomUUID();
when(accountsManager.getByAccountIdentifier(accountIdentifier)).thenReturn(Optional.of(account));
final RateLimiter checkAccountLimiter = mock(RateLimiter.class);
when(rateLimiters.getCheckAccountExistenceLimiter()).thenReturn(checkAccountLimiter);
doThrow(new RateLimitExceededException(Duration.ofSeconds(13))).when(checkAccountLimiter).validate("127.0.0.1");
final Response response = resources.getJerseyTest()
.target(String.format("/v1/accounts/account/%s", accountIdentifier))
.request()
.header("X-Forwarded-For", "127.0.0.1")
.head();
assertThat(response.getStatus()).isEqualTo(413);
assertThat(response.getHeaderString("Retry-After")).isEqualTo(String.valueOf(Duration.ofSeconds(13).toSeconds()));
}
@Test
void testAccountExistsNoForwardedFor() throws RateLimitExceededException {
final Response response = resources.getJerseyTest()
.target(String.format("/v1/accounts/account/%s", UUID.randomUUID()))
.request()
.header("X-Forwarded-For", "")
.head();
assertThat(response.getStatus()).isEqualTo(413);
assertThat(response.getHeaderString("Retry-After")).isNull();
}
@Test @Test
void testAccountExistsAuthenticated() { void testAccountExistsAuthenticated() {
assertThat(resources.getJerseyTest() assertThat(resources.getJerseyTest()

View File

@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.tests.controllers;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq; import static org.mockito.Mockito.eq;
@ -38,8 +39,6 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
@ -50,11 +49,11 @@ import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.PreKeyCount; import org.whispersystems.textsecuregcm.entities.PreKeyCount;
import org.whispersystems.textsecuregcm.entities.PreKeyResponse; import org.whispersystems.textsecuregcm.entities.PreKeyResponse;
import org.whispersystems.textsecuregcm.entities.PreKeyState; import org.whispersystems.textsecuregcm.entities.PreKeyState;
import org.whispersystems.textsecuregcm.entities.RateLimitChallenge;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.RateLimitChallengeExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.ServerRejectedExceptionMapper; import org.whispersystems.textsecuregcm.mappers.ServerRejectedExceptionMapper;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
@ -107,6 +106,7 @@ class KeysControllerTest {
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new ServerRejectedExceptionMapper()) .addResource(new ServerRejectedExceptionMapper())
.addResource(new KeysController(rateLimiters, KEYS, accounts)) .addResource(new KeysController(rateLimiters, KEYS, accounts))
.addResource(new RateLimitExceededExceptionMapper())
.build(); .build();
@BeforeEach @BeforeEach
@ -316,6 +316,21 @@ class KeysControllerTest {
verifyNoMoreInteractions(KEYS); verifyNoMoreInteractions(KEYS);
} }
@Test
void testGetKeysRateLimited() throws RateLimitExceededException {
Duration retryAfter = Duration.ofSeconds(31);
doThrow(new RateLimitExceededException(retryAfter)).when(rateLimiter).validate(anyString());
Response result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_PNI))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get();
assertThat(result.getStatus()).isEqualTo(413);
assertThat(result.getHeaderString("Retry-After")).isEqualTo(String.valueOf(retryAfter.toSeconds()));
}
@Test @Test
void testUnidentifiedRequest() { void testUnidentifiedRequest() {
PreKeyResponse result = resources.getJerseyTest() PreKeyResponse result = resources.getJerseyTest()

View File

@ -10,6 +10,7 @@ import static org.assertj.core.api.Assertions.assertThatNoException;
import static org.mockito.ArgumentMatchers.refEq; import static org.mockito.ArgumentMatchers.refEq;
import static org.mockito.Mockito.any; import static org.mockito.Mockito.any;
import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq; import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
@ -26,6 +27,7 @@ import io.dropwizard.testing.junit5.ResourceExtension;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.time.Clock; import java.time.Clock;
import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.Base64; import java.util.Base64;
import java.util.Collections; import java.util.Collections;
@ -82,6 +84,7 @@ import org.whispersystems.textsecuregcm.entities.ProfileKeyCredentialProfileResp
import org.whispersystems.textsecuregcm.entities.VersionedProfileResponse; import org.whispersystems.textsecuregcm.entities.VersionedProfileResponse;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.s3.PolicySigner; import org.whispersystems.textsecuregcm.s3.PolicySigner;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator; import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
@ -123,6 +126,7 @@ class ProfileControllerTest {
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.addProvider(new RateLimitExceededExceptionMapper())
.setMapper(SystemMapper.getMapper()) .setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new ProfileController( .addResource(new ProfileController(
@ -210,6 +214,7 @@ class ProfileControllerTest {
@AfterEach @AfterEach
void teardown() { void teardown() {
reset(accountsManager); reset(accountsManager);
reset(rateLimiter);
} }
@Test @Test
@ -228,6 +233,20 @@ class ProfileControllerTest {
verify(rateLimiter, times(1)).validate(AuthHelper.VALID_UUID); verify(rateLimiter, times(1)).validate(AuthHelper.VALID_UUID);
} }
@Test
void testProfileGetByUuidRateLimited() throws RateLimitExceededException {
doThrow(new RateLimitExceededException(Duration.ofSeconds(13))).when(rateLimiter).validate(AuthHelper.VALID_UUID);
Response response= resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_UUID_TWO)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get();
assertThat(response.getStatus()).isEqualTo(413);
assertThat(response.getHeaderString("Retry-After")).isEqualTo(String.valueOf(Duration.ofSeconds(13).toSeconds()));
}
@Test @Test
void testProfileGetByUuidUnidentified() throws RateLimitExceededException { void testProfileGetByUuidUnidentified() throws RateLimitExceededException {
BaseProfileResponse profile = resources.getJerseyTest() BaseProfileResponse profile = resources.getJerseyTest()