Add `PUT /v2/account/number`

This commit is contained in:
Chris Eager 2023-01-24 15:33:48 -06:00 committed by Chris Eager
parent 8fc465b3e8
commit c16006dc4b
23 changed files with 856 additions and 186 deletions

View File

@ -76,6 +76,7 @@ import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator;
import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener; import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener;
@ -87,6 +88,7 @@ import org.whispersystems.textsecuregcm.captcha.RecaptchaClient;
import org.whispersystems.textsecuregcm.configuration.DirectoryServerConfiguration; import org.whispersystems.textsecuregcm.configuration.DirectoryServerConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.AccountController; import org.whispersystems.textsecuregcm.controllers.AccountController;
import org.whispersystems.textsecuregcm.controllers.AccountControllerV2;
import org.whispersystems.textsecuregcm.controllers.ArtController; import org.whispersystems.textsecuregcm.controllers.ArtController;
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV2; import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV2;
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV3; import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV3;
@ -524,16 +526,22 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
final RegistrationLockVerificationManager registrationLockVerificationManager = new RegistrationLockVerificationManager( final RegistrationLockVerificationManager registrationLockVerificationManager = new RegistrationLockVerificationManager(
accountsManager, clientPresenceManager, backupCredentialsGenerator, rateLimiters); accountsManager, clientPresenceManager, backupCredentialsGenerator, rateLimiters);
final PhoneVerificationTokenManager phoneVerificationTokenManager = new PhoneVerificationTokenManager(
registrationServiceClient, registrationRecoveryPasswordsManager);
ReportedMessageMetricsListener reportedMessageMetricsListener = new ReportedMessageMetricsListener(accountsManager); final ReportedMessageMetricsListener reportedMessageMetricsListener = new ReportedMessageMetricsListener(
accountsManager);
reportMessageManager.addListener(reportedMessageMetricsListener); reportMessageManager.addListener(reportedMessageMetricsListener);
AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager); final AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager);
DisabledPermittedAccountAuthenticator disabledPermittedAccountAuthenticator = new DisabledPermittedAccountAuthenticator(accountsManager); final DisabledPermittedAccountAuthenticator disabledPermittedAccountAuthenticator = new DisabledPermittedAccountAuthenticator(
accountsManager);
MessageSender messageSender = new MessageSender(clientPresenceManager, messagesManager, pushNotificationManager, pushLatencyManager); final MessageSender messageSender = new MessageSender(clientPresenceManager, messagesManager,
ReceiptSender receiptSender = new ReceiptSender(accountsManager, messageSender, receiptSenderExecutor); pushNotificationManager,
TurnTokenGenerator turnTokenGenerator = new TurnTokenGenerator(dynamicConfigurationManager); pushLatencyManager);
final ReceiptSender receiptSender = new ReceiptSender(accountsManager, messageSender, receiptSenderExecutor);
final TurnTokenGenerator turnTokenGenerator = new TurnTokenGenerator(dynamicConfigurationManager);
RecaptchaClient recaptchaClient = new RecaptchaClient( RecaptchaClient recaptchaClient = new RecaptchaClient(
config.getRecaptchaConfiguration().getProjectPath(), config.getRecaptchaConfiguration().getProjectPath(),
@ -731,6 +739,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
} }
final List<Object> commonControllers = Lists.newArrayList( final List<Object> commonControllers = Lists.newArrayList(
new AccountControllerV2(accountsManager, changeNumberManager, phoneVerificationTokenManager,
registrationLockVerificationManager, rateLimiters),
new ArtController(rateLimiters, artCredentialsGenerator), new ArtController(rateLimiters, artCredentialsGenerator),
new AttachmentControllerV2(rateLimiters, config.getAwsAttachmentsConfiguration().getAccessKey(), config.getAwsAttachmentsConfiguration().getAccessSecret(), config.getAwsAttachmentsConfiguration().getRegion(), config.getAwsAttachmentsConfiguration().getBucket()), new AttachmentControllerV2(rateLimiters, config.getAwsAttachmentsConfiguration().getAccessKey(), config.getAwsAttachmentsConfiguration().getAccessSecret(), config.getAwsAttachmentsConfiguration().getRegion(), config.getAwsAttachmentsConfiguration().getBucket()),
new AttachmentControllerV3(rateLimiters, config.getGcpAttachmentsConfiguration().getDomain(), config.getGcpAttachmentsConfiguration().getEmail(), config.getGcpAttachmentsConfiguration().getMaxSizeInBytes(), config.getGcpAttachmentsConfiguration().getPathPrefix(), config.getGcpAttachmentsConfiguration().getRsaSigningKey()), new AttachmentControllerV3(rateLimiters, config.getGcpAttachmentsConfiguration().getDomain(), config.getGcpAttachmentsConfiguration().getEmail(), config.getGcpAttachmentsConfiguration().getMaxSizeInBytes(), config.getGcpAttachmentsConfiguration().getPathPrefix(), config.getGcpAttachmentsConfiguration().getRsaSigningKey()),
@ -748,8 +758,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
profileBadgeConverter, config.getBadges(), cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, config.getBadges(), cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner,
config.getCdnConfiguration().getBucket(), zkProfileOperations, batchIdentityCheckExecutor), config.getCdnConfiguration().getBucket(), zkProfileOperations, batchIdentityCheckExecutor),
new ProvisioningController(rateLimiters, provisioningManager), new ProvisioningController(rateLimiters, provisioningManager),
new RegistrationController(accountsManager, registrationServiceClient, registrationLockVerificationManager, new RegistrationController(accountsManager, phoneVerificationTokenManager, registrationLockVerificationManager,
registrationRecoveryPasswordsManager, rateLimiters), rateLimiters),
new RemoteConfigController(remoteConfigsManager, adminEventLogger, new RemoteConfigController(remoteConfigsManager, adminEventLogger,
config.getRemoteConfigConfiguration().getAuthorizedTokens(), config.getRemoteConfigConfiguration().getAuthorizedTokens(),
config.getRemoteConfigConfiguration().getGlobalConfig()), config.getRemoteConfigConfiguration().getGlobalConfig()),

View File

@ -0,0 +1,98 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import java.security.MessageDigest;
import java.time.Duration;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import javax.ws.rs.BadRequestException;
import javax.ws.rs.ForbiddenException;
import javax.ws.rs.NotAuthorizedException;
import javax.ws.rs.ServerErrorException;
import javax.ws.rs.core.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest;
import org.whispersystems.textsecuregcm.entities.RegistrationSession;
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
public class PhoneVerificationTokenManager {
private static final Logger logger = LoggerFactory.getLogger(PhoneVerificationTokenManager.class);
private static final Duration REGISTRATION_RPC_TIMEOUT = Duration.ofSeconds(15);
private static final long VERIFICATION_TIMEOUT_SECONDS = REGISTRATION_RPC_TIMEOUT.plusSeconds(1).getSeconds();
private final RegistrationServiceClient registrationServiceClient;
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager;
public PhoneVerificationTokenManager(final RegistrationServiceClient registrationServiceClient,
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager) {
this.registrationServiceClient = registrationServiceClient;
this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager;
}
/**
* Checks if a {@link PhoneVerificationRequest} has a token that verifies the caller has confirmed access to the e164
* number
*
* @param number the e164 presented for verification
* @param request the request with exactly one verification token (RegistrationService sessionId or registration
* recovery password)
* @return if verification was successful, returns the verification type
* @throws BadRequestException if the number does not match the sessionIds number
* @throws NotAuthorizedException if the session is not verified
* @throws ForbiddenException if the recovery password is not valid
* @throws InterruptedException if verification did not complete before a timeout
*/
public PhoneVerificationRequest.VerificationType verify(final String number, final PhoneVerificationRequest request)
throws InterruptedException {
final PhoneVerificationRequest.VerificationType verificationType = request.verificationType();
switch (verificationType) {
case SESSION -> verifyBySessionId(number, request.decodeSessionId());
case RECOVERY_PASSWORD -> verifyByRecoveryPassword(number, request.recoveryPassword());
}
return verificationType;
}
private void verifyBySessionId(final String number, final byte[] sessionId) throws InterruptedException {
try {
final RegistrationSession session = registrationServiceClient
.getSession(sessionId, REGISTRATION_RPC_TIMEOUT)
.get(VERIFICATION_TIMEOUT_SECONDS, TimeUnit.SECONDS)
.orElseThrow(() -> new NotAuthorizedException("session not verified"));
if (!MessageDigest.isEqual(number.getBytes(), session.number().getBytes())) {
throw new BadRequestException("number does not match session");
}
if (!session.verified()) {
throw new NotAuthorizedException("session not verified");
}
} catch (final CancellationException | ExecutionException | TimeoutException e) {
logger.error("Registration service failure", e);
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE);
}
}
private void verifyByRecoveryPassword(final String number, final byte[] recoveryPassword)
throws InterruptedException {
try {
final boolean verified = registrationRecoveryPasswordsManager.verify(number, recoveryPassword)
.get(VERIFICATION_TIMEOUT_SECONDS, TimeUnit.SECONDS);
if (!verified) {
throw new ForbiddenException("recoveryPassword couldn't be verified");
}
} catch (final ExecutionException | TimeoutException e) {
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE);
}
}
}

View File

@ -798,7 +798,7 @@ public class AccountController {
// This shouldn't happen, so conservatively assume we're over the rate-limit // This shouldn't happen, so conservatively assume we're over the rate-limit
// and indicate that the client should retry // and indicate that the client should retry
logger.error("Missing/bad Forwarded-For: {}", forwardedFor); logger.error("Missing/bad Forwarded-For: {}", forwardedFor);
return new RateLimitExceededException(Duration.ofHours(1)); return new RateLimitExceededException(Duration.ofHours(1), true);
}); });
rateLimiter.validate(mostRecentProxy); rateLimiter.validate(mostRecentProxy);

View File

@ -0,0 +1,132 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.annotation.Timed;
import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import java.util.Optional;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import javax.ws.rs.BadRequestException;
import javax.ws.rs.Consumes;
import javax.ws.rs.ForbiddenException;
import javax.ws.rs.HeaderParam;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ChangeNumberManager;
@Path("/v2/accounts")
public class AccountControllerV2 {
private static final String CHANGE_NUMBER_COUNTER_NAME = name(AccountControllerV2.class, "create");
private static final String VERIFICATION_TYPE_TAG_NAME = "verification";
private final AccountsManager accountsManager;
private final ChangeNumberManager changeNumberManager;
private final PhoneVerificationTokenManager phoneVerificationTokenManager;
private final RegistrationLockVerificationManager registrationLockVerificationManager;
private final RateLimiters rateLimiters;
public AccountControllerV2(final AccountsManager accountsManager, final ChangeNumberManager changeNumberManager,
final PhoneVerificationTokenManager phoneVerificationTokenManager,
final RegistrationLockVerificationManager registrationLockVerificationManager, final RateLimiters rateLimiters) {
this.accountsManager = accountsManager;
this.changeNumberManager = changeNumberManager;
this.phoneVerificationTokenManager = phoneVerificationTokenManager;
this.registrationLockVerificationManager = registrationLockVerificationManager;
this.rateLimiters = rateLimiters;
}
@Timed
@PUT
@Path("/number")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public AccountIdentityResponse changeNumber(@Auth final AuthenticatedAccount authenticatedAccount,
@NotNull @Valid final ChangeNumberRequest request, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent)
throws RateLimitExceededException, InterruptedException {
if (!authenticatedAccount.getAuthenticatedDevice().isMaster()) {
throw new ForbiddenException();
}
final String number = request.number();
// Only verify and check reglock if there's a data change to be made...
if (!authenticatedAccount.getAccount().getNumber().equals(number)) {
RateLimiter.adaptLegacyException(() -> rateLimiters.getRegistrationLimiter().validate(number));
final PhoneVerificationRequest.VerificationType verificationType = phoneVerificationTokenManager.verify(number,
request);
final Optional<Account> existingAccount = accountsManager.getByE164(number);
if (existingAccount.isPresent()) {
registrationLockVerificationManager.verifyRegistrationLock(existingAccount.get(), request.registrationLock());
}
Metrics.counter(CHANGE_NUMBER_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(VERIFICATION_TYPE_TAG_NAME, verificationType.name())))
.increment();
}
// ...but always attempt to make the change in case a client retries and needs to re-send messages
try {
final Account updatedAccount = changeNumberManager.changeNumber(
authenticatedAccount.getAccount(),
request.number(),
request.pniIdentityKey(),
request.devicePniSignedPrekeys(),
request.deviceMessages(),
request.pniRegistrationIds());
return new AccountIdentityResponse(
updatedAccount.getUuid(),
updatedAccount.getNumber(),
updatedAccount.getPhoneNumberIdentifier(),
updatedAccount.getUsernameHash().orElse(null),
updatedAccount.isStorageSupported());
} catch (MismatchedDevicesException e) {
throw new WebApplicationException(Response.status(409)
.type(MediaType.APPLICATION_JSON_TYPE)
.entity(new MismatchedDevices(e.getMissingDevices(),
e.getExtraDevices()))
.build());
} catch (StaleDevicesException e) {
throw new WebApplicationException(Response.status(410)
.type(MediaType.APPLICATION_JSON)
.entity(new StaleDevices(e.getStaleDevices()))
.build());
} catch (IllegalArgumentException e) {
throw new BadRequestException(e);
}
}
}

View File

@ -10,20 +10,27 @@ import javax.annotation.Nullable;
public class RateLimitExceededException extends Exception { public class RateLimitExceededException extends Exception {
private final @Nullable @Nullable
Duration retryDuration; private final Duration retryDuration;
private final boolean legacy;
/** /**
* Constructs a new exception indicating when it may become safe to retry * Constructs a new exception indicating when it may become safe to retry
* *
* @param retryDuration A duration to wait before retrying, null if no duration can be indicated * @param retryDuration A duration to wait before retrying, null if no duration can be indicated
* @param legacy whether to use a legacy status code when mapping the exception to an HTTP response
*/ */
public RateLimitExceededException(final @Nullable Duration retryDuration) { public RateLimitExceededException(@Nullable final Duration retryDuration, final boolean legacy) {
super(null, null, true, false); super(null, null, true, false);
this.retryDuration = retryDuration; this.retryDuration = retryDuration;
this.legacy = legacy;
} }
public Optional<Duration> getRetryDuration() { public Optional<Duration> getRetryDuration() {
return Optional.ofNullable(retryDuration); return Optional.ofNullable(retryDuration);
} }
public boolean isLegacy() {
return legacy;
}
} }

View File

@ -13,42 +13,33 @@ import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags; import io.micrometer.core.instrument.Tags;
import java.security.MessageDigest;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import javax.ws.rs.BadRequestException;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
import javax.ws.rs.ForbiddenException;
import javax.ws.rs.HeaderParam; import javax.ws.rs.HeaderParam;
import javax.ws.rs.NotAuthorizedException;
import javax.ws.rs.POST; import javax.ws.rs.POST;
import javax.ws.rs.Path; import javax.ws.rs.Path;
import javax.ws.rs.Produces; import javax.ws.rs.Produces;
import javax.ws.rs.ServerErrorException;
import javax.ws.rs.WebApplicationException; import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader; import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader;
import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse; import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest;
import org.whispersystems.textsecuregcm.entities.RegistrationRequest; import org.whispersystems.textsecuregcm.entities.RegistrationRequest;
import org.whispersystems.textsecuregcm.entities.RegistrationSession; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
@ -68,24 +59,17 @@ public class RegistrationController {
private static final String REGION_CODE_TAG_NAME = "regionCode"; private static final String REGION_CODE_TAG_NAME = "regionCode";
private static final String VERIFICATION_TYPE_TAG_NAME = "verification"; private static final String VERIFICATION_TYPE_TAG_NAME = "verification";
private static final Duration REGISTRATION_RPC_TIMEOUT = Duration.ofSeconds(15);
private static final long VERIFICATION_TIMEOUT_SECONDS = REGISTRATION_RPC_TIMEOUT.plusSeconds(1).getSeconds();
private final AccountsManager accounts; private final AccountsManager accounts;
private final RegistrationServiceClient registrationServiceClient; private final PhoneVerificationTokenManager phoneVerificationTokenManager;
private final RegistrationLockVerificationManager registrationLockVerificationManager; private final RegistrationLockVerificationManager registrationLockVerificationManager;
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager;
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
public RegistrationController(final AccountsManager accounts, public RegistrationController(final AccountsManager accounts,
final RegistrationServiceClient registrationServiceClient, final PhoneVerificationTokenManager phoneVerificationTokenManager,
final RegistrationLockVerificationManager registrationLockVerificationManager, final RegistrationLockVerificationManager registrationLockVerificationManager, final RateLimiters rateLimiters) {
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager,
final RateLimiters rateLimiters) {
this.accounts = accounts; this.accounts = accounts;
this.registrationServiceClient = registrationServiceClient; this.phoneVerificationTokenManager = phoneVerificationTokenManager;
this.registrationLockVerificationManager = registrationLockVerificationManager; this.registrationLockVerificationManager = registrationLockVerificationManager;
this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager;
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
} }
@ -99,17 +83,13 @@ public class RegistrationController {
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@NotNull @Valid final RegistrationRequest registrationRequest) throws RateLimitExceededException, InterruptedException { @NotNull @Valid final RegistrationRequest registrationRequest) throws RateLimitExceededException, InterruptedException {
rateLimiters.getRegistrationLimiter().validate(registrationRequest.sessionId());
final String number = authorizationHeader.getUsername(); final String number = authorizationHeader.getUsername();
final String password = authorizationHeader.getPassword(); final String password = authorizationHeader.getPassword();
// decide on the method of verification based on the registration request parameters and verify RateLimiter.adaptLegacyException(() -> rateLimiters.getRegistrationLimiter().validate(number));
final RegistrationRequest.VerificationType verificationType = registrationRequest.verificationType();
switch (verificationType) { final PhoneVerificationRequest.VerificationType verificationType = phoneVerificationTokenManager.verify(number,
case SESSION -> verifyBySessionId(number, registrationRequest.decodeSessionId()); registrationRequest);
case RECOVERY_PASSWORD -> verifyByRecoveryPassword(number, registrationRequest.recoveryPassword());
}
final Optional<Account> existingAccount = accounts.getByE164(number); final Optional<Account> existingAccount = accounts.getByE164(number);
@ -146,34 +126,4 @@ public class RegistrationController {
existingAccount.map(Account::isStorageSupported).orElse(false)); existingAccount.map(Account::isStorageSupported).orElse(false));
} }
private void verifyBySessionId(final String number, final byte[] sessionId) throws InterruptedException {
try {
final RegistrationSession session = registrationServiceClient
.getSession(sessionId, REGISTRATION_RPC_TIMEOUT)
.get(VERIFICATION_TIMEOUT_SECONDS, TimeUnit.SECONDS)
.orElseThrow(() -> new NotAuthorizedException("session not verified"));
if (!MessageDigest.isEqual(number.getBytes(), session.number().getBytes())) {
throw new BadRequestException("number does not match session");
}
if (!session.verified()) {
throw new NotAuthorizedException("session not verified");
}
} catch (final CancellationException | ExecutionException | TimeoutException e) {
logger.error("Registration service failure", e);
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE);
}
}
private void verifyByRecoveryPassword(final String number, final byte[] recoveryPassword) throws InterruptedException {
try {
final boolean verified = registrationRecoveryPasswordsManager.verify(number, recoveryPassword)
.get(VERIFICATION_TIMEOUT_SECONDS, TimeUnit.SECONDS);
if (!verified) {
throw new ForbiddenException("recoveryPassword couldn't be verified");
}
} catch (final ExecutionException | TimeoutException e) {
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE);
}
}
} }

View File

@ -0,0 +1,27 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import javax.validation.Valid;
import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
public record ChangeNumberRequest(String sessionId,
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) byte[] recoveryPassword,
@NotBlank String number,
@JsonProperty("reglock") @Nullable String registrationLock,
@NotBlank String pniIdentityKey,
@NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages,
@NotNull @Valid Map<Long, @NotNull @Valid SignedPreKey> devicePniSignedPrekeys,
@NotNull Map<Long, Integer> pniRegistrationIds) implements PhoneVerificationRequest {
}

View File

@ -0,0 +1,45 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import static org.apache.commons.lang3.StringUtils.isNotBlank;
import java.util.Base64;
import javax.validation.constraints.AssertTrue;
import javax.ws.rs.ClientErrorException;
import org.apache.http.HttpStatus;
public interface PhoneVerificationRequest {
enum VerificationType {
SESSION,
RECOVERY_PASSWORD
}
String sessionId();
byte[] recoveryPassword();
// for the @AssertTrue to work with bean validation, method name must follow 'isSmth()'/'getSmth()' naming convention
@AssertTrue
default boolean isValid() {
// checking that exactly one of sessionId/recoveryPassword is non-empty
return isNotBlank(sessionId()) ^ (recoveryPassword() != null && recoveryPassword().length > 0);
}
default PhoneVerificationRequest.VerificationType verificationType() {
return isNotBlank(sessionId()) ? PhoneVerificationRequest.VerificationType.SESSION
: PhoneVerificationRequest.VerificationType.RECOVERY_PASSWORD;
}
default byte[] decodeSessionId() {
try {
return Base64.getUrlDecoder().decode(sessionId());
} catch (final IllegalArgumentException e) {
throw new ClientErrorException("Malformed session ID", HttpStatus.SC_UNPROCESSABLE_ENTITY);
}
}
}

View File

@ -5,43 +5,15 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import static org.apache.commons.lang3.StringUtils.isNotBlank;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import java.util.Base64;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.AssertTrue;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import javax.ws.rs.ClientErrorException;
import org.apache.http.HttpStatus;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
public record RegistrationRequest(String sessionId, public record RegistrationRequest(String sessionId,
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) byte[] recoveryPassword, @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) byte[] recoveryPassword,
@NotNull @Valid AccountAttributes accountAttributes, @NotNull @Valid AccountAttributes accountAttributes,
boolean skipDeviceTransfer) { boolean skipDeviceTransfer) implements PhoneVerificationRequest {
public enum VerificationType {
SESSION,
RECOVERY_PASSWORD
}
// for the @AssertTrue to work with bean validation, method name must follow 'isSmth()'/'getSmth()' naming convention
@AssertTrue
public boolean isValid() {
// checking that exactly one of sessionId/recoveryPassword is non-empty
return isNotBlank(sessionId) ^ (recoveryPassword != null && recoveryPassword.length > 0);
}
public VerificationType verificationType() {
return isNotBlank(sessionId) ? VerificationType.SESSION : VerificationType.RECOVERY_PASSWORD;
}
public byte[] decodeSessionId() {
try {
return Base64.getUrlDecoder().decode(sessionId());
} catch (final IllegalArgumentException e) {
throw new ClientErrorException("Malformed session ID", HttpStatus.SC_UNPROCESSABLE_ENTITY);
}
}
} }

View File

@ -31,7 +31,7 @@ public class LockingRateLimiter extends RateLimiter {
public void validate(String key, int amount) throws RateLimitExceededException { public void validate(String key, int amount) throws RateLimitExceededException {
if (!acquireLock(key)) { if (!acquireLock(key)) {
meter.mark(); meter.mark();
throw new RateLimitExceededException(Duration.ZERO); throw new RateLimitExceededException(Duration.ZERO, true);
} }
try { try {

View File

@ -29,7 +29,8 @@ public class RateLimitByIpFilter implements ContainerRequestFilter {
private static final Logger logger = LoggerFactory.getLogger(RateLimitByIpFilter.class); private static final Logger logger = LoggerFactory.getLogger(RateLimitByIpFilter.class);
@VisibleForTesting @VisibleForTesting
static final RateLimitExceededException INVALID_HEADER_EXCEPTION = new RateLimitExceededException(Duration.ofHours(1)); static final RateLimitExceededException INVALID_HEADER_EXCEPTION = new RateLimitExceededException(Duration.ofHours(1),
true);
private static final ExceptionMapper<RateLimitExceededException> EXCEPTION_MAPPER = new RateLimitExceededExceptionMapper(); private static final ExceptionMapper<RateLimitExceededException> EXCEPTION_MAPPER = new RateLimitExceededExceptionMapper();

View File

@ -57,7 +57,7 @@ public class RateLimiter {
setBucket(key, bucket); setBucket(key, bucket);
} else { } else {
meter.mark(); meter.mark();
throw new RateLimitExceededException(bucket.getTimeUntilSpaceAvailable(amount)); throw new RateLimitExceededException(bucket.getTimeUntilSpaceAvailable(amount), true);
} }
} }
} }
@ -132,4 +132,22 @@ public class RateLimiter {
public boolean hasConfiguration(final RateLimitConfiguration configuration) { public boolean hasConfiguration(final RateLimitConfiguration configuration) {
return bucketSize == configuration.getBucketSize() && leakRatePerMinute == configuration.getLeakRatePerMinute(); return bucketSize == configuration.getBucketSize() && leakRatePerMinute == configuration.getLeakRatePerMinute();
} }
/**
* If the wrapped {@code validate()} call throws a {@link RateLimitExceededException}, it will adapt it to ensure that
* {@link RateLimitExceededException#isLegacy()} returns {@code true}
*/
public static void adaptLegacyException(final RateLimitValidator validator) throws RateLimitExceededException {
try {
validator.validate();
} catch (final RateLimitExceededException e) {
throw new RateLimitExceededException(e.getRetryDuration().orElse(null), false);
}
}
@FunctionalInterface
public interface RateLimitValidator {
void validate() throws RateLimitExceededException;
}
} }

View File

@ -4,36 +4,41 @@
*/ */
package org.whispersystems.textsecuregcm.mappers; package org.whispersystems.textsecuregcm.mappers;
import javax.ws.rs.core.Response;
import javax.ws.rs.ext.ExceptionMapper;
import javax.ws.rs.ext.Provider;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import javax.ws.rs.core.Response;
import javax.ws.rs.ext.ExceptionMapper;
import javax.ws.rs.ext.Provider;
@Provider @Provider
public class RateLimitExceededExceptionMapper implements ExceptionMapper<RateLimitExceededException> { public class RateLimitExceededExceptionMapper implements ExceptionMapper<RateLimitExceededException> {
private static final Logger logger = LoggerFactory.getLogger(RateLimitExceededExceptionMapper.class); private static final Logger logger = LoggerFactory.getLogger(RateLimitExceededExceptionMapper.class);
private static final int LEGACY_STATUS_CODE = 413;
private static final int STATUS_CODE = 429;
/** /**
* Convert a RateLimitExceededException to a 413 response with a * Convert a RateLimitExceededException to a {@value STATUS_CODE} (or legacy {@value LEGACY_STATUS_CODE}) response
* Retry-After header. * with a Retry-After header.
* *
* @param e A RateLimitExceededException potentially containing a recommended retry duration * @param e A RateLimitExceededException potentially containing a recommended retry duration
* @return the response * @return the response
*/ */
@Override @Override
public Response toResponse(RateLimitExceededException e) { public Response toResponse(RateLimitExceededException e) {
final int statusCode = e.isLegacy() ? LEGACY_STATUS_CODE : STATUS_CODE;
return e.getRetryDuration() return e.getRetryDuration()
.filter(d -> { .filter(d -> {
if (d.isNegative()) { if (d.isNegative()) {
logger.warn("Encountered a negative retry duration: {}, will not include a Retry-After header in response", d); logger.warn("Encountered a negative retry duration: {}, will not include a Retry-After header in response",
d);
} }
// only include non-negative durations in retry headers // only include non-negative durations in retry headers
return !d.isNegative(); return !d.isNegative();
}) })
.map(d -> Response.status(413).header("Retry-After", d.toSeconds())) .map(d -> Response.status(statusCode).header("Retry-After", d.toSeconds()))
.orElseGet(() -> Response.status(413)).build(); .orElseGet(() -> Response.status(statusCode)).build();
} }
} }

View File

@ -89,9 +89,12 @@ public class RegistrationServiceClient implements Managed {
case ERROR -> { case ERROR -> {
switch (response.getError().getErrorType()) { switch (response.getError().getErrorType()) {
case CREATE_REGISTRATION_SESSION_ERROR_TYPE_RATE_LIMITED -> throw new CompletionException(new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds()))); case CREATE_REGISTRATION_SESSION_ERROR_TYPE_RATE_LIMITED -> throw new CompletionException(
new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds()),
true));
case CREATE_REGISTRATION_SESSION_ERROR_TYPE_ILLEGAL_PHONE_NUMBER -> throw new IllegalArgumentException(); case CREATE_REGISTRATION_SESSION_ERROR_TYPE_ILLEGAL_PHONE_NUMBER -> throw new IllegalArgumentException();
default -> throw new RuntimeException("Unrecognized error type from registration service: " + response.getError().getErrorType()); default -> throw new RuntimeException(
"Unrecognized error type from registration service: " + response.getError().getErrorType());
} }
} }
@ -119,8 +122,9 @@ public class RegistrationServiceClient implements Managed {
.thenApply(response -> { .thenApply(response -> {
if (response.hasError()) { if (response.hasError()) {
switch (response.getError().getErrorType()) { switch (response.getError().getErrorType()) {
case SEND_VERIFICATION_CODE_ERROR_TYPE_RATE_LIMITED -> case SEND_VERIFICATION_CODE_ERROR_TYPE_RATE_LIMITED -> throw new CompletionException(
throw new CompletionException(new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds()))); new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds()),
true));
default -> throw new CompletionException(new RuntimeException("Failed to send verification code: " + response.getError().getErrorType())); default -> throw new CompletionException(new RuntimeException("Failed to send verification code: " + response.getError().getErrorType()));
} }
@ -142,8 +146,9 @@ public class RegistrationServiceClient implements Managed {
.thenApply(response -> { .thenApply(response -> {
if (response.hasError()) { if (response.hasError()) {
switch (response.getError().getErrorType()) { switch (response.getError().getErrorType()) {
case CHECK_VERIFICATION_CODE_ERROR_TYPE_RATE_LIMITED -> case CHECK_VERIFICATION_CODE_ERROR_TYPE_RATE_LIMITED -> throw new CompletionException(
throw new CompletionException(new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds()))); new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds()),
true));
default -> throw new CompletionException(new RuntimeException("Failed to check verification code: " + response.getError().getErrorType())); default -> throw new CompletionException(new RuntimeException("Failed to check verification code: " + response.getError().getErrorType()));
} }

View File

@ -331,11 +331,12 @@ class AccountControllerTest {
when(captchaChecker.verify(eq(VALID_CAPTCHA_TOKEN), anyString())) when(captchaChecker.verify(eq(VALID_CAPTCHA_TOKEN), anyString()))
.thenReturn(new AssessmentResult(true, "")); .thenReturn(new AssessmentResult(true, ""));
doThrow(new RateLimitExceededException(Duration.ZERO)).when(pinLimiter).validate(eq(SENDER_OVER_PIN)); doThrow(new RateLimitExceededException(Duration.ZERO, true)).when(pinLimiter).validate(eq(SENDER_OVER_PIN));
doThrow(new RateLimitExceededException(Duration.ZERO)).when(smsVoicePrefixLimiter).validate(SENDER_OVER_PREFIX.substring(0, 4+2)); doThrow(new RateLimitExceededException(Duration.ZERO, true)).when(smsVoicePrefixLimiter)
doThrow(new RateLimitExceededException(Duration.ZERO)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_IP_HOST); .validate(SENDER_OVER_PREFIX.substring(0, 4 + 2));
doThrow(new RateLimitExceededException(Duration.ZERO)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_HOST2); doThrow(new RateLimitExceededException(Duration.ZERO, true)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_IP_HOST);
doThrow(new RateLimitExceededException(Duration.ZERO, true)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_HOST2);
} }
@AfterEach @AfterEach
@ -571,7 +572,7 @@ class AccountControllerTest {
@Test @Test
void testSendCodeRateLimited() { void testSendCodeRateLimited() {
when(registrationServiceClient.createRegistrationSession(any(), any())) when(registrationServiceClient.createRegistrationSession(any(), any()))
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofMinutes(10)))); .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofMinutes(10), true)));
Response response = Response response =
resources.getJerseyTest() resources.getJerseyTest()
@ -2050,7 +2051,7 @@ class AccountControllerTest {
when(accountsManager.getByAccountIdentifier(accountIdentifier)).thenReturn(Optional.of(account)); when(accountsManager.getByAccountIdentifier(accountIdentifier)).thenReturn(Optional.of(account));
MockUtils.updateRateLimiterResponseToFail( MockUtils.updateRateLimiterResponseToFail(
rateLimiters, RateLimiters.Handle.CHECK_ACCOUNT_EXISTENCE, "127.0.0.1", expectedRetryAfter); rateLimiters, RateLimiters.Handle.CHECK_ACCOUNT_EXISTENCE, "127.0.0.1", expectedRetryAfter, true);
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
.target(String.format("/v1/accounts/account/%s", accountIdentifier)) .target(String.format("/v1/accounts/account/%s", accountIdentifier))
@ -2115,7 +2116,7 @@ class AccountControllerTest {
void testLookupUsernameRateLimited() throws RateLimitExceededException { void testLookupUsernameRateLimited() throws RateLimitExceededException {
final Duration expectedRetryAfter = Duration.ofSeconds(13); final Duration expectedRetryAfter = Duration.ofSeconds(13);
MockUtils.updateRateLimiterResponseToFail( MockUtils.updateRateLimiterResponseToFail(
rateLimiters, RateLimiters.Handle.USERNAME_LOOKUP, "127.0.0.1", expectedRetryAfter); rateLimiters, RateLimiters.Handle.USERNAME_LOOKUP, "127.0.0.1", expectedRetryAfter, true);
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
.target(String.format("v1/accounts/username_hash/%s", BASE_64_URL_USERNAME_HASH_1)) .target(String.format("v1/accounts/username_hash/%s", BASE_64_URL_USERNAME_HASH_1))
.request() .request()

View File

@ -0,0 +1,396 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableSet;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.Invocation;
import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.apache.http.HttpStatus;
import org.glassfish.jersey.server.ServerProperties;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
import org.whispersystems.textsecuregcm.auth.RegistrationLockError;
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest;
import org.whispersystems.textsecuregcm.entities.RegistrationSession;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.ImpossiblePhoneNumberExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.NonNormalizedPhoneNumberExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ChangeNumberManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ExtendWith(DropwizardExtensionsSupport.class)
class AccountControllerV2Test {
public static final String NEW_NUMBER = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
private final AccountsManager accountsManager = mock(AccountsManager.class);
private final ChangeNumberManager changeNumberManager = mock(ChangeNumberManager.class);
private final RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class);
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock(
RegistrationRecoveryPasswordsManager.class);
private final RegistrationLockVerificationManager registrationLockVerificationManager = mock(
RegistrationLockVerificationManager.class);
private final RateLimiters rateLimiters = mock(RateLimiters.class);
private final RateLimiter registrationLimiter = mock(RateLimiter.class);
private final ResourceExtension resources = ResourceExtension.builder()
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
.addProvider(AuthHelper.getAuthFilter())
.addProvider(
new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class,
DisabledPermittedAuthenticatedAccount.class)))
.addProvider(new RateLimitExceededExceptionMapper())
.addProvider(new ImpossiblePhoneNumberExceptionMapper())
.addProvider(new NonNormalizedPhoneNumberExceptionMapper())
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(
new AccountControllerV2(accountsManager, changeNumberManager,
new PhoneVerificationTokenManager(registrationServiceClient, registrationRecoveryPasswordsManager),
registrationLockVerificationManager, rateLimiters))
.build();
@Nested
class ChangeNumber {
@BeforeEach
void setUp() throws Exception {
when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter);
when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any())).thenAnswer(
(Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0, Account.class);
final String number = invocation.getArgument(1, String.class);
final String pniIdentityKey = invocation.getArgument(2, String.class);
final UUID uuid = account.getUuid();
final List<Device> devices = account.getDevices();
final Account updatedAccount = mock(Account.class);
when(updatedAccount.getUuid()).thenReturn(uuid);
when(updatedAccount.getNumber()).thenReturn(number);
when(updatedAccount.getPhoneNumberIdentityKey()).thenReturn(pniIdentityKey);
when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(UUID.randomUUID());
when(updatedAccount.getDevices()).thenReturn(devices);
for (long i = 1; i <= 3; i++) {
final Optional<Device> d = account.getDevice(i);
when(updatedAccount.getDevice(i)).thenReturn(d);
}
return updatedAccount;
});
}
@Test
void changeNumberSuccess() throws Exception {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(new RegistrationSession(NEW_NUMBER, true))));
final AccountIdentityResponse accountIdentityResponse =
resources.getJerseyTest()
.target("/v2/accounts/number")
.request()
.header(HttpHeaders.AUTHORIZATION,
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(
new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", "123",
Collections.emptyList(),
Collections.emptyMap(), Collections.emptyMap()),
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(NEW_NUMBER), any(), any(), any(),
any());
assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid());
assertEquals(NEW_NUMBER, accountIdentityResponse.number());
assertNotEquals(AuthHelper.VALID_PNI, accountIdentityResponse.pni());
}
@Test
void unprocessableRequestJson() {
final Invocation.Builder request = resources.getJerseyTest()
.target("/v2/accounts/number")
.request()
.header(HttpHeaders.AUTHORIZATION,
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
try (Response response = request.put(Entity.json(unprocessableJson()))) {
assertEquals(400, response.getStatus());
}
}
@Test
void missingBasicAuthorization() {
final Invocation.Builder request = resources.getJerseyTest()
.target("/v2/accounts/number")
.request();
try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) {
assertEquals(401, response.getStatus());
}
}
@Test
void invalidBasicAuthorization() {
final Invocation.Builder request = resources.getJerseyTest()
.target("/v2/accounts/number")
.request()
.header(HttpHeaders.AUTHORIZATION, "Basic but-invalid");
try (Response response = request.put(Entity.json(invalidRequestJson()))) {
assertEquals(401, response.getStatus());
}
}
@Test
void invalidRequestBody() {
final Invocation.Builder request = resources.getJerseyTest()
.target("/v2/accounts/number")
.request()
.header(HttpHeaders.AUTHORIZATION,
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
try (Response response = request.put(Entity.json(invalidRequestJson()))) {
assertEquals(422, response.getStatus());
}
}
@Test
void rateLimitedNumber() throws Exception {
doThrow(new RateLimitExceededException(null, true))
.when(registrationLimiter).validate(anyString());
final Invocation.Builder request = resources.getJerseyTest()
.target("/v2/accounts/number")
.request()
.header(HttpHeaders.AUTHORIZATION,
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) {
assertEquals(429, response.getStatus());
}
}
@Test
void registrationServiceTimeout() {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v2/accounts/number")
.request()
.header(HttpHeaders.AUTHORIZATION,
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) {
assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus());
}
}
@ParameterizedTest
@MethodSource
void registrationServiceSessionCheck(@Nullable final RegistrationSession session, final int expectedStatus,
final String message) {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.ofNullable(session)));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v2/accounts/number")
.request()
.header(HttpHeaders.AUTHORIZATION,
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) {
assertEquals(expectedStatus, response.getStatus(), message);
}
}
static Stream<Arguments> registrationServiceSessionCheck() {
return Stream.of(
Arguments.of(null, 401, "session not found"),
Arguments.of(new RegistrationSession("+18005551234", false), 400, "session number mismatch"),
Arguments.of(new RegistrationSession(NEW_NUMBER, false), 401, "session not verified")
);
}
@ParameterizedTest
@EnumSource(RegistrationLockError.class)
void registrationLock(final RegistrationLockError error) throws Exception {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
CompletableFuture.completedFuture(Optional.of(new RegistrationSession(NEW_NUMBER, true))));
when(accountsManager.getByE164(any())).thenReturn(Optional.of(mock(Account.class)));
final Exception e = switch (error) {
case MISMATCH -> new WebApplicationException(error.getExpectedStatus());
case RATE_LIMITED -> new RateLimitExceededException(null, true);
};
doThrow(e)
.when(registrationLockVerificationManager).verifyRegistrationLock(any(), any());
final Invocation.Builder request = resources.getJerseyTest()
.target("/v2/accounts/number")
.request()
.header(HttpHeaders.AUTHORIZATION,
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) {
assertEquals(error.getExpectedStatus(), response.getStatus());
}
}
@Test
void recoveryPasswordManagerVerificationTrue() throws Exception {
when(registrationRecoveryPasswordsManager.verify(any(), any()))
.thenReturn(CompletableFuture.completedFuture(true));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v2/accounts/number")
.request()
.header(HttpHeaders.AUTHORIZATION,
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
final byte[] recoveryPassword = new byte[32];
try (Response response = request.put(Entity.json(requestJsonRecoveryPassword(recoveryPassword, NEW_NUMBER)))) {
assertEquals(200, response.getStatus());
final AccountIdentityResponse accountIdentityResponse = response.readEntity(AccountIdentityResponse.class);
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(NEW_NUMBER), any(), any(), any(),
any());
assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid());
assertEquals(NEW_NUMBER, accountIdentityResponse.number());
assertNotEquals(AuthHelper.VALID_PNI, accountIdentityResponse.pni());
}
}
@Test
void recoveryPasswordManagerVerificationFalse() {
when(registrationRecoveryPasswordsManager.verify(any(), any()))
.thenReturn(CompletableFuture.completedFuture(false));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v2/accounts/number")
.request()
.header(HttpHeaders.AUTHORIZATION,
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
try (Response response = request.put(Entity.json(requestJsonRecoveryPassword(new byte[32], NEW_NUMBER)))) {
assertEquals(403, response.getStatus());
}
}
/**
* Valid request JSON with the given Recovery Password
*/
private static String requestJsonRecoveryPassword(final byte[] recoveryPassword, final String newNumber) {
return requestJson("", recoveryPassword, newNumber);
}
/**
* Valid request JSON with the give session ID and recovery password
*/
private static String requestJson(final String sessionId, final byte[] recoveryPassword, final String newNumber) {
return String.format("""
{
"sessionId": "%s",
"recoveryPassword": "%s",
"number": "%s",
"reglock": "1234",
"pniIdentityKey": "5678",
"deviceMessages": [],
"devicePniSignedPrekeys": {},
"pniRegistrationIds": {}
}
""", encodeSessionId(sessionId), encodeRecoveryPassword(recoveryPassword), newNumber);
}
/**
* Valid request JSON with the give session ID
*/
private static String requestJson(final String sessionId, final String newNumber) {
return requestJson(sessionId, new byte[0], newNumber);
}
/**
* Request JSON in the shape of {@link org.whispersystems.textsecuregcm.entities.ChangeNumberRequest}, but that
* fails validation
*/
private static String invalidRequestJson() {
return """
{
"sessionId": null
}
""";
}
/**
* Request JSON that cannot be marshalled into
* {@link org.whispersystems.textsecuregcm.entities.ChangeNumberRequest}
*/
private static String unprocessableJson() {
return """
{
"sessionId": []
}
""";
}
private static String encodeSessionId(final String sessionId) {
return Base64.getUrlEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8));
}
private static String encodeRecoveryPassword(final byte[] recoveryPassword) {
return Base64.getEncoder().encodeToString(recoveryPassword);
}
}
}

View File

@ -86,7 +86,8 @@ class ChallengeControllerTest {
"""; """;
final Duration retryAfter = Duration.ofMinutes(17); final Duration retryAfter = Duration.ofMinutes(17);
doThrow(new RateLimitExceededException(retryAfter)).when(rateLimitChallengeManager).answerPushChallenge(any(), any()); doThrow(new RateLimitExceededException(retryAfter, true)).when(rateLimitChallengeManager)
.answerPushChallenge(any(), any());
final Response response = EXTENSION.target("/v1/challenge") final Response response = EXTENSION.target("/v1/challenge")
.request() .request()
@ -128,7 +129,8 @@ class ChallengeControllerTest {
"""; """;
final Duration retryAfter = Duration.ofMinutes(17); final Duration retryAfter = Duration.ofMinutes(17);
doThrow(new RateLimitExceededException(retryAfter)).when(rateLimitChallengeManager).answerRecaptchaChallenge(any(), any(), any(), any()); doThrow(new RateLimitExceededException(retryAfter, true)).when(rateLimitChallengeManager)
.answerRecaptchaChallenge(any(), any(), any(), any());
final Response response = EXTENSION.target("/v1/challenge") final Response response = EXTENSION.target("/v1/challenge")
.request() .request()

View File

@ -255,7 +255,8 @@ class ProfileControllerTest {
@Test @Test
void testProfileGetByAciRateLimited() throws RateLimitExceededException { void testProfileGetByAciRateLimited() throws RateLimitExceededException {
doThrow(new RateLimitExceededException(Duration.ofSeconds(13))).when(rateLimiter).validate(AuthHelper.VALID_UUID); doThrow(new RateLimitExceededException(Duration.ofSeconds(13), true)).when(rateLimiter)
.validate(AuthHelper.VALID_UUID);
Response response= resources.getJerseyTest() Response response= resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_UUID_TWO) .target("/v1/profile/" + AuthHelper.VALID_UUID_TWO)
@ -326,7 +327,8 @@ class ProfileControllerTest {
@Test @Test
void testProfileGetByPniRateLimited() throws RateLimitExceededException { void testProfileGetByPniRateLimited() throws RateLimitExceededException {
doThrow(new RateLimitExceededException(Duration.ofSeconds(13))).when(rateLimiter).validate(AuthHelper.VALID_UUID); doThrow(new RateLimitExceededException(Duration.ofSeconds(13), true)).when(rateLimiter)
.validate(AuthHelper.VALID_UUID);
Response response= resources.getJerseyTest() Response response= resources.getJerseyTest()
.target("/v1/profile/" + AuthHelper.VALID_PNI_TWO) .target("/v1/profile/" + AuthHelper.VALID_PNI_TWO)

View File

@ -1,9 +1,26 @@
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension; import io.dropwizard.testing.junit5.ResourceExtension;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Base64;
import java.util.UUID;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -20,25 +37,6 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress; import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Base64;
import java.util.UUID;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
class ProvisioningControllerTest { class ProvisioningControllerTest {
@ -101,7 +99,7 @@ class ProvisioningControllerTest {
final String destination = UUID.randomUUID().toString(); final String destination = UUID.randomUUID().toString();
final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8); final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8);
doThrow(new RateLimitExceededException(Duration.ZERO)) doThrow(new RateLimitExceededException(Duration.ZERO, true))
.when(messagesRateLimiter).validate(AuthHelper.VALID_UUID); .when(messagesRateLimiter).validate(AuthHelper.VALID_UUID);
try (final Response response = RESOURCE_EXTENSION.getJerseyTest() try (final Response response = RESOURCE_EXTENSION.getJerseyTest()

View File

@ -13,6 +13,7 @@ import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension; import io.dropwizard.testing.junit5.ResourceExtension;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
@ -37,6 +38,7 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
import org.whispersystems.textsecuregcm.auth.RegistrationLockError; import org.whispersystems.textsecuregcm.auth.RegistrationLockError;
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
@ -51,12 +53,18 @@ import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
class RegistrationControllerTest { class RegistrationControllerTest {
private static final String NUMBER = "+18005551212"; private static final String NUMBER = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
public static final String PASSWORD = "password";
private final AccountsManager accountsManager = mock(AccountsManager.class); private final AccountsManager accountsManager = mock(AccountsManager.class);
private final RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class); private final RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class);
private final RegistrationLockVerificationManager registrationLockVerificationManager = mock( private final RegistrationLockVerificationManager registrationLockVerificationManager = mock(
@ -66,7 +74,6 @@ class RegistrationControllerTest {
private final RateLimiters rateLimiters = mock(RateLimiters.class); private final RateLimiters rateLimiters = mock(RateLimiters.class);
private final RateLimiter registrationLimiter = mock(RateLimiter.class); private final RateLimiter registrationLimiter = mock(RateLimiter.class);
private final RateLimiter pinLimiter = mock(RateLimiter.class);
private final ResourceExtension resources = ResourceExtension.builder() private final ResourceExtension resources = ResourceExtension.builder()
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE) .addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
@ -76,14 +83,14 @@ class RegistrationControllerTest {
.setMapper(SystemMapper.getMapper()) .setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource( .addResource(
new RegistrationController(accountsManager, registrationServiceClient, registrationLockVerificationManager, new RegistrationController(accountsManager,
registrationRecoveryPasswordsManager, rateLimiters)) new PhoneVerificationTokenManager(registrationServiceClient, registrationRecoveryPasswordsManager),
registrationLockVerificationManager, rateLimiters))
.build(); .build();
@BeforeEach @BeforeEach
void setUp() { void setUp() {
when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter); when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter);
when(rateLimiters.getPinLimiter()).thenReturn(pinLimiter);
} }
@Test @Test
@ -130,25 +137,23 @@ class RegistrationControllerTest {
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration") .target("/v1/registration")
.request() .request()
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
try (Response response = request.post(Entity.json(invalidRequestJson()))) { try (Response response = request.post(Entity.json(invalidRequestJson()))) {
assertEquals(422, response.getStatus()); assertEquals(422, response.getStatus());
} }
} }
@Test @Test
void rateLimitedSession() throws Exception { void rateLimitedNumber() throws Exception {
final String sessionId = "sessionId";
doThrow(RateLimitExceededException.class) doThrow(RateLimitExceededException.class)
.when(registrationLimiter).validate(encodeSessionId(sessionId)); .when(registrationLimiter).validate(NUMBER);
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration") .target("/v1/registration")
.request() .request()
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
try (Response response = request.post(Entity.json(requestJson(sessionId)))) { try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
assertEquals(413, response.getStatus()); assertEquals(429, response.getStatus());
// In the future, change to assertEquals(429, response.getStatus());
} }
} }
@ -160,7 +165,7 @@ class RegistrationControllerTest {
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration") .target("/v1/registration")
.request() .request()
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
try (Response response = request.post(Entity.json(requestJson("sessionId")))) { try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus()); assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus());
} }
@ -174,7 +179,7 @@ class RegistrationControllerTest {
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration") .target("/v1/registration")
.request() .request()
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) { try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) {
assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus()); assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus());
} }
@ -190,7 +195,7 @@ class RegistrationControllerTest {
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration") .target("/v1/registration")
.request() .request()
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
try (Response response = request.post(Entity.json(requestJson("sessionId")))) { try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
assertEquals(expectedStatus, response.getStatus(), message); assertEquals(expectedStatus, response.getStatus(), message);
} }
@ -214,7 +219,7 @@ class RegistrationControllerTest {
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration") .target("/v1/registration")
.request() .request()
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
final byte[] recoveryPassword = new byte[32]; final byte[] recoveryPassword = new byte[32];
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(recoveryPassword)))) { try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(recoveryPassword)))) {
assertEquals(200, response.getStatus()); assertEquals(200, response.getStatus());
@ -229,7 +234,7 @@ class RegistrationControllerTest {
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration") .target("/v1/registration")
.request() .request()
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) { try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) {
assertEquals(403, response.getStatus()); assertEquals(403, response.getStatus());
} }
@ -245,7 +250,7 @@ class RegistrationControllerTest {
final Exception e = switch (error) { final Exception e = switch (error) {
case MISMATCH -> new WebApplicationException(error.getExpectedStatus()); case MISMATCH -> new WebApplicationException(error.getExpectedStatus());
case RATE_LIMITED -> new RateLimitExceededException(null); case RATE_LIMITED -> new RateLimitExceededException(null, true);
}; };
doThrow(e) doThrow(e)
.when(registrationLockVerificationManager).verifyRegistrationLock(any(), any()); .when(registrationLockVerificationManager).verifyRegistrationLock(any(), any());
@ -253,7 +258,7 @@ class RegistrationControllerTest {
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration") .target("/v1/registration")
.request() .request()
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
try (Response response = request.post(Entity.json(requestJson("sessionId")))) { try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
assertEquals(error.getExpectedStatus(), response.getStatus()); assertEquals(error.getExpectedStatus(), response.getStatus());
} }
@ -286,7 +291,7 @@ class RegistrationControllerTest {
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration") .target("/v1/registration")
.request() .request()
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
try (Response response = request.post(Entity.json(requestJson("sessionId", new byte[0], skipDeviceTransfer)))) { try (Response response = request.post(Entity.json(requestJson("sessionId", new byte[0], skipDeviceTransfer)))) {
assertEquals(expectedStatus, response.getStatus()); assertEquals(expectedStatus, response.getStatus());
} }
@ -294,7 +299,7 @@ class RegistrationControllerTest {
// this is functionally the same as deviceTransferAvailable(existingAccount=false) // this is functionally the same as deviceTransferAvailable(existingAccount=false)
@Test @Test
void success() throws Exception { void registrationSuccess() throws Exception {
when(registrationServiceClient.getSession(any(), any())) when(registrationServiceClient.getSession(any(), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(new RegistrationSession(NUMBER, true)))); .thenReturn(CompletableFuture.completedFuture(Optional.of(new RegistrationSession(NUMBER, true))));
when(accountsManager.create(any(), any(), any(), any(), any())) when(accountsManager.create(any(), any(), any(), any(), any()))
@ -303,7 +308,7 @@ class RegistrationControllerTest {
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration") .target("/v1/registration")
.request() .request()
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
try (Response response = request.post(Entity.json(requestJson("sessionId")))) { try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
assertEquals(200, response.getStatus()); assertEquals(200, response.getStatus());
} }
@ -365,13 +370,8 @@ class RegistrationControllerTest {
"""; """;
} }
private static String authorizationHeader(final String number) {
return "Basic " + Base64.getEncoder().encodeToString(
String.format("%s:password", number).getBytes(StandardCharsets.UTF_8));
}
private static String encodeSessionId(final String sessionId) { private static String encodeSessionId(final String sessionId) {
return Base64.getEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8)); return Base64.getUrlEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8));
} }
private static String encodeRecoveryPassword(final byte[] recoveryPassword) { private static String encodeRecoveryPassword(final byte[] recoveryPassword) {

View File

@ -72,11 +72,11 @@ public class RateLimitedByIpTest {
public void testRateLimits() throws Exception { public void testRateLimits() throws Exception {
Mockito.doNothing().when(RATE_LIMITER).validate(Mockito.eq(IP)); Mockito.doNothing().when(RATE_LIMITER).validate(Mockito.eq(IP));
validateSuccess("/test/strict", VALID_X_FORWARDED_FOR); validateSuccess("/test/strict", VALID_X_FORWARDED_FOR);
Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER)).when(RATE_LIMITER).validate(Mockito.eq(IP)); Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER, true)).when(RATE_LIMITER).validate(Mockito.eq(IP));
validateFailure("/test/strict", VALID_X_FORWARDED_FOR, RETRY_AFTER); validateFailure("/test/strict", VALID_X_FORWARDED_FOR, RETRY_AFTER);
Mockito.doNothing().when(RATE_LIMITER).validate(Mockito.eq(IP)); Mockito.doNothing().when(RATE_LIMITER).validate(Mockito.eq(IP));
validateSuccess("/test/strict", VALID_X_FORWARDED_FOR); validateSuccess("/test/strict", VALID_X_FORWARDED_FOR);
Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER)).when(RATE_LIMITER).validate(Mockito.eq(IP)); Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER, true)).when(RATE_LIMITER).validate(Mockito.eq(IP));
validateFailure("/test/strict", VALID_X_FORWARDED_FOR, RETRY_AFTER); validateFailure("/test/strict", VALID_X_FORWARDED_FOR, RETRY_AFTER);
} }
@ -92,7 +92,7 @@ public class RateLimitedByIpTest {
validateSuccess("/test/loose", ""); validateSuccess("/test/loose", "");
// also checking that even if rate limiter is failing -- it doesn't matter in the case of invalid IP // also checking that even if rate limiter is failing -- it doesn't matter in the case of invalid IP
Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER)).when(RATE_LIMITER).validate(Mockito.anyString()); Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER, true)).when(RATE_LIMITER).validate(Mockito.anyString());
validateFailure("/test/loose", VALID_X_FORWARDED_FOR, RETRY_AFTER); validateFailure("/test/loose", VALID_X_FORWARDED_FOR, RETRY_AFTER);
validateSuccess("/test/loose", INVALID_X_FORWARDED_FOR); validateSuccess("/test/loose", INVALID_X_FORWARDED_FOR);
validateSuccess("/test/loose", ""); validateSuccess("/test/loose", "");

View File

@ -340,7 +340,7 @@ class KeysControllerTest {
@Test @Test
void testGetKeysRateLimited() throws RateLimitExceededException { void testGetKeysRateLimited() throws RateLimitExceededException {
Duration retryAfter = Duration.ofSeconds(31); Duration retryAfter = Duration.ofSeconds(31);
doThrow(new RateLimitExceededException(retryAfter)).when(rateLimiter).validate(anyString()); doThrow(new RateLimitExceededException(retryAfter, true)).when(rateLimiter).validate(anyString());
Response result = resources.getJerseyTest() Response result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_PNI)) .target(String.format("/v2/keys/%s/*", EXISTS_PNI))

View File

@ -60,11 +60,12 @@ public final class MockUtils {
final RateLimiters rateLimitersMock, final RateLimiters rateLimitersMock,
final RateLimiters.Handle handle, final RateLimiters.Handle handle,
final String input, final String input,
final Duration retryAfter) { final Duration retryAfter,
final boolean legacyStatusCode) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class); final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(Optional.of(mockRateLimiter)).when(rateLimitersMock).byHandle(eq(handle)); doReturn(Optional.of(mockRateLimiter)).when(rateLimitersMock).byHandle(eq(handle));
try { try {
doThrow(new RateLimitExceededException(retryAfter)).when(mockRateLimiter).validate(eq(input)); doThrow(new RateLimitExceededException(retryAfter, legacyStatusCode)).when(mockRateLimiter).validate(eq(input));
} catch (final RateLimitExceededException e) { } catch (final RateLimitExceededException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }