From c16006dc4b8cbd3141ae9d888b87a919f38fba75 Mon Sep 17 00:00:00 2001 From: Chris Eager Date: Tue, 24 Jan 2023 15:33:48 -0600 Subject: [PATCH] Add `PUT /v2/account/number` --- .../textsecuregcm/WhisperServerService.java | 26 +- .../auth/PhoneVerificationTokenManager.java | 98 +++++ .../controllers/AccountController.java | 2 +- .../controllers/AccountControllerV2.java | 132 ++++++ .../RateLimitExceededException.java | 13 +- .../controllers/RegistrationController.java | 72 +--- .../entities/ChangeNumberRequest.java | 27 ++ .../entities/PhoneVerificationRequest.java | 45 ++ .../entities/RegistrationRequest.java | 30 +- .../limits/LockingRateLimiter.java | 2 +- .../limits/RateLimitByIpFilter.java | 3 +- .../textsecuregcm/limits/RateLimiter.java | 20 +- .../RateLimitExceededExceptionMapper.java | 23 +- .../RegistrationServiceClient.java | 17 +- .../controllers/AccountControllerTest.java | 15 +- .../controllers/AccountControllerV2Test.java | 396 ++++++++++++++++++ .../controllers/ChallengeControllerTest.java | 6 +- .../controllers/ProfileControllerTest.java | 6 +- .../ProvisioningControllerTest.java | 38 +- .../RegistrationControllerTest.java | 58 +-- .../limits/RateLimitedByIpTest.java | 6 +- .../tests/controllers/KeysControllerTest.java | 2 +- .../textsecuregcm/util/MockUtils.java | 5 +- 23 files changed, 856 insertions(+), 186 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneVerificationTokenManager.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneVerificationRequest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index f7f262472..d05b85ba7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -76,6 +76,7 @@ import org.whispersystems.textsecuregcm.auth.CertificateGenerator; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator; +import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager; import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener; @@ -87,6 +88,7 @@ import org.whispersystems.textsecuregcm.captcha.RecaptchaClient; import org.whispersystems.textsecuregcm.configuration.DirectoryServerConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.controllers.AccountController; +import org.whispersystems.textsecuregcm.controllers.AccountControllerV2; import org.whispersystems.textsecuregcm.controllers.ArtController; import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV2; import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV3; @@ -524,16 +526,22 @@ public class WhisperServerService extends Application commonControllers = Lists.newArrayList( + new AccountControllerV2(accountsManager, changeNumberManager, phoneVerificationTokenManager, + registrationLockVerificationManager, rateLimiters), new ArtController(rateLimiters, artCredentialsGenerator), new AttachmentControllerV2(rateLimiters, config.getAwsAttachmentsConfiguration().getAccessKey(), config.getAwsAttachmentsConfiguration().getAccessSecret(), config.getAwsAttachmentsConfiguration().getRegion(), config.getAwsAttachmentsConfiguration().getBucket()), new AttachmentControllerV3(rateLimiters, config.getGcpAttachmentsConfiguration().getDomain(), config.getGcpAttachmentsConfiguration().getEmail(), config.getGcpAttachmentsConfiguration().getMaxSizeInBytes(), config.getGcpAttachmentsConfiguration().getPathPrefix(), config.getGcpAttachmentsConfiguration().getRsaSigningKey()), @@ -748,8 +758,8 @@ public class WhisperServerService extends Application verifyBySessionId(number, request.decodeSessionId()); + case RECOVERY_PASSWORD -> verifyByRecoveryPassword(number, request.recoveryPassword()); + } + + return verificationType; + } + + private void verifyBySessionId(final String number, final byte[] sessionId) throws InterruptedException { + try { + final RegistrationSession session = registrationServiceClient + .getSession(sessionId, REGISTRATION_RPC_TIMEOUT) + .get(VERIFICATION_TIMEOUT_SECONDS, TimeUnit.SECONDS) + .orElseThrow(() -> new NotAuthorizedException("session not verified")); + + if (!MessageDigest.isEqual(number.getBytes(), session.number().getBytes())) { + throw new BadRequestException("number does not match session"); + } + if (!session.verified()) { + throw new NotAuthorizedException("session not verified"); + } + } catch (final CancellationException | ExecutionException | TimeoutException e) { + logger.error("Registration service failure", e); + throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE); + } + } + + private void verifyByRecoveryPassword(final String number, final byte[] recoveryPassword) + throws InterruptedException { + try { + final boolean verified = registrationRecoveryPasswordsManager.verify(number, recoveryPassword) + .get(VERIFICATION_TIMEOUT_SECONDS, TimeUnit.SECONDS); + if (!verified) { + throw new ForbiddenException("recoveryPassword couldn't be verified"); + } + } catch (final ExecutionException | TimeoutException e) { + throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE); + } + } + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java index 74862484d..f8d81ee0c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java @@ -798,7 +798,7 @@ public class AccountController { // This shouldn't happen, so conservatively assume we're over the rate-limit // and indicate that the client should retry logger.error("Missing/bad Forwarded-For: {}", forwardedFor); - return new RateLimitExceededException(Duration.ofHours(1)); + return new RateLimitExceededException(Duration.ofHours(1), true); }); rateLimiter.validate(mostRecentProxy); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java new file mode 100644 index 000000000..c694a080f --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java @@ -0,0 +1,132 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.controllers; + +import static com.codahale.metrics.MetricRegistry.name; + +import com.codahale.metrics.annotation.Timed; +import com.google.common.net.HttpHeaders; +import io.dropwizard.auth.Auth; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.Tags; +import java.util.Optional; +import javax.validation.Valid; +import javax.validation.constraints.NotNull; +import javax.ws.rs.BadRequestException; +import javax.ws.rs.Consumes; +import javax.ws.rs.ForbiddenException; +import javax.ws.rs.HeaderParam; +import javax.ws.rs.PUT; +import javax.ws.rs.Path; +import javax.ws.rs.Produces; +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager; +import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; +import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse; +import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest; +import org.whispersystems.textsecuregcm.entities.MismatchedDevices; +import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest; +import org.whispersystems.textsecuregcm.entities.StaleDevices; +import org.whispersystems.textsecuregcm.limits.RateLimiter; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.ChangeNumberManager; + +@Path("/v2/accounts") +public class AccountControllerV2 { + + private static final String CHANGE_NUMBER_COUNTER_NAME = name(AccountControllerV2.class, "create"); + private static final String VERIFICATION_TYPE_TAG_NAME = "verification"; + + private final AccountsManager accountsManager; + private final ChangeNumberManager changeNumberManager; + private final PhoneVerificationTokenManager phoneVerificationTokenManager; + private final RegistrationLockVerificationManager registrationLockVerificationManager; + private final RateLimiters rateLimiters; + + public AccountControllerV2(final AccountsManager accountsManager, final ChangeNumberManager changeNumberManager, + final PhoneVerificationTokenManager phoneVerificationTokenManager, + final RegistrationLockVerificationManager registrationLockVerificationManager, final RateLimiters rateLimiters) { + this.accountsManager = accountsManager; + this.changeNumberManager = changeNumberManager; + this.phoneVerificationTokenManager = phoneVerificationTokenManager; + this.registrationLockVerificationManager = registrationLockVerificationManager; + this.rateLimiters = rateLimiters; + } + + @Timed + @PUT + @Path("/number") + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + public AccountIdentityResponse changeNumber(@Auth final AuthenticatedAccount authenticatedAccount, + @NotNull @Valid final ChangeNumberRequest request, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) + throws RateLimitExceededException, InterruptedException { + + if (!authenticatedAccount.getAuthenticatedDevice().isMaster()) { + throw new ForbiddenException(); + } + + final String number = request.number(); + + // Only verify and check reglock if there's a data change to be made... + if (!authenticatedAccount.getAccount().getNumber().equals(number)) { + + RateLimiter.adaptLegacyException(() -> rateLimiters.getRegistrationLimiter().validate(number)); + + final PhoneVerificationRequest.VerificationType verificationType = phoneVerificationTokenManager.verify(number, + request); + + final Optional existingAccount = accountsManager.getByE164(number); + + if (existingAccount.isPresent()) { + registrationLockVerificationManager.verifyRegistrationLock(existingAccount.get(), request.registrationLock()); + } + + Metrics.counter(CHANGE_NUMBER_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), + Tag.of(VERIFICATION_TYPE_TAG_NAME, verificationType.name()))) + .increment(); + } + + // ...but always attempt to make the change in case a client retries and needs to re-send messages + try { + final Account updatedAccount = changeNumberManager.changeNumber( + authenticatedAccount.getAccount(), + request.number(), + request.pniIdentityKey(), + request.devicePniSignedPrekeys(), + request.deviceMessages(), + request.pniRegistrationIds()); + + return new AccountIdentityResponse( + updatedAccount.getUuid(), + updatedAccount.getNumber(), + updatedAccount.getPhoneNumberIdentifier(), + updatedAccount.getUsernameHash().orElse(null), + updatedAccount.isStorageSupported()); + } catch (MismatchedDevicesException e) { + throw new WebApplicationException(Response.status(409) + .type(MediaType.APPLICATION_JSON_TYPE) + .entity(new MismatchedDevices(e.getMissingDevices(), + e.getExtraDevices())) + .build()); + } catch (StaleDevicesException e) { + throw new WebApplicationException(Response.status(410) + .type(MediaType.APPLICATION_JSON) + .entity(new StaleDevices(e.getStaleDevices())) + .build()); + } catch (IllegalArgumentException e) { + throw new BadRequestException(e); + } + } + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RateLimitExceededException.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RateLimitExceededException.java index b6c2fd3f1..41cf2b8ff 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RateLimitExceededException.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RateLimitExceededException.java @@ -10,20 +10,27 @@ import javax.annotation.Nullable; public class RateLimitExceededException extends Exception { - private final @Nullable - Duration retryDuration; + @Nullable + private final Duration retryDuration; + private final boolean legacy; /** * Constructs a new exception indicating when it may become safe to retry * * @param retryDuration A duration to wait before retrying, null if no duration can be indicated + * @param legacy whether to use a legacy status code when mapping the exception to an HTTP response */ - public RateLimitExceededException(final @Nullable Duration retryDuration) { + public RateLimitExceededException(@Nullable final Duration retryDuration, final boolean legacy) { super(null, null, true, false); this.retryDuration = retryDuration; + this.legacy = legacy; } public Optional getRetryDuration() { return Optional.ofNullable(retryDuration); } + + public boolean isLegacy() { + return legacy; + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java index cc54f522a..b1380b81e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java @@ -13,42 +13,33 @@ import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tags; -import java.security.MessageDigest; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.Optional; -import java.util.concurrent.CancellationException; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import javax.validation.Valid; import javax.validation.constraints.NotNull; -import javax.ws.rs.BadRequestException; import javax.ws.rs.Consumes; -import javax.ws.rs.ForbiddenException; import javax.ws.rs.HeaderParam; -import javax.ws.rs.NotAuthorizedException; import javax.ws.rs.POST; import javax.ws.rs.Path; import javax.ws.rs.Produces; -import javax.ws.rs.ServerErrorException; import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader; +import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager; import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse; +import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest; import org.whispersystems.textsecuregcm.entities.RegistrationRequest; -import org.whispersystems.textsecuregcm.entities.RegistrationSession; +import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; -import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; -import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.Util; @@ -68,24 +59,17 @@ public class RegistrationController { private static final String REGION_CODE_TAG_NAME = "regionCode"; private static final String VERIFICATION_TYPE_TAG_NAME = "verification"; - private static final Duration REGISTRATION_RPC_TIMEOUT = Duration.ofSeconds(15); - private static final long VERIFICATION_TIMEOUT_SECONDS = REGISTRATION_RPC_TIMEOUT.plusSeconds(1).getSeconds(); - private final AccountsManager accounts; - private final RegistrationServiceClient registrationServiceClient; + private final PhoneVerificationTokenManager phoneVerificationTokenManager; private final RegistrationLockVerificationManager registrationLockVerificationManager; - private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager; private final RateLimiters rateLimiters; public RegistrationController(final AccountsManager accounts, - final RegistrationServiceClient registrationServiceClient, - final RegistrationLockVerificationManager registrationLockVerificationManager, - final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager, - final RateLimiters rateLimiters) { + final PhoneVerificationTokenManager phoneVerificationTokenManager, + final RegistrationLockVerificationManager registrationLockVerificationManager, final RateLimiters rateLimiters) { this.accounts = accounts; - this.registrationServiceClient = registrationServiceClient; + this.phoneVerificationTokenManager = phoneVerificationTokenManager; this.registrationLockVerificationManager = registrationLockVerificationManager; - this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager; this.rateLimiters = rateLimiters; } @@ -99,17 +83,13 @@ public class RegistrationController { @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @NotNull @Valid final RegistrationRequest registrationRequest) throws RateLimitExceededException, InterruptedException { - rateLimiters.getRegistrationLimiter().validate(registrationRequest.sessionId()); - final String number = authorizationHeader.getUsername(); final String password = authorizationHeader.getPassword(); - // decide on the method of verification based on the registration request parameters and verify - final RegistrationRequest.VerificationType verificationType = registrationRequest.verificationType(); - switch (verificationType) { - case SESSION -> verifyBySessionId(number, registrationRequest.decodeSessionId()); - case RECOVERY_PASSWORD -> verifyByRecoveryPassword(number, registrationRequest.recoveryPassword()); - } + RateLimiter.adaptLegacyException(() -> rateLimiters.getRegistrationLimiter().validate(number)); + + final PhoneVerificationRequest.VerificationType verificationType = phoneVerificationTokenManager.verify(number, + registrationRequest); final Optional existingAccount = accounts.getByE164(number); @@ -146,34 +126,4 @@ public class RegistrationController { existingAccount.map(Account::isStorageSupported).orElse(false)); } - private void verifyBySessionId(final String number, final byte[] sessionId) throws InterruptedException { - try { - final RegistrationSession session = registrationServiceClient - .getSession(sessionId, REGISTRATION_RPC_TIMEOUT) - .get(VERIFICATION_TIMEOUT_SECONDS, TimeUnit.SECONDS) - .orElseThrow(() -> new NotAuthorizedException("session not verified")); - - if (!MessageDigest.isEqual(number.getBytes(), session.number().getBytes())) { - throw new BadRequestException("number does not match session"); - } - if (!session.verified()) { - throw new NotAuthorizedException("session not verified"); - } - } catch (final CancellationException | ExecutionException | TimeoutException e) { - logger.error("Registration service failure", e); - throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE); - } - } - - private void verifyByRecoveryPassword(final String number, final byte[] recoveryPassword) throws InterruptedException { - try { - final boolean verified = registrationRecoveryPasswordsManager.verify(number, recoveryPassword) - .get(VERIFICATION_TIMEOUT_SECONDS, TimeUnit.SECONDS); - if (!verified) { - throw new ForbiddenException("recoveryPassword couldn't be verified"); - } - } catch (final ExecutionException | TimeoutException e) { - throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE); - } - } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java new file mode 100644 index 000000000..3d0d63042 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java @@ -0,0 +1,27 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.entities; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import javax.validation.Valid; +import javax.validation.constraints.NotBlank; +import javax.validation.constraints.NotNull; +import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; + +public record ChangeNumberRequest(String sessionId, + @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) byte[] recoveryPassword, + @NotBlank String number, + @JsonProperty("reglock") @Nullable String registrationLock, + @NotBlank String pniIdentityKey, + @NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages, + @NotNull @Valid Map devicePniSignedPrekeys, + @NotNull Map pniRegistrationIds) implements PhoneVerificationRequest { + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneVerificationRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneVerificationRequest.java new file mode 100644 index 000000000..c471936cf --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneVerificationRequest.java @@ -0,0 +1,45 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.entities; + +import static org.apache.commons.lang3.StringUtils.isNotBlank; + +import java.util.Base64; +import javax.validation.constraints.AssertTrue; +import javax.ws.rs.ClientErrorException; +import org.apache.http.HttpStatus; + +public interface PhoneVerificationRequest { + + enum VerificationType { + SESSION, + RECOVERY_PASSWORD + } + + String sessionId(); + + byte[] recoveryPassword(); + + // for the @AssertTrue to work with bean validation, method name must follow 'isSmth()'/'getSmth()' naming convention + @AssertTrue + default boolean isValid() { + // checking that exactly one of sessionId/recoveryPassword is non-empty + return isNotBlank(sessionId()) ^ (recoveryPassword() != null && recoveryPassword().length > 0); + } + + default PhoneVerificationRequest.VerificationType verificationType() { + return isNotBlank(sessionId()) ? PhoneVerificationRequest.VerificationType.SESSION + : PhoneVerificationRequest.VerificationType.RECOVERY_PASSWORD; + } + + default byte[] decodeSessionId() { + try { + return Base64.getUrlDecoder().decode(sessionId()); + } catch (final IllegalArgumentException e) { + throw new ClientErrorException("Malformed session ID", HttpStatus.SC_UNPROCESSABLE_ENTITY); + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java index e0bb4105d..8c15abc93 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java @@ -5,43 +5,15 @@ package org.whispersystems.textsecuregcm.entities; -import static org.apache.commons.lang3.StringUtils.isNotBlank; - import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import java.util.Base64; import javax.validation.Valid; -import javax.validation.constraints.AssertTrue; import javax.validation.constraints.NotNull; -import javax.ws.rs.ClientErrorException; -import org.apache.http.HttpStatus; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; public record RegistrationRequest(String sessionId, @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) byte[] recoveryPassword, @NotNull @Valid AccountAttributes accountAttributes, - boolean skipDeviceTransfer) { + boolean skipDeviceTransfer) implements PhoneVerificationRequest { - public enum VerificationType { - SESSION, - RECOVERY_PASSWORD - } - // for the @AssertTrue to work with bean validation, method name must follow 'isSmth()'/'getSmth()' naming convention - @AssertTrue - public boolean isValid() { - // checking that exactly one of sessionId/recoveryPassword is non-empty - return isNotBlank(sessionId) ^ (recoveryPassword != null && recoveryPassword.length > 0); - } - - public VerificationType verificationType() { - return isNotBlank(sessionId) ? VerificationType.SESSION : VerificationType.RECOVERY_PASSWORD; - } - - public byte[] decodeSessionId() { - try { - return Base64.getUrlDecoder().decode(sessionId()); - } catch (final IllegalArgumentException e) { - throw new ClientErrorException("Malformed session ID", HttpStatus.SC_UNPROCESSABLE_ENTITY); - } - } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java index 81651c395..f40522127 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java @@ -31,7 +31,7 @@ public class LockingRateLimiter extends RateLimiter { public void validate(String key, int amount) throws RateLimitExceededException { if (!acquireLock(key)) { meter.mark(); - throw new RateLimitExceededException(Duration.ZERO); + throw new RateLimitExceededException(Duration.ZERO, true); } try { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimitByIpFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimitByIpFilter.java index 3147073b7..6e6b230f4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimitByIpFilter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimitByIpFilter.java @@ -29,7 +29,8 @@ public class RateLimitByIpFilter implements ContainerRequestFilter { private static final Logger logger = LoggerFactory.getLogger(RateLimitByIpFilter.class); @VisibleForTesting - static final RateLimitExceededException INVALID_HEADER_EXCEPTION = new RateLimitExceededException(Duration.ofHours(1)); + static final RateLimitExceededException INVALID_HEADER_EXCEPTION = new RateLimitExceededException(Duration.ofHours(1), + true); private static final ExceptionMapper EXCEPTION_MAPPER = new RateLimitExceededExceptionMapper(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java index 23c5af012..17486a63d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java @@ -57,7 +57,7 @@ public class RateLimiter { setBucket(key, bucket); } else { meter.mark(); - throw new RateLimitExceededException(bucket.getTimeUntilSpaceAvailable(amount)); + throw new RateLimitExceededException(bucket.getTimeUntilSpaceAvailable(amount), true); } } } @@ -132,4 +132,22 @@ public class RateLimiter { public boolean hasConfiguration(final RateLimitConfiguration configuration) { return bucketSize == configuration.getBucketSize() && leakRatePerMinute == configuration.getLeakRatePerMinute(); } + + /** + * If the wrapped {@code validate()} call throws a {@link RateLimitExceededException}, it will adapt it to ensure that + * {@link RateLimitExceededException#isLegacy()} returns {@code true} + */ + public static void adaptLegacyException(final RateLimitValidator validator) throws RateLimitExceededException { + try { + validator.validate(); + } catch (final RateLimitExceededException e) { + throw new RateLimitExceededException(e.getRetryDuration().orElse(null), false); + } + } + + @FunctionalInterface + public interface RateLimitValidator { + + void validate() throws RateLimitExceededException; + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/mappers/RateLimitExceededExceptionMapper.java b/service/src/main/java/org/whispersystems/textsecuregcm/mappers/RateLimitExceededExceptionMapper.java index e634d4b33..3202933b8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/mappers/RateLimitExceededExceptionMapper.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/mappers/RateLimitExceededExceptionMapper.java @@ -4,36 +4,41 @@ */ package org.whispersystems.textsecuregcm.mappers; +import javax.ws.rs.core.Response; +import javax.ws.rs.ext.ExceptionMapper; +import javax.ws.rs.ext.Provider; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; -import javax.ws.rs.core.Response; -import javax.ws.rs.ext.ExceptionMapper; -import javax.ws.rs.ext.Provider; - @Provider public class RateLimitExceededExceptionMapper implements ExceptionMapper { + private static final Logger logger = LoggerFactory.getLogger(RateLimitExceededExceptionMapper.class); + private static final int LEGACY_STATUS_CODE = 413; + private static final int STATUS_CODE = 429; + /** - * Convert a RateLimitExceededException to a 413 response with a - * Retry-After header. + * Convert a RateLimitExceededException to a {@value STATUS_CODE} (or legacy {@value LEGACY_STATUS_CODE}) response + * with a Retry-After header. * * @param e A RateLimitExceededException potentially containing a recommended retry duration * @return the response */ @Override public Response toResponse(RateLimitExceededException e) { + final int statusCode = e.isLegacy() ? LEGACY_STATUS_CODE : STATUS_CODE; return e.getRetryDuration() .filter(d -> { if (d.isNegative()) { - logger.warn("Encountered a negative retry duration: {}, will not include a Retry-After header in response", d); + logger.warn("Encountered a negative retry duration: {}, will not include a Retry-After header in response", + d); } // only include non-negative durations in retry headers return !d.isNegative(); }) - .map(d -> Response.status(413).header("Retry-After", d.toSeconds())) - .orElseGet(() -> Response.status(413)).build(); + .map(d -> Response.status(statusCode).header("Retry-After", d.toSeconds())) + .orElseGet(() -> Response.status(statusCode)).build(); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/registration/RegistrationServiceClient.java b/service/src/main/java/org/whispersystems/textsecuregcm/registration/RegistrationServiceClient.java index 926161763..df6d30f2d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/registration/RegistrationServiceClient.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/registration/RegistrationServiceClient.java @@ -89,9 +89,12 @@ public class RegistrationServiceClient implements Managed { case ERROR -> { switch (response.getError().getErrorType()) { - case CREATE_REGISTRATION_SESSION_ERROR_TYPE_RATE_LIMITED -> throw new CompletionException(new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds()))); + case CREATE_REGISTRATION_SESSION_ERROR_TYPE_RATE_LIMITED -> throw new CompletionException( + new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds()), + true)); case CREATE_REGISTRATION_SESSION_ERROR_TYPE_ILLEGAL_PHONE_NUMBER -> throw new IllegalArgumentException(); - default -> throw new RuntimeException("Unrecognized error type from registration service: " + response.getError().getErrorType()); + default -> throw new RuntimeException( + "Unrecognized error type from registration service: " + response.getError().getErrorType()); } } @@ -119,8 +122,9 @@ public class RegistrationServiceClient implements Managed { .thenApply(response -> { if (response.hasError()) { switch (response.getError().getErrorType()) { - case SEND_VERIFICATION_CODE_ERROR_TYPE_RATE_LIMITED -> - throw new CompletionException(new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds()))); + case SEND_VERIFICATION_CODE_ERROR_TYPE_RATE_LIMITED -> throw new CompletionException( + new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds()), + true)); default -> throw new CompletionException(new RuntimeException("Failed to send verification code: " + response.getError().getErrorType())); } @@ -142,8 +146,9 @@ public class RegistrationServiceClient implements Managed { .thenApply(response -> { if (response.hasError()) { switch (response.getError().getErrorType()) { - case CHECK_VERIFICATION_CODE_ERROR_TYPE_RATE_LIMITED -> - throw new CompletionException(new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds()))); + case CHECK_VERIFICATION_CODE_ERROR_TYPE_RATE_LIMITED -> throw new CompletionException( + new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds()), + true)); default -> throw new CompletionException(new RuntimeException("Failed to check verification code: " + response.getError().getErrorType())); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java index fb549b092..98cd6ce22 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java @@ -331,11 +331,12 @@ class AccountControllerTest { when(captchaChecker.verify(eq(VALID_CAPTCHA_TOKEN), anyString())) .thenReturn(new AssessmentResult(true, "")); - doThrow(new RateLimitExceededException(Duration.ZERO)).when(pinLimiter).validate(eq(SENDER_OVER_PIN)); + doThrow(new RateLimitExceededException(Duration.ZERO, true)).when(pinLimiter).validate(eq(SENDER_OVER_PIN)); - doThrow(new RateLimitExceededException(Duration.ZERO)).when(smsVoicePrefixLimiter).validate(SENDER_OVER_PREFIX.substring(0, 4+2)); - doThrow(new RateLimitExceededException(Duration.ZERO)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_IP_HOST); - doThrow(new RateLimitExceededException(Duration.ZERO)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_HOST2); + doThrow(new RateLimitExceededException(Duration.ZERO, true)).when(smsVoicePrefixLimiter) + .validate(SENDER_OVER_PREFIX.substring(0, 4 + 2)); + doThrow(new RateLimitExceededException(Duration.ZERO, true)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_IP_HOST); + doThrow(new RateLimitExceededException(Duration.ZERO, true)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_HOST2); } @AfterEach @@ -571,7 +572,7 @@ class AccountControllerTest { @Test void testSendCodeRateLimited() { when(registrationServiceClient.createRegistrationSession(any(), any())) - .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofMinutes(10)))); + .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofMinutes(10), true))); Response response = resources.getJerseyTest() @@ -2050,7 +2051,7 @@ class AccountControllerTest { when(accountsManager.getByAccountIdentifier(accountIdentifier)).thenReturn(Optional.of(account)); MockUtils.updateRateLimiterResponseToFail( - rateLimiters, RateLimiters.Handle.CHECK_ACCOUNT_EXISTENCE, "127.0.0.1", expectedRetryAfter); + rateLimiters, RateLimiters.Handle.CHECK_ACCOUNT_EXISTENCE, "127.0.0.1", expectedRetryAfter, true); final Response response = resources.getJerseyTest() .target(String.format("/v1/accounts/account/%s", accountIdentifier)) @@ -2115,7 +2116,7 @@ class AccountControllerTest { void testLookupUsernameRateLimited() throws RateLimitExceededException { final Duration expectedRetryAfter = Duration.ofSeconds(13); MockUtils.updateRateLimiterResponseToFail( - rateLimiters, RateLimiters.Handle.USERNAME_LOOKUP, "127.0.0.1", expectedRetryAfter); + rateLimiters, RateLimiters.Handle.USERNAME_LOOKUP, "127.0.0.1", expectedRetryAfter, true); final Response response = resources.getJerseyTest() .target(String.format("v1/accounts/username_hash/%s", BASE_64_URL_USERNAME_HASH_1)) .request() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java new file mode 100644 index 000000000..f336076ca --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java @@ -0,0 +1,396 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.controllers; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableSet; +import com.google.i18n.phonenumbers.PhoneNumberUtil; +import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; +import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; +import io.dropwizard.testing.junit5.ResourceExtension; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; +import javax.annotation.Nullable; +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.client.Entity; +import javax.ws.rs.client.Invocation; +import javax.ws.rs.core.HttpHeaders; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import org.apache.http.HttpStatus; +import org.glassfish.jersey.server.ServerProperties; +import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.stubbing.Answer; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager; +import org.whispersystems.textsecuregcm.auth.RegistrationLockError; +import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; +import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse; +import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest; +import org.whispersystems.textsecuregcm.entities.RegistrationSession; +import org.whispersystems.textsecuregcm.limits.RateLimiter; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.mappers.ImpossiblePhoneNumberExceptionMapper; +import org.whispersystems.textsecuregcm.mappers.NonNormalizedPhoneNumberExceptionMapper; +import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; +import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.ChangeNumberManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; +import org.whispersystems.textsecuregcm.tests.util.AuthHelper; +import org.whispersystems.textsecuregcm.util.SystemMapper; + +@ExtendWith(DropwizardExtensionsSupport.class) +class AccountControllerV2Test { + + public static final String NEW_NUMBER = PhoneNumberUtil.getInstance().format( + PhoneNumberUtil.getInstance().getExampleNumber("US"), + PhoneNumberUtil.PhoneNumberFormat.E164); + + private final AccountsManager accountsManager = mock(AccountsManager.class); + private final ChangeNumberManager changeNumberManager = mock(ChangeNumberManager.class); + private final RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class); + private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock( + RegistrationRecoveryPasswordsManager.class); + private final RegistrationLockVerificationManager registrationLockVerificationManager = mock( + RegistrationLockVerificationManager.class); + private final RateLimiters rateLimiters = mock(RateLimiters.class); + private final RateLimiter registrationLimiter = mock(RateLimiter.class); + + private final ResourceExtension resources = ResourceExtension.builder() + .addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE) + .addProvider(AuthHelper.getAuthFilter()) + .addProvider( + new PolymorphicAuthValueFactoryProvider.Binder<>( + ImmutableSet.of(AuthenticatedAccount.class, + DisabledPermittedAuthenticatedAccount.class))) + .addProvider(new RateLimitExceededExceptionMapper()) + .addProvider(new ImpossiblePhoneNumberExceptionMapper()) + .addProvider(new NonNormalizedPhoneNumberExceptionMapper()) + .setMapper(SystemMapper.getMapper()) + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addResource( + new AccountControllerV2(accountsManager, changeNumberManager, + new PhoneVerificationTokenManager(registrationServiceClient, registrationRecoveryPasswordsManager), + registrationLockVerificationManager, rateLimiters)) + .build(); + + @Nested + class ChangeNumber { + + @BeforeEach + void setUp() throws Exception { + when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter); + + when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any())).thenAnswer( + (Answer) invocation -> { + final Account account = invocation.getArgument(0, Account.class); + final String number = invocation.getArgument(1, String.class); + final String pniIdentityKey = invocation.getArgument(2, String.class); + + final UUID uuid = account.getUuid(); + final List devices = account.getDevices(); + + final Account updatedAccount = mock(Account.class); + when(updatedAccount.getUuid()).thenReturn(uuid); + when(updatedAccount.getNumber()).thenReturn(number); + when(updatedAccount.getPhoneNumberIdentityKey()).thenReturn(pniIdentityKey); + when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(UUID.randomUUID()); + when(updatedAccount.getDevices()).thenReturn(devices); + + for (long i = 1; i <= 3; i++) { + final Optional d = account.getDevice(i); + when(updatedAccount.getDevice(i)).thenReturn(d); + } + + return updatedAccount; + }); + } + + @Test + void changeNumberSuccess() throws Exception { + + when(registrationServiceClient.getSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(new RegistrationSession(NEW_NUMBER, true)))); + + final AccountIdentityResponse accountIdentityResponse = + resources.getJerseyTest() + .target("/v2/accounts/number") + .request() + .header(HttpHeaders.AUTHORIZATION, + AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(Entity.entity( + new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", "123", + Collections.emptyList(), + Collections.emptyMap(), Collections.emptyMap()), + MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class); + + verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(NEW_NUMBER), any(), any(), any(), + any()); + + assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid()); + assertEquals(NEW_NUMBER, accountIdentityResponse.number()); + assertNotEquals(AuthHelper.VALID_PNI, accountIdentityResponse.pni()); + } + + @Test + void unprocessableRequestJson() { + final Invocation.Builder request = resources.getJerseyTest() + .target("/v2/accounts/number") + .request() + .header(HttpHeaders.AUTHORIZATION, + AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)); + try (Response response = request.put(Entity.json(unprocessableJson()))) { + assertEquals(400, response.getStatus()); + } + } + + @Test + void missingBasicAuthorization() { + final Invocation.Builder request = resources.getJerseyTest() + .target("/v2/accounts/number") + .request(); + try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) { + assertEquals(401, response.getStatus()); + } + } + + @Test + void invalidBasicAuthorization() { + final Invocation.Builder request = resources.getJerseyTest() + .target("/v2/accounts/number") + .request() + .header(HttpHeaders.AUTHORIZATION, "Basic but-invalid"); + try (Response response = request.put(Entity.json(invalidRequestJson()))) { + assertEquals(401, response.getStatus()); + } + } + + @Test + void invalidRequestBody() { + final Invocation.Builder request = resources.getJerseyTest() + .target("/v2/accounts/number") + .request() + .header(HttpHeaders.AUTHORIZATION, + AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)); + try (Response response = request.put(Entity.json(invalidRequestJson()))) { + assertEquals(422, response.getStatus()); + } + } + + @Test + void rateLimitedNumber() throws Exception { + doThrow(new RateLimitExceededException(null, true)) + .when(registrationLimiter).validate(anyString()); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v2/accounts/number") + .request() + .header(HttpHeaders.AUTHORIZATION, + AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)); + try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) { + assertEquals(429, response.getStatus()); + } + } + + @Test + void registrationServiceTimeout() { + when(registrationServiceClient.getSession(any(), any())) + .thenReturn(CompletableFuture.failedFuture(new RuntimeException())); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v2/accounts/number") + .request() + .header(HttpHeaders.AUTHORIZATION, + AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)); + try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) { + assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus()); + } + } + + @ParameterizedTest + @MethodSource + void registrationServiceSessionCheck(@Nullable final RegistrationSession session, final int expectedStatus, + final String message) { + when(registrationServiceClient.getSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(Optional.ofNullable(session))); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v2/accounts/number") + .request() + .header(HttpHeaders.AUTHORIZATION, + AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)); + try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) { + assertEquals(expectedStatus, response.getStatus(), message); + } + } + + static Stream registrationServiceSessionCheck() { + return Stream.of( + Arguments.of(null, 401, "session not found"), + Arguments.of(new RegistrationSession("+18005551234", false), 400, "session number mismatch"), + Arguments.of(new RegistrationSession(NEW_NUMBER, false), 401, "session not verified") + ); + } + + @ParameterizedTest + @EnumSource(RegistrationLockError.class) + void registrationLock(final RegistrationLockError error) throws Exception { + when(registrationServiceClient.getSession(any(), any())) + .thenReturn( + CompletableFuture.completedFuture(Optional.of(new RegistrationSession(NEW_NUMBER, true)))); + + when(accountsManager.getByE164(any())).thenReturn(Optional.of(mock(Account.class))); + + final Exception e = switch (error) { + case MISMATCH -> new WebApplicationException(error.getExpectedStatus()); + case RATE_LIMITED -> new RateLimitExceededException(null, true); + }; + doThrow(e) + .when(registrationLockVerificationManager).verifyRegistrationLock(any(), any()); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v2/accounts/number") + .request() + .header(HttpHeaders.AUTHORIZATION, + AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)); + try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) { + assertEquals(error.getExpectedStatus(), response.getStatus()); + } + } + + @Test + void recoveryPasswordManagerVerificationTrue() throws Exception { + when(registrationRecoveryPasswordsManager.verify(any(), any())) + .thenReturn(CompletableFuture.completedFuture(true)); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v2/accounts/number") + .request() + .header(HttpHeaders.AUTHORIZATION, + AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)); + final byte[] recoveryPassword = new byte[32]; + try (Response response = request.put(Entity.json(requestJsonRecoveryPassword(recoveryPassword, NEW_NUMBER)))) { + assertEquals(200, response.getStatus()); + + final AccountIdentityResponse accountIdentityResponse = response.readEntity(AccountIdentityResponse.class); + + verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(NEW_NUMBER), any(), any(), any(), + any()); + + assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid()); + assertEquals(NEW_NUMBER, accountIdentityResponse.number()); + assertNotEquals(AuthHelper.VALID_PNI, accountIdentityResponse.pni()); + } + } + + @Test + void recoveryPasswordManagerVerificationFalse() { + when(registrationRecoveryPasswordsManager.verify(any(), any())) + .thenReturn(CompletableFuture.completedFuture(false)); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v2/accounts/number") + .request() + .header(HttpHeaders.AUTHORIZATION, + AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)); + try (Response response = request.put(Entity.json(requestJsonRecoveryPassword(new byte[32], NEW_NUMBER)))) { + assertEquals(403, response.getStatus()); + } + } + + /** + * Valid request JSON with the given Recovery Password + */ + private static String requestJsonRecoveryPassword(final byte[] recoveryPassword, final String newNumber) { + return requestJson("", recoveryPassword, newNumber); + } + + /** + * Valid request JSON with the give session ID and recovery password + */ + private static String requestJson(final String sessionId, final byte[] recoveryPassword, final String newNumber) { + return String.format(""" + { + "sessionId": "%s", + "recoveryPassword": "%s", + "number": "%s", + "reglock": "1234", + "pniIdentityKey": "5678", + "deviceMessages": [], + "devicePniSignedPrekeys": {}, + "pniRegistrationIds": {} + } + """, encodeSessionId(sessionId), encodeRecoveryPassword(recoveryPassword), newNumber); + } + + /** + * Valid request JSON with the give session ID + */ + private static String requestJson(final String sessionId, final String newNumber) { + return requestJson(sessionId, new byte[0], newNumber); + } + + /** + * Request JSON in the shape of {@link org.whispersystems.textsecuregcm.entities.ChangeNumberRequest}, but that + * fails validation + */ + private static String invalidRequestJson() { + return """ + { + "sessionId": null + } + """; + } + + /** + * Request JSON that cannot be marshalled into + * {@link org.whispersystems.textsecuregcm.entities.ChangeNumberRequest} + */ + private static String unprocessableJson() { + return """ + { + "sessionId": [] + } + """; + } + + private static String encodeSessionId(final String sessionId) { + return Base64.getUrlEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8)); + } + + private static String encodeRecoveryPassword(final byte[] recoveryPassword) { + return Base64.getEncoder().encodeToString(recoveryPassword); + } + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java index ea3d0e4f3..5cb4244c4 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java @@ -86,7 +86,8 @@ class ChallengeControllerTest { """; final Duration retryAfter = Duration.ofMinutes(17); - doThrow(new RateLimitExceededException(retryAfter)).when(rateLimitChallengeManager).answerPushChallenge(any(), any()); + doThrow(new RateLimitExceededException(retryAfter, true)).when(rateLimitChallengeManager) + .answerPushChallenge(any(), any()); final Response response = EXTENSION.target("/v1/challenge") .request() @@ -128,7 +129,8 @@ class ChallengeControllerTest { """; final Duration retryAfter = Duration.ofMinutes(17); - doThrow(new RateLimitExceededException(retryAfter)).when(rateLimitChallengeManager).answerRecaptchaChallenge(any(), any(), any(), any()); + doThrow(new RateLimitExceededException(retryAfter, true)).when(rateLimitChallengeManager) + .answerRecaptchaChallenge(any(), any(), any(), any()); final Response response = EXTENSION.target("/v1/challenge") .request() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java index 53e9c54de..96565ea01 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java @@ -255,7 +255,8 @@ class ProfileControllerTest { @Test void testProfileGetByAciRateLimited() throws RateLimitExceededException { - doThrow(new RateLimitExceededException(Duration.ofSeconds(13))).when(rateLimiter).validate(AuthHelper.VALID_UUID); + doThrow(new RateLimitExceededException(Duration.ofSeconds(13), true)).when(rateLimiter) + .validate(AuthHelper.VALID_UUID); Response response= resources.getJerseyTest() .target("/v1/profile/" + AuthHelper.VALID_UUID_TWO) @@ -326,7 +327,8 @@ class ProfileControllerTest { @Test void testProfileGetByPniRateLimited() throws RateLimitExceededException { - doThrow(new RateLimitExceededException(Duration.ofSeconds(13))).when(rateLimiter).validate(AuthHelper.VALID_UUID); + doThrow(new RateLimitExceededException(Duration.ofSeconds(13), true)).when(rateLimiter) + .validate(AuthHelper.VALID_UUID); Response response= resources.getJerseyTest() .target("/v1/profile/" + AuthHelper.VALID_PNI_TWO) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProvisioningControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProvisioningControllerTest.java index 31c6c7ad8..4384c1d16 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProvisioningControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProvisioningControllerTest.java @@ -1,9 +1,26 @@ package org.whispersystems.textsecuregcm.controllers; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + import com.google.common.collect.ImmutableSet; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.ResourceExtension; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Base64; +import java.util.UUID; +import javax.ws.rs.client.Entity; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -20,25 +37,6 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress; -import javax.ws.rs.client.Entity; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; - -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.Base64; -import java.util.UUID; - -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.reset; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - @ExtendWith(DropwizardExtensionsSupport.class) class ProvisioningControllerTest { @@ -101,7 +99,7 @@ class ProvisioningControllerTest { final String destination = UUID.randomUUID().toString(); final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8); - doThrow(new RateLimitExceededException(Duration.ZERO)) + doThrow(new RateLimitExceededException(Duration.ZERO, true)) .when(messagesRateLimiter).validate(AuthHelper.VALID_UUID); try (final Response response = RESOURCE_EXTENSION.getJerseyTest() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java index d2810f884..720b8bac7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java @@ -13,6 +13,7 @@ import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import com.google.i18n.phonenumbers.PhoneNumberUtil; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.ResourceExtension; import java.nio.charset.StandardCharsets; @@ -37,6 +38,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.MethodSource; +import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager; import org.whispersystems.textsecuregcm.auth.RegistrationLockError; import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; import org.whispersystems.textsecuregcm.entities.AccountAttributes; @@ -51,12 +53,18 @@ import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; +import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.SystemMapper; @ExtendWith(DropwizardExtensionsSupport.class) class RegistrationControllerTest { - private static final String NUMBER = "+18005551212"; + private static final String NUMBER = PhoneNumberUtil.getInstance().format( + PhoneNumberUtil.getInstance().getExampleNumber("US"), + PhoneNumberUtil.PhoneNumberFormat.E164); + + public static final String PASSWORD = "password"; + private final AccountsManager accountsManager = mock(AccountsManager.class); private final RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class); private final RegistrationLockVerificationManager registrationLockVerificationManager = mock( @@ -66,7 +74,6 @@ class RegistrationControllerTest { private final RateLimiters rateLimiters = mock(RateLimiters.class); private final RateLimiter registrationLimiter = mock(RateLimiter.class); - private final RateLimiter pinLimiter = mock(RateLimiter.class); private final ResourceExtension resources = ResourceExtension.builder() .addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE) @@ -76,14 +83,14 @@ class RegistrationControllerTest { .setMapper(SystemMapper.getMapper()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource( - new RegistrationController(accountsManager, registrationServiceClient, registrationLockVerificationManager, - registrationRecoveryPasswordsManager, rateLimiters)) + new RegistrationController(accountsManager, + new PhoneVerificationTokenManager(registrationServiceClient, registrationRecoveryPasswordsManager), + registrationLockVerificationManager, rateLimiters)) .build(); @BeforeEach void setUp() { when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter); - when(rateLimiters.getPinLimiter()).thenReturn(pinLimiter); } @Test @@ -130,25 +137,23 @@ class RegistrationControllerTest { final Invocation.Builder request = resources.getJerseyTest() .target("/v1/registration") .request() - .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD)); try (Response response = request.post(Entity.json(invalidRequestJson()))) { assertEquals(422, response.getStatus()); } } @Test - void rateLimitedSession() throws Exception { - final String sessionId = "sessionId"; + void rateLimitedNumber() throws Exception { doThrow(RateLimitExceededException.class) - .when(registrationLimiter).validate(encodeSessionId(sessionId)); + .when(registrationLimiter).validate(NUMBER); final Invocation.Builder request = resources.getJerseyTest() .target("/v1/registration") .request() - .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); - try (Response response = request.post(Entity.json(requestJson(sessionId)))) { - assertEquals(413, response.getStatus()); - // In the future, change to assertEquals(429, response.getStatus()); + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD)); + try (Response response = request.post(Entity.json(requestJson("sessionId")))) { + assertEquals(429, response.getStatus()); } } @@ -160,7 +165,7 @@ class RegistrationControllerTest { final Invocation.Builder request = resources.getJerseyTest() .target("/v1/registration") .request() - .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD)); try (Response response = request.post(Entity.json(requestJson("sessionId")))) { assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus()); } @@ -174,7 +179,7 @@ class RegistrationControllerTest { final Invocation.Builder request = resources.getJerseyTest() .target("/v1/registration") .request() - .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD)); try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) { assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus()); } @@ -190,7 +195,7 @@ class RegistrationControllerTest { final Invocation.Builder request = resources.getJerseyTest() .target("/v1/registration") .request() - .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD)); try (Response response = request.post(Entity.json(requestJson("sessionId")))) { assertEquals(expectedStatus, response.getStatus(), message); } @@ -214,7 +219,7 @@ class RegistrationControllerTest { final Invocation.Builder request = resources.getJerseyTest() .target("/v1/registration") .request() - .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD)); final byte[] recoveryPassword = new byte[32]; try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(recoveryPassword)))) { assertEquals(200, response.getStatus()); @@ -229,7 +234,7 @@ class RegistrationControllerTest { final Invocation.Builder request = resources.getJerseyTest() .target("/v1/registration") .request() - .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD)); try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) { assertEquals(403, response.getStatus()); } @@ -245,7 +250,7 @@ class RegistrationControllerTest { final Exception e = switch (error) { case MISMATCH -> new WebApplicationException(error.getExpectedStatus()); - case RATE_LIMITED -> new RateLimitExceededException(null); + case RATE_LIMITED -> new RateLimitExceededException(null, true); }; doThrow(e) .when(registrationLockVerificationManager).verifyRegistrationLock(any(), any()); @@ -253,7 +258,7 @@ class RegistrationControllerTest { final Invocation.Builder request = resources.getJerseyTest() .target("/v1/registration") .request() - .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD)); try (Response response = request.post(Entity.json(requestJson("sessionId")))) { assertEquals(error.getExpectedStatus(), response.getStatus()); } @@ -286,7 +291,7 @@ class RegistrationControllerTest { final Invocation.Builder request = resources.getJerseyTest() .target("/v1/registration") .request() - .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD)); try (Response response = request.post(Entity.json(requestJson("sessionId", new byte[0], skipDeviceTransfer)))) { assertEquals(expectedStatus, response.getStatus()); } @@ -294,7 +299,7 @@ class RegistrationControllerTest { // this is functionally the same as deviceTransferAvailable(existingAccount=false) @Test - void success() throws Exception { + void registrationSuccess() throws Exception { when(registrationServiceClient.getSession(any(), any())) .thenReturn(CompletableFuture.completedFuture(Optional.of(new RegistrationSession(NUMBER, true)))); when(accountsManager.create(any(), any(), any(), any(), any())) @@ -303,7 +308,7 @@ class RegistrationControllerTest { final Invocation.Builder request = resources.getJerseyTest() .target("/v1/registration") .request() - .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD)); try (Response response = request.post(Entity.json(requestJson("sessionId")))) { assertEquals(200, response.getStatus()); } @@ -365,13 +370,8 @@ class RegistrationControllerTest { """; } - private static String authorizationHeader(final String number) { - return "Basic " + Base64.getEncoder().encodeToString( - String.format("%s:password", number).getBytes(StandardCharsets.UTF_8)); - } - private static String encodeSessionId(final String sessionId) { - return Base64.getEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8)); + return Base64.getUrlEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8)); } private static String encodeRecoveryPassword(final byte[] recoveryPassword) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIpTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIpTest.java index be1268f68..98dc855c9 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIpTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIpTest.java @@ -72,11 +72,11 @@ public class RateLimitedByIpTest { public void testRateLimits() throws Exception { Mockito.doNothing().when(RATE_LIMITER).validate(Mockito.eq(IP)); validateSuccess("/test/strict", VALID_X_FORWARDED_FOR); - Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER)).when(RATE_LIMITER).validate(Mockito.eq(IP)); + Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER, true)).when(RATE_LIMITER).validate(Mockito.eq(IP)); validateFailure("/test/strict", VALID_X_FORWARDED_FOR, RETRY_AFTER); Mockito.doNothing().when(RATE_LIMITER).validate(Mockito.eq(IP)); validateSuccess("/test/strict", VALID_X_FORWARDED_FOR); - Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER)).when(RATE_LIMITER).validate(Mockito.eq(IP)); + Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER, true)).when(RATE_LIMITER).validate(Mockito.eq(IP)); validateFailure("/test/strict", VALID_X_FORWARDED_FOR, RETRY_AFTER); } @@ -92,7 +92,7 @@ public class RateLimitedByIpTest { validateSuccess("/test/loose", ""); // also checking that even if rate limiter is failing -- it doesn't matter in the case of invalid IP - Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER)).when(RATE_LIMITER).validate(Mockito.anyString()); + Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER, true)).when(RATE_LIMITER).validate(Mockito.anyString()); validateFailure("/test/loose", VALID_X_FORWARDED_FOR, RETRY_AFTER); validateSuccess("/test/loose", INVALID_X_FORWARDED_FOR); validateSuccess("/test/loose", ""); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java index 12066b77a..49604832c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java @@ -340,7 +340,7 @@ class KeysControllerTest { @Test void testGetKeysRateLimited() throws RateLimitExceededException { Duration retryAfter = Duration.ofSeconds(31); - doThrow(new RateLimitExceededException(retryAfter)).when(rateLimiter).validate(anyString()); + doThrow(new RateLimitExceededException(retryAfter, true)).when(rateLimiter).validate(anyString()); Response result = resources.getJerseyTest() .target(String.format("/v2/keys/%s/*", EXISTS_PNI)) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java index d39f43785..def0353ea 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/MockUtils.java @@ -60,11 +60,12 @@ public final class MockUtils { final RateLimiters rateLimitersMock, final RateLimiters.Handle handle, final String input, - final Duration retryAfter) { + final Duration retryAfter, + final boolean legacyStatusCode) { final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class); doReturn(Optional.of(mockRateLimiter)).when(rateLimitersMock).byHandle(eq(handle)); try { - doThrow(new RateLimitExceededException(retryAfter)).when(mockRateLimiter).validate(eq(input)); + doThrow(new RateLimitExceededException(retryAfter, legacyStatusCode)).when(mockRateLimiter).validate(eq(input)); } catch (final RateLimitExceededException e) { throw new RuntimeException(e); }