Return a Retry-After on rate-limited responses
Previously, only endpoints throwing a RetryLaterException would include a Retry-After header in the 413 response. Now, by default, all RateLimitExceededExceptions will be marshalled into a 413 with a Retry-After included if possible.
This commit is contained in:
		
							parent
							
								
									43792e2426
								
							
						
					
					
						commit
						ae3a5c5f5e
					
				| 
						 | 
					@ -123,7 +123,6 @@ import org.whispersystems.textsecuregcm.mappers.InvalidWebsocketAddressException
 | 
				
			||||||
import org.whispersystems.textsecuregcm.mappers.NonNormalizedPhoneNumberExceptionMapper;
 | 
					import org.whispersystems.textsecuregcm.mappers.NonNormalizedPhoneNumberExceptionMapper;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.mappers.RateLimitChallengeExceptionMapper;
 | 
					import org.whispersystems.textsecuregcm.mappers.RateLimitChallengeExceptionMapper;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
 | 
					import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.mappers.RetryLaterExceptionMapper;
 | 
					 | 
				
			||||||
import org.whispersystems.textsecuregcm.mappers.ServerRejectedExceptionMapper;
 | 
					import org.whispersystems.textsecuregcm.mappers.ServerRejectedExceptionMapper;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.metrics.ApplicationShutdownMonitor;
 | 
					import org.whispersystems.textsecuregcm.metrics.ApplicationShutdownMonitor;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.metrics.BufferPoolGauges;
 | 
					import org.whispersystems.textsecuregcm.metrics.BufferPoolGauges;
 | 
				
			||||||
| 
						 | 
					@ -758,7 +757,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
 | 
				
			||||||
        new RateLimitExceededExceptionMapper(),
 | 
					        new RateLimitExceededExceptionMapper(),
 | 
				
			||||||
        new InvalidWebsocketAddressExceptionMapper(),
 | 
					        new InvalidWebsocketAddressExceptionMapper(),
 | 
				
			||||||
        new DeviceLimitExceededExceptionMapper(),
 | 
					        new DeviceLimitExceededExceptionMapper(),
 | 
				
			||||||
        new RetryLaterExceptionMapper(),
 | 
					 | 
				
			||||||
        new ServerRejectedExceptionMapper(),
 | 
					        new ServerRejectedExceptionMapper(),
 | 
				
			||||||
        new ImpossiblePhoneNumberExceptionMapper(),
 | 
					        new ImpossiblePhoneNumberExceptionMapper(),
 | 
				
			||||||
        new NonNormalizedPhoneNumberExceptionMapper()
 | 
					        new NonNormalizedPhoneNumberExceptionMapper()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -213,7 +213,7 @@ public class AccountController {
 | 
				
			||||||
                                @QueryParam("client")           Optional<String> client,
 | 
					                                @QueryParam("client")           Optional<String> client,
 | 
				
			||||||
                                @QueryParam("captcha")          Optional<String> captcha,
 | 
					                                @QueryParam("captcha")          Optional<String> captcha,
 | 
				
			||||||
                                @QueryParam("challenge")        Optional<String> pushChallenge)
 | 
					                                @QueryParam("challenge")        Optional<String> pushChallenge)
 | 
				
			||||||
      throws RateLimitExceededException, RetryLaterException, ImpossiblePhoneNumberException, NonNormalizedPhoneNumberException {
 | 
					      throws RateLimitExceededException, ImpossiblePhoneNumberException, NonNormalizedPhoneNumberException {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Util.requireNormalizedNumber(number);
 | 
					    Util.requireNormalizedNumber(number);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -234,24 +234,16 @@ public class AccountController {
 | 
				
			||||||
      return Response.status(402).build();
 | 
					      return Response.status(402).build();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    try {
 | 
					    switch (transport) {
 | 
				
			||||||
      switch (transport) {
 | 
					      case "sms":
 | 
				
			||||||
        case "sms":
 | 
					        rateLimiters.getSmsDestinationLimiter().validate(number);
 | 
				
			||||||
          rateLimiters.getSmsDestinationLimiter().validate(number);
 | 
					        break;
 | 
				
			||||||
          break;
 | 
					      case "voice":
 | 
				
			||||||
        case "voice":
 | 
					        rateLimiters.getVoiceDestinationLimiter().validate(number);
 | 
				
			||||||
          rateLimiters.getVoiceDestinationLimiter().validate(number);
 | 
					        rateLimiters.getVoiceDestinationDailyLimiter().validate(number);
 | 
				
			||||||
          rateLimiters.getVoiceDestinationDailyLimiter().validate(number);
 | 
					        break;
 | 
				
			||||||
          break;
 | 
					      default:
 | 
				
			||||||
        default:
 | 
					        throw new WebApplicationException(Response.status(422).build());
 | 
				
			||||||
          throw new WebApplicationException(Response.status(422).build());
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    } catch (RateLimitExceededException e) {
 | 
					 | 
				
			||||||
      if (!e.getRetryDuration().isNegative()) {
 | 
					 | 
				
			||||||
        throw new RetryLaterException(e);
 | 
					 | 
				
			||||||
      } else {
 | 
					 | 
				
			||||||
        throw e;
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    VerificationCode       verificationCode       = generateVerificationCode(number);
 | 
					    VerificationCode       verificationCode       = generateVerificationCode(number);
 | 
				
			||||||
| 
						 | 
					@ -643,7 +635,9 @@ public class AccountController {
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    final String mostRecentProxy = ForwardedIpUtil.getMostRecentProxy(forwardedFor)
 | 
					    final String mostRecentProxy = ForwardedIpUtil.getMostRecentProxy(forwardedFor)
 | 
				
			||||||
        .orElseThrow(() -> new RateLimitExceededException(Duration.ofHours(1)));
 | 
					        // Missing/malformed Forwarded-For, cannot calculate a reasonable backoff
 | 
				
			||||||
 | 
					        // duration
 | 
				
			||||||
 | 
					        .orElseThrow(() -> new RateLimitExceededException(Duration.ofHours(-1)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    rateLimiters.getCheckAccountExistenceLimiter().validate(mostRecentProxy);
 | 
					    rateLimiters.getCheckAccountExistenceLimiter().validate(mostRecentProxy);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -51,7 +51,7 @@ public class ChallengeController {
 | 
				
			||||||
  public Response handleChallengeResponse(@Auth final AuthenticatedAccount auth,
 | 
					  public Response handleChallengeResponse(@Auth final AuthenticatedAccount auth,
 | 
				
			||||||
      @Valid final AnswerChallengeRequest answerRequest,
 | 
					      @Valid final AnswerChallengeRequest answerRequest,
 | 
				
			||||||
      @HeaderParam("X-Forwarded-For") final String forwardedFor,
 | 
					      @HeaderParam("X-Forwarded-For") final String forwardedFor,
 | 
				
			||||||
      @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) throws RetryLaterException {
 | 
					      @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) throws RateLimitExceededException {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent));
 | 
					    Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -76,8 +76,6 @@ public class ChallengeController {
 | 
				
			||||||
      } else {
 | 
					      } else {
 | 
				
			||||||
        tags = tags.and(CHALLENGE_TYPE_TAG, "unrecognized");
 | 
					        tags = tags.and(CHALLENGE_TYPE_TAG, "unrecognized");
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    } catch (final RateLimitExceededException e) {
 | 
					 | 
				
			||||||
      throw new RetryLaterException(e);
 | 
					 | 
				
			||||||
    } finally {
 | 
					    } finally {
 | 
				
			||||||
      Metrics.counter(CHALLENGE_RESPONSE_COUNTER_NAME, tags).increment();
 | 
					      Metrics.counter(CHALLENGE_RESPONSE_COUNTER_NAME, tags).increment();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -5,10 +5,11 @@
 | 
				
			||||||
package org.whispersystems.textsecuregcm.controllers;
 | 
					package org.whispersystems.textsecuregcm.controllers;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import java.time.Duration;
 | 
					import java.time.Duration;
 | 
				
			||||||
 | 
					import java.util.Optional;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
public class RateLimitExceededException extends Exception {
 | 
					public class RateLimitExceededException extends Exception {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private final Duration retryDuration;
 | 
					  private final Optional<Duration> retryDuration;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public RateLimitExceededException(final Duration retryDuration) {
 | 
					  public RateLimitExceededException(final Duration retryDuration) {
 | 
				
			||||||
    this(null, retryDuration);
 | 
					    this(null, retryDuration);
 | 
				
			||||||
| 
						 | 
					@ -16,8 +17,9 @@ public class RateLimitExceededException extends Exception {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public RateLimitExceededException(final String message, final Duration retryDuration) {
 | 
					  public RateLimitExceededException(final String message, final Duration retryDuration) {
 | 
				
			||||||
    super(message, null, true, false);
 | 
					    super(message, null, true, false);
 | 
				
			||||||
    this.retryDuration = retryDuration;
 | 
					    // we won't provide a backoff in the case the duration is negative
 | 
				
			||||||
 | 
					    this.retryDuration = retryDuration.isNegative() ? Optional.empty() : Optional.of(retryDuration);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  public Duration getRetryDuration() { return retryDuration; }
 | 
					  public Optional<Duration> getRetryDuration() { return retryDuration; }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,19 +0,0 @@
 | 
				
			||||||
/*
 | 
					 | 
				
			||||||
 * Copyright 2013-2021 Signal Messenger, LLC
 | 
					 | 
				
			||||||
 * SPDX-License-Identifier: AGPL-3.0-only
 | 
					 | 
				
			||||||
 */
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
package org.whispersystems.textsecuregcm.controllers;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import java.time.Duration;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
public class RetryLaterException extends Exception {
 | 
					 | 
				
			||||||
  private final Duration backoffDuration;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  public RetryLaterException(RateLimitExceededException e) {
 | 
					 | 
				
			||||||
    super(null, e, true, false);
 | 
					 | 
				
			||||||
    this.backoffDuration = e.getRetryDuration();
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  public Duration getBackoffDuration() { return backoffDuration; }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
| 
						 | 
					@ -12,8 +12,18 @@ import javax.ws.rs.ext.Provider;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@Provider
 | 
					@Provider
 | 
				
			||||||
public class RateLimitExceededExceptionMapper implements ExceptionMapper<RateLimitExceededException> {
 | 
					public class RateLimitExceededExceptionMapper implements ExceptionMapper<RateLimitExceededException> {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Convert a RateLimitExceededException to a 413 response with a
 | 
				
			||||||
 | 
					   * Retry-After header.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param e A RateLimitExceededException potentially containing a reccomended retry duration
 | 
				
			||||||
 | 
					   * @return the response
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
  @Override
 | 
					  @Override
 | 
				
			||||||
  public Response toResponse(RateLimitExceededException e) {
 | 
					  public Response toResponse(RateLimitExceededException e) {
 | 
				
			||||||
    return Response.status(413).build();
 | 
					    return e.getRetryDuration()
 | 
				
			||||||
 | 
					        .map(d -> Response.status(413).header("Retry-After", d.toSeconds()))
 | 
				
			||||||
 | 
					        .orElseGet(() -> Response.status(413)).build();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,23 +0,0 @@
 | 
				
			||||||
/*
 | 
					 | 
				
			||||||
 * Copyright 2013-2020 Signal Messenger, LLC
 | 
					 | 
				
			||||||
 * SPDX-License-Identifier: AGPL-3.0-only
 | 
					 | 
				
			||||||
 */
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
package org.whispersystems.textsecuregcm.mappers;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import org.whispersystems.textsecuregcm.controllers.RetryLaterException;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import javax.ws.rs.core.Response;
 | 
					 | 
				
			||||||
import javax.ws.rs.ext.ExceptionMapper;
 | 
					 | 
				
			||||||
import javax.ws.rs.ext.Provider;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@Provider
 | 
					 | 
				
			||||||
public class RetryLaterExceptionMapper implements ExceptionMapper<RetryLaterException> {
 | 
					 | 
				
			||||||
  @Override
 | 
					 | 
				
			||||||
  public Response toResponse(RetryLaterException e) {
 | 
					 | 
				
			||||||
    return Response.status(413)
 | 
					 | 
				
			||||||
                   .header("Retry-After", e.getBackoffDuration().toSeconds())
 | 
					 | 
				
			||||||
                   .build();
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
| 
						 | 
					@ -27,7 +27,7 @@ import org.junit.jupiter.api.extension.ExtendWith;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
 | 
					import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
 | 
					import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
 | 
					import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.mappers.RetryLaterExceptionMapper;
 | 
					import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
 | 
					import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
 | 
					import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
 | 
					import org.whispersystems.textsecuregcm.util.SystemMapper;
 | 
				
			||||||
| 
						 | 
					@ -45,7 +45,7 @@ class ChallengeControllerTest {
 | 
				
			||||||
          Set.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
 | 
					          Set.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
 | 
				
			||||||
      .setMapper(SystemMapper.getMapper())
 | 
					      .setMapper(SystemMapper.getMapper())
 | 
				
			||||||
      .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
 | 
					      .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
 | 
				
			||||||
      .addResource(new RetryLaterExceptionMapper())
 | 
					      .addResource(new RateLimitExceededExceptionMapper())
 | 
				
			||||||
      .addResource(challengeController)
 | 
					      .addResource(challengeController)
 | 
				
			||||||
      .build();
 | 
					      .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1804,6 +1804,38 @@ class AccountControllerTest {
 | 
				
			||||||
        .getStatus()).isEqualTo(404);
 | 
					        .getStatus()).isEqualTo(404);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Test
 | 
				
			||||||
 | 
					  void testAccountExistsRateLimited() throws RateLimitExceededException {
 | 
				
			||||||
 | 
					    final Account account = mock(Account.class);
 | 
				
			||||||
 | 
					    final UUID accountIdentifier = UUID.randomUUID();
 | 
				
			||||||
 | 
					    when(accountsManager.getByAccountIdentifier(accountIdentifier)).thenReturn(Optional.of(account));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    final RateLimiter checkAccountLimiter = mock(RateLimiter.class);
 | 
				
			||||||
 | 
					    when(rateLimiters.getCheckAccountExistenceLimiter()).thenReturn(checkAccountLimiter);
 | 
				
			||||||
 | 
					    doThrow(new RateLimitExceededException(Duration.ofSeconds(13))).when(checkAccountLimiter).validate("127.0.0.1");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    final Response response = resources.getJerseyTest()
 | 
				
			||||||
 | 
					        .target(String.format("/v1/accounts/account/%s", accountIdentifier))
 | 
				
			||||||
 | 
					        .request()
 | 
				
			||||||
 | 
					        .header("X-Forwarded-For", "127.0.0.1")
 | 
				
			||||||
 | 
					        .head();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assertThat(response.getStatus()).isEqualTo(413);
 | 
				
			||||||
 | 
					    assertThat(response.getHeaderString("Retry-After")).isEqualTo(String.valueOf(Duration.ofSeconds(13).toSeconds()));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Test
 | 
				
			||||||
 | 
					  void testAccountExistsNoForwardedFor() throws RateLimitExceededException {
 | 
				
			||||||
 | 
					    final Response response = resources.getJerseyTest()
 | 
				
			||||||
 | 
					        .target(String.format("/v1/accounts/account/%s", UUID.randomUUID()))
 | 
				
			||||||
 | 
					        .request()
 | 
				
			||||||
 | 
					        .header("X-Forwarded-For", "")
 | 
				
			||||||
 | 
					        .head();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assertThat(response.getStatus()).isEqualTo(413);
 | 
				
			||||||
 | 
					    assertThat(response.getHeaderString("Retry-After")).isNull();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @Test
 | 
					  @Test
 | 
				
			||||||
  void testAccountExistsAuthenticated() {
 | 
					  void testAccountExistsAuthenticated() {
 | 
				
			||||||
    assertThat(resources.getJerseyTest()
 | 
					    assertThat(resources.getJerseyTest()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.tests.controllers;
 | 
				
			||||||
import static org.assertj.core.api.Assertions.assertThat;
 | 
					import static org.assertj.core.api.Assertions.assertThat;
 | 
				
			||||||
import static org.mockito.ArgumentMatchers.any;
 | 
					import static org.mockito.ArgumentMatchers.any;
 | 
				
			||||||
import static org.mockito.ArgumentMatchers.anyLong;
 | 
					import static org.mockito.ArgumentMatchers.anyLong;
 | 
				
			||||||
 | 
					import static org.mockito.ArgumentMatchers.anyString;
 | 
				
			||||||
import static org.mockito.Mockito.clearInvocations;
 | 
					import static org.mockito.Mockito.clearInvocations;
 | 
				
			||||||
import static org.mockito.Mockito.doThrow;
 | 
					import static org.mockito.Mockito.doThrow;
 | 
				
			||||||
import static org.mockito.Mockito.eq;
 | 
					import static org.mockito.Mockito.eq;
 | 
				
			||||||
| 
						 | 
					@ -38,8 +39,6 @@ import org.junit.jupiter.api.AfterEach;
 | 
				
			||||||
import org.junit.jupiter.api.BeforeEach;
 | 
					import org.junit.jupiter.api.BeforeEach;
 | 
				
			||||||
import org.junit.jupiter.api.Test;
 | 
					import org.junit.jupiter.api.Test;
 | 
				
			||||||
import org.junit.jupiter.api.extension.ExtendWith;
 | 
					import org.junit.jupiter.api.extension.ExtendWith;
 | 
				
			||||||
import org.junit.jupiter.params.ParameterizedTest;
 | 
					 | 
				
			||||||
import org.junit.jupiter.params.provider.ValueSource;
 | 
					 | 
				
			||||||
import org.mockito.ArgumentCaptor;
 | 
					import org.mockito.ArgumentCaptor;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
 | 
					import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
 | 
					import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
 | 
				
			||||||
| 
						 | 
					@ -50,11 +49,11 @@ import org.whispersystems.textsecuregcm.entities.PreKey;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.entities.PreKeyCount;
 | 
					import org.whispersystems.textsecuregcm.entities.PreKeyCount;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.entities.PreKeyResponse;
 | 
					import org.whispersystems.textsecuregcm.entities.PreKeyResponse;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.entities.PreKeyState;
 | 
					import org.whispersystems.textsecuregcm.entities.PreKeyState;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.entities.RateLimitChallenge;
 | 
					 | 
				
			||||||
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
 | 
					import org.whispersystems.textsecuregcm.entities.SignedPreKey;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
 | 
					 | 
				
			||||||
import org.whispersystems.textsecuregcm.limits.RateLimiter;
 | 
					import org.whispersystems.textsecuregcm.limits.RateLimiter;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
 | 
					import org.whispersystems.textsecuregcm.limits.RateLimiters;
 | 
				
			||||||
 | 
					import org.whispersystems.textsecuregcm.mappers.RateLimitChallengeExceptionMapper;
 | 
				
			||||||
 | 
					import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.mappers.ServerRejectedExceptionMapper;
 | 
					import org.whispersystems.textsecuregcm.mappers.ServerRejectedExceptionMapper;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.storage.Account;
 | 
					import org.whispersystems.textsecuregcm.storage.Account;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
 | 
					import org.whispersystems.textsecuregcm.storage.AccountsManager;
 | 
				
			||||||
| 
						 | 
					@ -107,6 +106,7 @@ class KeysControllerTest {
 | 
				
			||||||
      .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
 | 
					      .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
 | 
				
			||||||
      .addResource(new ServerRejectedExceptionMapper())
 | 
					      .addResource(new ServerRejectedExceptionMapper())
 | 
				
			||||||
      .addResource(new KeysController(rateLimiters, KEYS, accounts))
 | 
					      .addResource(new KeysController(rateLimiters, KEYS, accounts))
 | 
				
			||||||
 | 
					      .addResource(new RateLimitExceededExceptionMapper())
 | 
				
			||||||
      .build();
 | 
					      .build();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @BeforeEach
 | 
					  @BeforeEach
 | 
				
			||||||
| 
						 | 
					@ -316,6 +316,21 @@ class KeysControllerTest {
 | 
				
			||||||
    verifyNoMoreInteractions(KEYS);
 | 
					    verifyNoMoreInteractions(KEYS);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Test
 | 
				
			||||||
 | 
					  void testGetKeysRateLimited() throws RateLimitExceededException {
 | 
				
			||||||
 | 
					    Duration retryAfter = Duration.ofSeconds(31);
 | 
				
			||||||
 | 
					    doThrow(new RateLimitExceededException(retryAfter)).when(rateLimiter).validate(anyString());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Response result = resources.getJerseyTest()
 | 
				
			||||||
 | 
					        .target(String.format("/v2/keys/%s/*", EXISTS_PNI))
 | 
				
			||||||
 | 
					        .request()
 | 
				
			||||||
 | 
					        .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
 | 
				
			||||||
 | 
					        .get();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assertThat(result.getStatus()).isEqualTo(413);
 | 
				
			||||||
 | 
					    assertThat(result.getHeaderString("Retry-After")).isEqualTo(String.valueOf(retryAfter.toSeconds()));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @Test
 | 
					  @Test
 | 
				
			||||||
  void testUnidentifiedRequest() {
 | 
					  void testUnidentifiedRequest() {
 | 
				
			||||||
    PreKeyResponse result = resources.getJerseyTest()
 | 
					    PreKeyResponse result = resources.getJerseyTest()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -10,6 +10,7 @@ import static org.assertj.core.api.Assertions.assertThatNoException;
 | 
				
			||||||
import static org.mockito.ArgumentMatchers.refEq;
 | 
					import static org.mockito.ArgumentMatchers.refEq;
 | 
				
			||||||
import static org.mockito.Mockito.any;
 | 
					import static org.mockito.Mockito.any;
 | 
				
			||||||
import static org.mockito.Mockito.clearInvocations;
 | 
					import static org.mockito.Mockito.clearInvocations;
 | 
				
			||||||
 | 
					import static org.mockito.Mockito.doThrow;
 | 
				
			||||||
import static org.mockito.Mockito.eq;
 | 
					import static org.mockito.Mockito.eq;
 | 
				
			||||||
import static org.mockito.Mockito.mock;
 | 
					import static org.mockito.Mockito.mock;
 | 
				
			||||||
import static org.mockito.Mockito.never;
 | 
					import static org.mockito.Mockito.never;
 | 
				
			||||||
| 
						 | 
					@ -26,6 +27,7 @@ import io.dropwizard.testing.junit5.ResourceExtension;
 | 
				
			||||||
import java.nio.charset.StandardCharsets;
 | 
					import java.nio.charset.StandardCharsets;
 | 
				
			||||||
import java.security.SecureRandom;
 | 
					import java.security.SecureRandom;
 | 
				
			||||||
import java.time.Clock;
 | 
					import java.time.Clock;
 | 
				
			||||||
 | 
					import java.time.Duration;
 | 
				
			||||||
import java.time.Instant;
 | 
					import java.time.Instant;
 | 
				
			||||||
import java.util.Base64;
 | 
					import java.util.Base64;
 | 
				
			||||||
import java.util.Collections;
 | 
					import java.util.Collections;
 | 
				
			||||||
| 
						 | 
					@ -82,6 +84,7 @@ import org.whispersystems.textsecuregcm.entities.ProfileKeyCredentialProfileResp
 | 
				
			||||||
import org.whispersystems.textsecuregcm.entities.VersionedProfileResponse;
 | 
					import org.whispersystems.textsecuregcm.entities.VersionedProfileResponse;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.limits.RateLimiter;
 | 
					import org.whispersystems.textsecuregcm.limits.RateLimiter;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
 | 
					import org.whispersystems.textsecuregcm.limits.RateLimiters;
 | 
				
			||||||
 | 
					import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.s3.PolicySigner;
 | 
					import org.whispersystems.textsecuregcm.s3.PolicySigner;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
 | 
					import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
 | 
				
			||||||
import org.whispersystems.textsecuregcm.storage.Account;
 | 
					import org.whispersystems.textsecuregcm.storage.Account;
 | 
				
			||||||
| 
						 | 
					@ -123,6 +126,7 @@ class ProfileControllerTest {
 | 
				
			||||||
      .addProvider(AuthHelper.getAuthFilter())
 | 
					      .addProvider(AuthHelper.getAuthFilter())
 | 
				
			||||||
      .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
 | 
					      .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
 | 
				
			||||||
          ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
 | 
					          ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
 | 
				
			||||||
 | 
					      .addProvider(new RateLimitExceededExceptionMapper())
 | 
				
			||||||
      .setMapper(SystemMapper.getMapper())
 | 
					      .setMapper(SystemMapper.getMapper())
 | 
				
			||||||
      .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
 | 
					      .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
 | 
				
			||||||
      .addResource(new ProfileController(
 | 
					      .addResource(new ProfileController(
 | 
				
			||||||
| 
						 | 
					@ -210,6 +214,7 @@ class ProfileControllerTest {
 | 
				
			||||||
  @AfterEach
 | 
					  @AfterEach
 | 
				
			||||||
  void teardown() {
 | 
					  void teardown() {
 | 
				
			||||||
    reset(accountsManager);
 | 
					    reset(accountsManager);
 | 
				
			||||||
 | 
					    reset(rateLimiter);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @Test
 | 
					  @Test
 | 
				
			||||||
| 
						 | 
					@ -228,6 +233,20 @@ class ProfileControllerTest {
 | 
				
			||||||
    verify(rateLimiter, times(1)).validate(AuthHelper.VALID_UUID);
 | 
					    verify(rateLimiter, times(1)).validate(AuthHelper.VALID_UUID);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Test
 | 
				
			||||||
 | 
					  void testProfileGetByUuidRateLimited() throws RateLimitExceededException {
 | 
				
			||||||
 | 
					    doThrow(new RateLimitExceededException(Duration.ofSeconds(13))).when(rateLimiter).validate(AuthHelper.VALID_UUID);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Response response= resources.getJerseyTest()
 | 
				
			||||||
 | 
					        .target("/v1/profile/" + AuthHelper.VALID_UUID_TWO)
 | 
				
			||||||
 | 
					        .request()
 | 
				
			||||||
 | 
					        .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
 | 
				
			||||||
 | 
					        .get();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assertThat(response.getStatus()).isEqualTo(413);
 | 
				
			||||||
 | 
					    assertThat(response.getHeaderString("Retry-After")).isEqualTo(String.valueOf(Duration.ofSeconds(13).toSeconds()));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @Test
 | 
					  @Test
 | 
				
			||||||
  void testProfileGetByUuidUnidentified() throws RateLimitExceededException {
 | 
					  void testProfileGetByUuidUnidentified() throws RateLimitExceededException {
 | 
				
			||||||
    BaseProfileResponse profile = resources.getJerseyTest()
 | 
					    BaseProfileResponse profile = resources.getJerseyTest()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue