Add `PUT /v2/account/number`
This commit is contained in:
parent
8fc465b3e8
commit
c16006dc4b
|
@ -76,6 +76,7 @@ import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
|
||||||
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator;
|
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator;
|
||||||
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
|
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
|
||||||
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
|
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
|
||||||
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
|
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
|
||||||
import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator;
|
import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator;
|
||||||
import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener;
|
import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener;
|
||||||
|
@ -87,6 +88,7 @@ import org.whispersystems.textsecuregcm.captcha.RecaptchaClient;
|
||||||
import org.whispersystems.textsecuregcm.configuration.DirectoryServerConfiguration;
|
import org.whispersystems.textsecuregcm.configuration.DirectoryServerConfiguration;
|
||||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||||
import org.whispersystems.textsecuregcm.controllers.AccountController;
|
import org.whispersystems.textsecuregcm.controllers.AccountController;
|
||||||
|
import org.whispersystems.textsecuregcm.controllers.AccountControllerV2;
|
||||||
import org.whispersystems.textsecuregcm.controllers.ArtController;
|
import org.whispersystems.textsecuregcm.controllers.ArtController;
|
||||||
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV2;
|
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV2;
|
||||||
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV3;
|
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV3;
|
||||||
|
@ -524,16 +526,22 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
|
|
||||||
final RegistrationLockVerificationManager registrationLockVerificationManager = new RegistrationLockVerificationManager(
|
final RegistrationLockVerificationManager registrationLockVerificationManager = new RegistrationLockVerificationManager(
|
||||||
accountsManager, clientPresenceManager, backupCredentialsGenerator, rateLimiters);
|
accountsManager, clientPresenceManager, backupCredentialsGenerator, rateLimiters);
|
||||||
|
final PhoneVerificationTokenManager phoneVerificationTokenManager = new PhoneVerificationTokenManager(
|
||||||
|
registrationServiceClient, registrationRecoveryPasswordsManager);
|
||||||
|
|
||||||
ReportedMessageMetricsListener reportedMessageMetricsListener = new ReportedMessageMetricsListener(accountsManager);
|
final ReportedMessageMetricsListener reportedMessageMetricsListener = new ReportedMessageMetricsListener(
|
||||||
|
accountsManager);
|
||||||
reportMessageManager.addListener(reportedMessageMetricsListener);
|
reportMessageManager.addListener(reportedMessageMetricsListener);
|
||||||
|
|
||||||
AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager);
|
final AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager);
|
||||||
DisabledPermittedAccountAuthenticator disabledPermittedAccountAuthenticator = new DisabledPermittedAccountAuthenticator(accountsManager);
|
final DisabledPermittedAccountAuthenticator disabledPermittedAccountAuthenticator = new DisabledPermittedAccountAuthenticator(
|
||||||
|
accountsManager);
|
||||||
|
|
||||||
MessageSender messageSender = new MessageSender(clientPresenceManager, messagesManager, pushNotificationManager, pushLatencyManager);
|
final MessageSender messageSender = new MessageSender(clientPresenceManager, messagesManager,
|
||||||
ReceiptSender receiptSender = new ReceiptSender(accountsManager, messageSender, receiptSenderExecutor);
|
pushNotificationManager,
|
||||||
TurnTokenGenerator turnTokenGenerator = new TurnTokenGenerator(dynamicConfigurationManager);
|
pushLatencyManager);
|
||||||
|
final ReceiptSender receiptSender = new ReceiptSender(accountsManager, messageSender, receiptSenderExecutor);
|
||||||
|
final TurnTokenGenerator turnTokenGenerator = new TurnTokenGenerator(dynamicConfigurationManager);
|
||||||
|
|
||||||
RecaptchaClient recaptchaClient = new RecaptchaClient(
|
RecaptchaClient recaptchaClient = new RecaptchaClient(
|
||||||
config.getRecaptchaConfiguration().getProjectPath(),
|
config.getRecaptchaConfiguration().getProjectPath(),
|
||||||
|
@ -731,6 +739,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
}
|
}
|
||||||
|
|
||||||
final List<Object> commonControllers = Lists.newArrayList(
|
final List<Object> commonControllers = Lists.newArrayList(
|
||||||
|
new AccountControllerV2(accountsManager, changeNumberManager, phoneVerificationTokenManager,
|
||||||
|
registrationLockVerificationManager, rateLimiters),
|
||||||
new ArtController(rateLimiters, artCredentialsGenerator),
|
new ArtController(rateLimiters, artCredentialsGenerator),
|
||||||
new AttachmentControllerV2(rateLimiters, config.getAwsAttachmentsConfiguration().getAccessKey(), config.getAwsAttachmentsConfiguration().getAccessSecret(), config.getAwsAttachmentsConfiguration().getRegion(), config.getAwsAttachmentsConfiguration().getBucket()),
|
new AttachmentControllerV2(rateLimiters, config.getAwsAttachmentsConfiguration().getAccessKey(), config.getAwsAttachmentsConfiguration().getAccessSecret(), config.getAwsAttachmentsConfiguration().getRegion(), config.getAwsAttachmentsConfiguration().getBucket()),
|
||||||
new AttachmentControllerV3(rateLimiters, config.getGcpAttachmentsConfiguration().getDomain(), config.getGcpAttachmentsConfiguration().getEmail(), config.getGcpAttachmentsConfiguration().getMaxSizeInBytes(), config.getGcpAttachmentsConfiguration().getPathPrefix(), config.getGcpAttachmentsConfiguration().getRsaSigningKey()),
|
new AttachmentControllerV3(rateLimiters, config.getGcpAttachmentsConfiguration().getDomain(), config.getGcpAttachmentsConfiguration().getEmail(), config.getGcpAttachmentsConfiguration().getMaxSizeInBytes(), config.getGcpAttachmentsConfiguration().getPathPrefix(), config.getGcpAttachmentsConfiguration().getRsaSigningKey()),
|
||||||
|
@ -748,8 +758,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
profileBadgeConverter, config.getBadges(), cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner,
|
profileBadgeConverter, config.getBadges(), cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner,
|
||||||
config.getCdnConfiguration().getBucket(), zkProfileOperations, batchIdentityCheckExecutor),
|
config.getCdnConfiguration().getBucket(), zkProfileOperations, batchIdentityCheckExecutor),
|
||||||
new ProvisioningController(rateLimiters, provisioningManager),
|
new ProvisioningController(rateLimiters, provisioningManager),
|
||||||
new RegistrationController(accountsManager, registrationServiceClient, registrationLockVerificationManager,
|
new RegistrationController(accountsManager, phoneVerificationTokenManager, registrationLockVerificationManager,
|
||||||
registrationRecoveryPasswordsManager, rateLimiters),
|
rateLimiters),
|
||||||
new RemoteConfigController(remoteConfigsManager, adminEventLogger,
|
new RemoteConfigController(remoteConfigsManager, adminEventLogger,
|
||||||
config.getRemoteConfigConfiguration().getAuthorizedTokens(),
|
config.getRemoteConfigConfiguration().getAuthorizedTokens(),
|
||||||
config.getRemoteConfigConfiguration().getGlobalConfig()),
|
config.getRemoteConfigConfiguration().getGlobalConfig()),
|
||||||
|
|
|
@ -0,0 +1,98 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2023 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.auth;
|
||||||
|
|
||||||
|
import java.security.MessageDigest;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.concurrent.CancellationException;
|
||||||
|
import java.util.concurrent.ExecutionException;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.concurrent.TimeoutException;
|
||||||
|
import javax.ws.rs.BadRequestException;
|
||||||
|
import javax.ws.rs.ForbiddenException;
|
||||||
|
import javax.ws.rs.NotAuthorizedException;
|
||||||
|
import javax.ws.rs.ServerErrorException;
|
||||||
|
import javax.ws.rs.core.Response;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest;
|
||||||
|
import org.whispersystems.textsecuregcm.entities.RegistrationSession;
|
||||||
|
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
|
||||||
|
|
||||||
|
public class PhoneVerificationTokenManager {
|
||||||
|
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(PhoneVerificationTokenManager.class);
|
||||||
|
private static final Duration REGISTRATION_RPC_TIMEOUT = Duration.ofSeconds(15);
|
||||||
|
private static final long VERIFICATION_TIMEOUT_SECONDS = REGISTRATION_RPC_TIMEOUT.plusSeconds(1).getSeconds();
|
||||||
|
|
||||||
|
private final RegistrationServiceClient registrationServiceClient;
|
||||||
|
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager;
|
||||||
|
|
||||||
|
public PhoneVerificationTokenManager(final RegistrationServiceClient registrationServiceClient,
|
||||||
|
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager) {
|
||||||
|
this.registrationServiceClient = registrationServiceClient;
|
||||||
|
this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if a {@link PhoneVerificationRequest} has a token that verifies the caller has confirmed access to the e164
|
||||||
|
* number
|
||||||
|
*
|
||||||
|
* @param number the e164 presented for verification
|
||||||
|
* @param request the request with exactly one verification token (RegistrationService sessionId or registration
|
||||||
|
* recovery password)
|
||||||
|
* @return if verification was successful, returns the verification type
|
||||||
|
* @throws BadRequestException if the number does not match the sessionId’s number
|
||||||
|
* @throws NotAuthorizedException if the session is not verified
|
||||||
|
* @throws ForbiddenException if the recovery password is not valid
|
||||||
|
* @throws InterruptedException if verification did not complete before a timeout
|
||||||
|
*/
|
||||||
|
public PhoneVerificationRequest.VerificationType verify(final String number, final PhoneVerificationRequest request)
|
||||||
|
throws InterruptedException {
|
||||||
|
|
||||||
|
final PhoneVerificationRequest.VerificationType verificationType = request.verificationType();
|
||||||
|
switch (verificationType) {
|
||||||
|
case SESSION -> verifyBySessionId(number, request.decodeSessionId());
|
||||||
|
case RECOVERY_PASSWORD -> verifyByRecoveryPassword(number, request.recoveryPassword());
|
||||||
|
}
|
||||||
|
|
||||||
|
return verificationType;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void verifyBySessionId(final String number, final byte[] sessionId) throws InterruptedException {
|
||||||
|
try {
|
||||||
|
final RegistrationSession session = registrationServiceClient
|
||||||
|
.getSession(sessionId, REGISTRATION_RPC_TIMEOUT)
|
||||||
|
.get(VERIFICATION_TIMEOUT_SECONDS, TimeUnit.SECONDS)
|
||||||
|
.orElseThrow(() -> new NotAuthorizedException("session not verified"));
|
||||||
|
|
||||||
|
if (!MessageDigest.isEqual(number.getBytes(), session.number().getBytes())) {
|
||||||
|
throw new BadRequestException("number does not match session");
|
||||||
|
}
|
||||||
|
if (!session.verified()) {
|
||||||
|
throw new NotAuthorizedException("session not verified");
|
||||||
|
}
|
||||||
|
} catch (final CancellationException | ExecutionException | TimeoutException e) {
|
||||||
|
logger.error("Registration service failure", e);
|
||||||
|
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void verifyByRecoveryPassword(final String number, final byte[] recoveryPassword)
|
||||||
|
throws InterruptedException {
|
||||||
|
try {
|
||||||
|
final boolean verified = registrationRecoveryPasswordsManager.verify(number, recoveryPassword)
|
||||||
|
.get(VERIFICATION_TIMEOUT_SECONDS, TimeUnit.SECONDS);
|
||||||
|
if (!verified) {
|
||||||
|
throw new ForbiddenException("recoveryPassword couldn't be verified");
|
||||||
|
}
|
||||||
|
} catch (final ExecutionException | TimeoutException e) {
|
||||||
|
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -798,7 +798,7 @@ public class AccountController {
|
||||||
// This shouldn't happen, so conservatively assume we're over the rate-limit
|
// This shouldn't happen, so conservatively assume we're over the rate-limit
|
||||||
// and indicate that the client should retry
|
// and indicate that the client should retry
|
||||||
logger.error("Missing/bad Forwarded-For: {}", forwardedFor);
|
logger.error("Missing/bad Forwarded-For: {}", forwardedFor);
|
||||||
return new RateLimitExceededException(Duration.ofHours(1));
|
return new RateLimitExceededException(Duration.ofHours(1), true);
|
||||||
});
|
});
|
||||||
|
|
||||||
rateLimiter.validate(mostRecentProxy);
|
rateLimiter.validate(mostRecentProxy);
|
||||||
|
|
|
@ -0,0 +1,132 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2023 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.controllers;
|
||||||
|
|
||||||
|
import static com.codahale.metrics.MetricRegistry.name;
|
||||||
|
|
||||||
|
import com.codahale.metrics.annotation.Timed;
|
||||||
|
import com.google.common.net.HttpHeaders;
|
||||||
|
import io.dropwizard.auth.Auth;
|
||||||
|
import io.micrometer.core.instrument.Metrics;
|
||||||
|
import io.micrometer.core.instrument.Tag;
|
||||||
|
import io.micrometer.core.instrument.Tags;
|
||||||
|
import java.util.Optional;
|
||||||
|
import javax.validation.Valid;
|
||||||
|
import javax.validation.constraints.NotNull;
|
||||||
|
import javax.ws.rs.BadRequestException;
|
||||||
|
import javax.ws.rs.Consumes;
|
||||||
|
import javax.ws.rs.ForbiddenException;
|
||||||
|
import javax.ws.rs.HeaderParam;
|
||||||
|
import javax.ws.rs.PUT;
|
||||||
|
import javax.ws.rs.Path;
|
||||||
|
import javax.ws.rs.Produces;
|
||||||
|
import javax.ws.rs.WebApplicationException;
|
||||||
|
import javax.ws.rs.core.MediaType;
|
||||||
|
import javax.ws.rs.core.Response;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
|
||||||
|
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
|
||||||
|
import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest;
|
||||||
|
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
|
||||||
|
import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest;
|
||||||
|
import org.whispersystems.textsecuregcm.entities.StaleDevices;
|
||||||
|
import org.whispersystems.textsecuregcm.limits.RateLimiter;
|
||||||
|
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||||
|
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.Account;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.ChangeNumberManager;
|
||||||
|
|
||||||
|
@Path("/v2/accounts")
|
||||||
|
public class AccountControllerV2 {
|
||||||
|
|
||||||
|
private static final String CHANGE_NUMBER_COUNTER_NAME = name(AccountControllerV2.class, "create");
|
||||||
|
private static final String VERIFICATION_TYPE_TAG_NAME = "verification";
|
||||||
|
|
||||||
|
private final AccountsManager accountsManager;
|
||||||
|
private final ChangeNumberManager changeNumberManager;
|
||||||
|
private final PhoneVerificationTokenManager phoneVerificationTokenManager;
|
||||||
|
private final RegistrationLockVerificationManager registrationLockVerificationManager;
|
||||||
|
private final RateLimiters rateLimiters;
|
||||||
|
|
||||||
|
public AccountControllerV2(final AccountsManager accountsManager, final ChangeNumberManager changeNumberManager,
|
||||||
|
final PhoneVerificationTokenManager phoneVerificationTokenManager,
|
||||||
|
final RegistrationLockVerificationManager registrationLockVerificationManager, final RateLimiters rateLimiters) {
|
||||||
|
this.accountsManager = accountsManager;
|
||||||
|
this.changeNumberManager = changeNumberManager;
|
||||||
|
this.phoneVerificationTokenManager = phoneVerificationTokenManager;
|
||||||
|
this.registrationLockVerificationManager = registrationLockVerificationManager;
|
||||||
|
this.rateLimiters = rateLimiters;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Timed
|
||||||
|
@PUT
|
||||||
|
@Path("/number")
|
||||||
|
@Consumes(MediaType.APPLICATION_JSON)
|
||||||
|
@Produces(MediaType.APPLICATION_JSON)
|
||||||
|
public AccountIdentityResponse changeNumber(@Auth final AuthenticatedAccount authenticatedAccount,
|
||||||
|
@NotNull @Valid final ChangeNumberRequest request, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent)
|
||||||
|
throws RateLimitExceededException, InterruptedException {
|
||||||
|
|
||||||
|
if (!authenticatedAccount.getAuthenticatedDevice().isMaster()) {
|
||||||
|
throw new ForbiddenException();
|
||||||
|
}
|
||||||
|
|
||||||
|
final String number = request.number();
|
||||||
|
|
||||||
|
// Only verify and check reglock if there's a data change to be made...
|
||||||
|
if (!authenticatedAccount.getAccount().getNumber().equals(number)) {
|
||||||
|
|
||||||
|
RateLimiter.adaptLegacyException(() -> rateLimiters.getRegistrationLimiter().validate(number));
|
||||||
|
|
||||||
|
final PhoneVerificationRequest.VerificationType verificationType = phoneVerificationTokenManager.verify(number,
|
||||||
|
request);
|
||||||
|
|
||||||
|
final Optional<Account> existingAccount = accountsManager.getByE164(number);
|
||||||
|
|
||||||
|
if (existingAccount.isPresent()) {
|
||||||
|
registrationLockVerificationManager.verifyRegistrationLock(existingAccount.get(), request.registrationLock());
|
||||||
|
}
|
||||||
|
|
||||||
|
Metrics.counter(CHANGE_NUMBER_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
|
||||||
|
Tag.of(VERIFICATION_TYPE_TAG_NAME, verificationType.name())))
|
||||||
|
.increment();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ...but always attempt to make the change in case a client retries and needs to re-send messages
|
||||||
|
try {
|
||||||
|
final Account updatedAccount = changeNumberManager.changeNumber(
|
||||||
|
authenticatedAccount.getAccount(),
|
||||||
|
request.number(),
|
||||||
|
request.pniIdentityKey(),
|
||||||
|
request.devicePniSignedPrekeys(),
|
||||||
|
request.deviceMessages(),
|
||||||
|
request.pniRegistrationIds());
|
||||||
|
|
||||||
|
return new AccountIdentityResponse(
|
||||||
|
updatedAccount.getUuid(),
|
||||||
|
updatedAccount.getNumber(),
|
||||||
|
updatedAccount.getPhoneNumberIdentifier(),
|
||||||
|
updatedAccount.getUsernameHash().orElse(null),
|
||||||
|
updatedAccount.isStorageSupported());
|
||||||
|
} catch (MismatchedDevicesException e) {
|
||||||
|
throw new WebApplicationException(Response.status(409)
|
||||||
|
.type(MediaType.APPLICATION_JSON_TYPE)
|
||||||
|
.entity(new MismatchedDevices(e.getMissingDevices(),
|
||||||
|
e.getExtraDevices()))
|
||||||
|
.build());
|
||||||
|
} catch (StaleDevicesException e) {
|
||||||
|
throw new WebApplicationException(Response.status(410)
|
||||||
|
.type(MediaType.APPLICATION_JSON)
|
||||||
|
.entity(new StaleDevices(e.getStaleDevices()))
|
||||||
|
.build());
|
||||||
|
} catch (IllegalArgumentException e) {
|
||||||
|
throw new BadRequestException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -10,20 +10,27 @@ import javax.annotation.Nullable;
|
||||||
|
|
||||||
public class RateLimitExceededException extends Exception {
|
public class RateLimitExceededException extends Exception {
|
||||||
|
|
||||||
private final @Nullable
|
@Nullable
|
||||||
Duration retryDuration;
|
private final Duration retryDuration;
|
||||||
|
private final boolean legacy;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Constructs a new exception indicating when it may become safe to retry
|
* Constructs a new exception indicating when it may become safe to retry
|
||||||
*
|
*
|
||||||
* @param retryDuration A duration to wait before retrying, null if no duration can be indicated
|
* @param retryDuration A duration to wait before retrying, null if no duration can be indicated
|
||||||
|
* @param legacy whether to use a legacy status code when mapping the exception to an HTTP response
|
||||||
*/
|
*/
|
||||||
public RateLimitExceededException(final @Nullable Duration retryDuration) {
|
public RateLimitExceededException(@Nullable final Duration retryDuration, final boolean legacy) {
|
||||||
super(null, null, true, false);
|
super(null, null, true, false);
|
||||||
this.retryDuration = retryDuration;
|
this.retryDuration = retryDuration;
|
||||||
|
this.legacy = legacy;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Optional<Duration> getRetryDuration() {
|
public Optional<Duration> getRetryDuration() {
|
||||||
return Optional.ofNullable(retryDuration);
|
return Optional.ofNullable(retryDuration);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean isLegacy() {
|
||||||
|
return legacy;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,42 +13,33 @@ import io.micrometer.core.instrument.DistributionSummary;
|
||||||
import io.micrometer.core.instrument.Metrics;
|
import io.micrometer.core.instrument.Metrics;
|
||||||
import io.micrometer.core.instrument.Tag;
|
import io.micrometer.core.instrument.Tag;
|
||||||
import io.micrometer.core.instrument.Tags;
|
import io.micrometer.core.instrument.Tags;
|
||||||
import java.security.MessageDigest;
|
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.concurrent.CancellationException;
|
|
||||||
import java.util.concurrent.ExecutionException;
|
|
||||||
import java.util.concurrent.TimeUnit;
|
|
||||||
import java.util.concurrent.TimeoutException;
|
|
||||||
import javax.validation.Valid;
|
import javax.validation.Valid;
|
||||||
import javax.validation.constraints.NotNull;
|
import javax.validation.constraints.NotNull;
|
||||||
import javax.ws.rs.BadRequestException;
|
|
||||||
import javax.ws.rs.Consumes;
|
import javax.ws.rs.Consumes;
|
||||||
import javax.ws.rs.ForbiddenException;
|
|
||||||
import javax.ws.rs.HeaderParam;
|
import javax.ws.rs.HeaderParam;
|
||||||
import javax.ws.rs.NotAuthorizedException;
|
|
||||||
import javax.ws.rs.POST;
|
import javax.ws.rs.POST;
|
||||||
import javax.ws.rs.Path;
|
import javax.ws.rs.Path;
|
||||||
import javax.ws.rs.Produces;
|
import javax.ws.rs.Produces;
|
||||||
import javax.ws.rs.ServerErrorException;
|
|
||||||
import javax.ws.rs.WebApplicationException;
|
import javax.ws.rs.WebApplicationException;
|
||||||
import javax.ws.rs.core.MediaType;
|
import javax.ws.rs.core.MediaType;
|
||||||
import javax.ws.rs.core.Response;
|
import javax.ws.rs.core.Response;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader;
|
import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
|
||||||
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
|
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
|
||||||
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
|
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
|
||||||
|
import org.whispersystems.textsecuregcm.entities.PhoneVerificationRequest;
|
||||||
import org.whispersystems.textsecuregcm.entities.RegistrationRequest;
|
import org.whispersystems.textsecuregcm.entities.RegistrationRequest;
|
||||||
import org.whispersystems.textsecuregcm.entities.RegistrationSession;
|
import org.whispersystems.textsecuregcm.limits.RateLimiter;
|
||||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||||
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
|
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
|
||||||
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
|
|
||||||
import org.whispersystems.textsecuregcm.storage.Account;
|
import org.whispersystems.textsecuregcm.storage.Account;
|
||||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||||
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
|
|
||||||
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
||||||
import org.whispersystems.textsecuregcm.util.Util;
|
import org.whispersystems.textsecuregcm.util.Util;
|
||||||
|
|
||||||
|
@ -68,24 +59,17 @@ public class RegistrationController {
|
||||||
private static final String REGION_CODE_TAG_NAME = "regionCode";
|
private static final String REGION_CODE_TAG_NAME = "regionCode";
|
||||||
private static final String VERIFICATION_TYPE_TAG_NAME = "verification";
|
private static final String VERIFICATION_TYPE_TAG_NAME = "verification";
|
||||||
|
|
||||||
private static final Duration REGISTRATION_RPC_TIMEOUT = Duration.ofSeconds(15);
|
|
||||||
private static final long VERIFICATION_TIMEOUT_SECONDS = REGISTRATION_RPC_TIMEOUT.plusSeconds(1).getSeconds();
|
|
||||||
|
|
||||||
private final AccountsManager accounts;
|
private final AccountsManager accounts;
|
||||||
private final RegistrationServiceClient registrationServiceClient;
|
private final PhoneVerificationTokenManager phoneVerificationTokenManager;
|
||||||
private final RegistrationLockVerificationManager registrationLockVerificationManager;
|
private final RegistrationLockVerificationManager registrationLockVerificationManager;
|
||||||
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager;
|
|
||||||
private final RateLimiters rateLimiters;
|
private final RateLimiters rateLimiters;
|
||||||
|
|
||||||
public RegistrationController(final AccountsManager accounts,
|
public RegistrationController(final AccountsManager accounts,
|
||||||
final RegistrationServiceClient registrationServiceClient,
|
final PhoneVerificationTokenManager phoneVerificationTokenManager,
|
||||||
final RegistrationLockVerificationManager registrationLockVerificationManager,
|
final RegistrationLockVerificationManager registrationLockVerificationManager, final RateLimiters rateLimiters) {
|
||||||
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager,
|
|
||||||
final RateLimiters rateLimiters) {
|
|
||||||
this.accounts = accounts;
|
this.accounts = accounts;
|
||||||
this.registrationServiceClient = registrationServiceClient;
|
this.phoneVerificationTokenManager = phoneVerificationTokenManager;
|
||||||
this.registrationLockVerificationManager = registrationLockVerificationManager;
|
this.registrationLockVerificationManager = registrationLockVerificationManager;
|
||||||
this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager;
|
|
||||||
this.rateLimiters = rateLimiters;
|
this.rateLimiters = rateLimiters;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,17 +83,13 @@ public class RegistrationController {
|
||||||
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
|
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
|
||||||
@NotNull @Valid final RegistrationRequest registrationRequest) throws RateLimitExceededException, InterruptedException {
|
@NotNull @Valid final RegistrationRequest registrationRequest) throws RateLimitExceededException, InterruptedException {
|
||||||
|
|
||||||
rateLimiters.getRegistrationLimiter().validate(registrationRequest.sessionId());
|
|
||||||
|
|
||||||
final String number = authorizationHeader.getUsername();
|
final String number = authorizationHeader.getUsername();
|
||||||
final String password = authorizationHeader.getPassword();
|
final String password = authorizationHeader.getPassword();
|
||||||
|
|
||||||
// decide on the method of verification based on the registration request parameters and verify
|
RateLimiter.adaptLegacyException(() -> rateLimiters.getRegistrationLimiter().validate(number));
|
||||||
final RegistrationRequest.VerificationType verificationType = registrationRequest.verificationType();
|
|
||||||
switch (verificationType) {
|
final PhoneVerificationRequest.VerificationType verificationType = phoneVerificationTokenManager.verify(number,
|
||||||
case SESSION -> verifyBySessionId(number, registrationRequest.decodeSessionId());
|
registrationRequest);
|
||||||
case RECOVERY_PASSWORD -> verifyByRecoveryPassword(number, registrationRequest.recoveryPassword());
|
|
||||||
}
|
|
||||||
|
|
||||||
final Optional<Account> existingAccount = accounts.getByE164(number);
|
final Optional<Account> existingAccount = accounts.getByE164(number);
|
||||||
|
|
||||||
|
@ -146,34 +126,4 @@ public class RegistrationController {
|
||||||
existingAccount.map(Account::isStorageSupported).orElse(false));
|
existingAccount.map(Account::isStorageSupported).orElse(false));
|
||||||
}
|
}
|
||||||
|
|
||||||
private void verifyBySessionId(final String number, final byte[] sessionId) throws InterruptedException {
|
|
||||||
try {
|
|
||||||
final RegistrationSession session = registrationServiceClient
|
|
||||||
.getSession(sessionId, REGISTRATION_RPC_TIMEOUT)
|
|
||||||
.get(VERIFICATION_TIMEOUT_SECONDS, TimeUnit.SECONDS)
|
|
||||||
.orElseThrow(() -> new NotAuthorizedException("session not verified"));
|
|
||||||
|
|
||||||
if (!MessageDigest.isEqual(number.getBytes(), session.number().getBytes())) {
|
|
||||||
throw new BadRequestException("number does not match session");
|
|
||||||
}
|
|
||||||
if (!session.verified()) {
|
|
||||||
throw new NotAuthorizedException("session not verified");
|
|
||||||
}
|
|
||||||
} catch (final CancellationException | ExecutionException | TimeoutException e) {
|
|
||||||
logger.error("Registration service failure", e);
|
|
||||||
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void verifyByRecoveryPassword(final String number, final byte[] recoveryPassword) throws InterruptedException {
|
|
||||||
try {
|
|
||||||
final boolean verified = registrationRecoveryPasswordsManager.verify(number, recoveryPassword)
|
|
||||||
.get(VERIFICATION_TIMEOUT_SECONDS, TimeUnit.SECONDS);
|
|
||||||
if (!verified) {
|
|
||||||
throw new ForbiddenException("recoveryPassword couldn't be verified");
|
|
||||||
}
|
|
||||||
} catch (final ExecutionException | TimeoutException e) {
|
|
||||||
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2023 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.entities;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import javax.annotation.Nullable;
|
||||||
|
import javax.validation.Valid;
|
||||||
|
import javax.validation.constraints.NotBlank;
|
||||||
|
import javax.validation.constraints.NotNull;
|
||||||
|
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
|
||||||
|
|
||||||
|
public record ChangeNumberRequest(String sessionId,
|
||||||
|
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) byte[] recoveryPassword,
|
||||||
|
@NotBlank String number,
|
||||||
|
@JsonProperty("reglock") @Nullable String registrationLock,
|
||||||
|
@NotBlank String pniIdentityKey,
|
||||||
|
@NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages,
|
||||||
|
@NotNull @Valid Map<Long, @NotNull @Valid SignedPreKey> devicePniSignedPrekeys,
|
||||||
|
@NotNull Map<Long, Integer> pniRegistrationIds) implements PhoneVerificationRequest {
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2023 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.entities;
|
||||||
|
|
||||||
|
import static org.apache.commons.lang3.StringUtils.isNotBlank;
|
||||||
|
|
||||||
|
import java.util.Base64;
|
||||||
|
import javax.validation.constraints.AssertTrue;
|
||||||
|
import javax.ws.rs.ClientErrorException;
|
||||||
|
import org.apache.http.HttpStatus;
|
||||||
|
|
||||||
|
public interface PhoneVerificationRequest {
|
||||||
|
|
||||||
|
enum VerificationType {
|
||||||
|
SESSION,
|
||||||
|
RECOVERY_PASSWORD
|
||||||
|
}
|
||||||
|
|
||||||
|
String sessionId();
|
||||||
|
|
||||||
|
byte[] recoveryPassword();
|
||||||
|
|
||||||
|
// for the @AssertTrue to work with bean validation, method name must follow 'isSmth()'/'getSmth()' naming convention
|
||||||
|
@AssertTrue
|
||||||
|
default boolean isValid() {
|
||||||
|
// checking that exactly one of sessionId/recoveryPassword is non-empty
|
||||||
|
return isNotBlank(sessionId()) ^ (recoveryPassword() != null && recoveryPassword().length > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
default PhoneVerificationRequest.VerificationType verificationType() {
|
||||||
|
return isNotBlank(sessionId()) ? PhoneVerificationRequest.VerificationType.SESSION
|
||||||
|
: PhoneVerificationRequest.VerificationType.RECOVERY_PASSWORD;
|
||||||
|
}
|
||||||
|
|
||||||
|
default byte[] decodeSessionId() {
|
||||||
|
try {
|
||||||
|
return Base64.getUrlDecoder().decode(sessionId());
|
||||||
|
} catch (final IllegalArgumentException e) {
|
||||||
|
throw new ClientErrorException("Malformed session ID", HttpStatus.SC_UNPROCESSABLE_ENTITY);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -5,43 +5,15 @@
|
||||||
|
|
||||||
package org.whispersystems.textsecuregcm.entities;
|
package org.whispersystems.textsecuregcm.entities;
|
||||||
|
|
||||||
import static org.apache.commons.lang3.StringUtils.isNotBlank;
|
|
||||||
|
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
||||||
import java.util.Base64;
|
|
||||||
import javax.validation.Valid;
|
import javax.validation.Valid;
|
||||||
import javax.validation.constraints.AssertTrue;
|
|
||||||
import javax.validation.constraints.NotNull;
|
import javax.validation.constraints.NotNull;
|
||||||
import javax.ws.rs.ClientErrorException;
|
|
||||||
import org.apache.http.HttpStatus;
|
|
||||||
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
|
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
|
||||||
|
|
||||||
public record RegistrationRequest(String sessionId,
|
public record RegistrationRequest(String sessionId,
|
||||||
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) byte[] recoveryPassword,
|
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) byte[] recoveryPassword,
|
||||||
@NotNull @Valid AccountAttributes accountAttributes,
|
@NotNull @Valid AccountAttributes accountAttributes,
|
||||||
boolean skipDeviceTransfer) {
|
boolean skipDeviceTransfer) implements PhoneVerificationRequest {
|
||||||
|
|
||||||
public enum VerificationType {
|
|
||||||
SESSION,
|
|
||||||
RECOVERY_PASSWORD
|
|
||||||
}
|
|
||||||
|
|
||||||
// for the @AssertTrue to work with bean validation, method name must follow 'isSmth()'/'getSmth()' naming convention
|
|
||||||
@AssertTrue
|
|
||||||
public boolean isValid() {
|
|
||||||
// checking that exactly one of sessionId/recoveryPassword is non-empty
|
|
||||||
return isNotBlank(sessionId) ^ (recoveryPassword != null && recoveryPassword.length > 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
public VerificationType verificationType() {
|
|
||||||
return isNotBlank(sessionId) ? VerificationType.SESSION : VerificationType.RECOVERY_PASSWORD;
|
|
||||||
}
|
|
||||||
|
|
||||||
public byte[] decodeSessionId() {
|
|
||||||
try {
|
|
||||||
return Base64.getUrlDecoder().decode(sessionId());
|
|
||||||
} catch (final IllegalArgumentException e) {
|
|
||||||
throw new ClientErrorException("Malformed session ID", HttpStatus.SC_UNPROCESSABLE_ENTITY);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ public class LockingRateLimiter extends RateLimiter {
|
||||||
public void validate(String key, int amount) throws RateLimitExceededException {
|
public void validate(String key, int amount) throws RateLimitExceededException {
|
||||||
if (!acquireLock(key)) {
|
if (!acquireLock(key)) {
|
||||||
meter.mark();
|
meter.mark();
|
||||||
throw new RateLimitExceededException(Duration.ZERO);
|
throw new RateLimitExceededException(Duration.ZERO, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|
|
@ -29,7 +29,8 @@ public class RateLimitByIpFilter implements ContainerRequestFilter {
|
||||||
private static final Logger logger = LoggerFactory.getLogger(RateLimitByIpFilter.class);
|
private static final Logger logger = LoggerFactory.getLogger(RateLimitByIpFilter.class);
|
||||||
|
|
||||||
@VisibleForTesting
|
@VisibleForTesting
|
||||||
static final RateLimitExceededException INVALID_HEADER_EXCEPTION = new RateLimitExceededException(Duration.ofHours(1));
|
static final RateLimitExceededException INVALID_HEADER_EXCEPTION = new RateLimitExceededException(Duration.ofHours(1),
|
||||||
|
true);
|
||||||
|
|
||||||
private static final ExceptionMapper<RateLimitExceededException> EXCEPTION_MAPPER = new RateLimitExceededExceptionMapper();
|
private static final ExceptionMapper<RateLimitExceededException> EXCEPTION_MAPPER = new RateLimitExceededExceptionMapper();
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,7 @@ public class RateLimiter {
|
||||||
setBucket(key, bucket);
|
setBucket(key, bucket);
|
||||||
} else {
|
} else {
|
||||||
meter.mark();
|
meter.mark();
|
||||||
throw new RateLimitExceededException(bucket.getTimeUntilSpaceAvailable(amount));
|
throw new RateLimitExceededException(bucket.getTimeUntilSpaceAvailable(amount), true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -132,4 +132,22 @@ public class RateLimiter {
|
||||||
public boolean hasConfiguration(final RateLimitConfiguration configuration) {
|
public boolean hasConfiguration(final RateLimitConfiguration configuration) {
|
||||||
return bucketSize == configuration.getBucketSize() && leakRatePerMinute == configuration.getLeakRatePerMinute();
|
return bucketSize == configuration.getBucketSize() && leakRatePerMinute == configuration.getLeakRatePerMinute();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If the wrapped {@code validate()} call throws a {@link RateLimitExceededException}, it will adapt it to ensure that
|
||||||
|
* {@link RateLimitExceededException#isLegacy()} returns {@code true}
|
||||||
|
*/
|
||||||
|
public static void adaptLegacyException(final RateLimitValidator validator) throws RateLimitExceededException {
|
||||||
|
try {
|
||||||
|
validator.validate();
|
||||||
|
} catch (final RateLimitExceededException e) {
|
||||||
|
throw new RateLimitExceededException(e.getRetryDuration().orElse(null), false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@FunctionalInterface
|
||||||
|
public interface RateLimitValidator {
|
||||||
|
|
||||||
|
void validate() throws RateLimitExceededException;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,36 +4,41 @@
|
||||||
*/
|
*/
|
||||||
package org.whispersystems.textsecuregcm.mappers;
|
package org.whispersystems.textsecuregcm.mappers;
|
||||||
|
|
||||||
|
import javax.ws.rs.core.Response;
|
||||||
|
import javax.ws.rs.ext.ExceptionMapper;
|
||||||
|
import javax.ws.rs.ext.Provider;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||||
|
|
||||||
import javax.ws.rs.core.Response;
|
|
||||||
import javax.ws.rs.ext.ExceptionMapper;
|
|
||||||
import javax.ws.rs.ext.Provider;
|
|
||||||
|
|
||||||
@Provider
|
@Provider
|
||||||
public class RateLimitExceededExceptionMapper implements ExceptionMapper<RateLimitExceededException> {
|
public class RateLimitExceededExceptionMapper implements ExceptionMapper<RateLimitExceededException> {
|
||||||
|
|
||||||
private static final Logger logger = LoggerFactory.getLogger(RateLimitExceededExceptionMapper.class);
|
private static final Logger logger = LoggerFactory.getLogger(RateLimitExceededExceptionMapper.class);
|
||||||
|
|
||||||
|
private static final int LEGACY_STATUS_CODE = 413;
|
||||||
|
private static final int STATUS_CODE = 429;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert a RateLimitExceededException to a 413 response with a
|
* Convert a RateLimitExceededException to a {@value STATUS_CODE} (or legacy {@value LEGACY_STATUS_CODE}) response
|
||||||
* Retry-After header.
|
* with a Retry-After header.
|
||||||
*
|
*
|
||||||
* @param e A RateLimitExceededException potentially containing a recommended retry duration
|
* @param e A RateLimitExceededException potentially containing a recommended retry duration
|
||||||
* @return the response
|
* @return the response
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public Response toResponse(RateLimitExceededException e) {
|
public Response toResponse(RateLimitExceededException e) {
|
||||||
|
final int statusCode = e.isLegacy() ? LEGACY_STATUS_CODE : STATUS_CODE;
|
||||||
return e.getRetryDuration()
|
return e.getRetryDuration()
|
||||||
.filter(d -> {
|
.filter(d -> {
|
||||||
if (d.isNegative()) {
|
if (d.isNegative()) {
|
||||||
logger.warn("Encountered a negative retry duration: {}, will not include a Retry-After header in response", d);
|
logger.warn("Encountered a negative retry duration: {}, will not include a Retry-After header in response",
|
||||||
|
d);
|
||||||
}
|
}
|
||||||
// only include non-negative durations in retry headers
|
// only include non-negative durations in retry headers
|
||||||
return !d.isNegative();
|
return !d.isNegative();
|
||||||
})
|
})
|
||||||
.map(d -> Response.status(413).header("Retry-After", d.toSeconds()))
|
.map(d -> Response.status(statusCode).header("Retry-After", d.toSeconds()))
|
||||||
.orElseGet(() -> Response.status(413)).build();
|
.orElseGet(() -> Response.status(statusCode)).build();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,9 +89,12 @@ public class RegistrationServiceClient implements Managed {
|
||||||
|
|
||||||
case ERROR -> {
|
case ERROR -> {
|
||||||
switch (response.getError().getErrorType()) {
|
switch (response.getError().getErrorType()) {
|
||||||
case CREATE_REGISTRATION_SESSION_ERROR_TYPE_RATE_LIMITED -> throw new CompletionException(new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds())));
|
case CREATE_REGISTRATION_SESSION_ERROR_TYPE_RATE_LIMITED -> throw new CompletionException(
|
||||||
|
new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds()),
|
||||||
|
true));
|
||||||
case CREATE_REGISTRATION_SESSION_ERROR_TYPE_ILLEGAL_PHONE_NUMBER -> throw new IllegalArgumentException();
|
case CREATE_REGISTRATION_SESSION_ERROR_TYPE_ILLEGAL_PHONE_NUMBER -> throw new IllegalArgumentException();
|
||||||
default -> throw new RuntimeException("Unrecognized error type from registration service: " + response.getError().getErrorType());
|
default -> throw new RuntimeException(
|
||||||
|
"Unrecognized error type from registration service: " + response.getError().getErrorType());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,8 +122,9 @@ public class RegistrationServiceClient implements Managed {
|
||||||
.thenApply(response -> {
|
.thenApply(response -> {
|
||||||
if (response.hasError()) {
|
if (response.hasError()) {
|
||||||
switch (response.getError().getErrorType()) {
|
switch (response.getError().getErrorType()) {
|
||||||
case SEND_VERIFICATION_CODE_ERROR_TYPE_RATE_LIMITED ->
|
case SEND_VERIFICATION_CODE_ERROR_TYPE_RATE_LIMITED -> throw new CompletionException(
|
||||||
throw new CompletionException(new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds())));
|
new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds()),
|
||||||
|
true));
|
||||||
|
|
||||||
default -> throw new CompletionException(new RuntimeException("Failed to send verification code: " + response.getError().getErrorType()));
|
default -> throw new CompletionException(new RuntimeException("Failed to send verification code: " + response.getError().getErrorType()));
|
||||||
}
|
}
|
||||||
|
@ -142,8 +146,9 @@ public class RegistrationServiceClient implements Managed {
|
||||||
.thenApply(response -> {
|
.thenApply(response -> {
|
||||||
if (response.hasError()) {
|
if (response.hasError()) {
|
||||||
switch (response.getError().getErrorType()) {
|
switch (response.getError().getErrorType()) {
|
||||||
case CHECK_VERIFICATION_CODE_ERROR_TYPE_RATE_LIMITED ->
|
case CHECK_VERIFICATION_CODE_ERROR_TYPE_RATE_LIMITED -> throw new CompletionException(
|
||||||
throw new CompletionException(new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds())));
|
new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds()),
|
||||||
|
true));
|
||||||
|
|
||||||
default -> throw new CompletionException(new RuntimeException("Failed to check verification code: " + response.getError().getErrorType()));
|
default -> throw new CompletionException(new RuntimeException("Failed to check verification code: " + response.getError().getErrorType()));
|
||||||
}
|
}
|
||||||
|
|
|
@ -331,11 +331,12 @@ class AccountControllerTest {
|
||||||
when(captchaChecker.verify(eq(VALID_CAPTCHA_TOKEN), anyString()))
|
when(captchaChecker.verify(eq(VALID_CAPTCHA_TOKEN), anyString()))
|
||||||
.thenReturn(new AssessmentResult(true, ""));
|
.thenReturn(new AssessmentResult(true, ""));
|
||||||
|
|
||||||
doThrow(new RateLimitExceededException(Duration.ZERO)).when(pinLimiter).validate(eq(SENDER_OVER_PIN));
|
doThrow(new RateLimitExceededException(Duration.ZERO, true)).when(pinLimiter).validate(eq(SENDER_OVER_PIN));
|
||||||
|
|
||||||
doThrow(new RateLimitExceededException(Duration.ZERO)).when(smsVoicePrefixLimiter).validate(SENDER_OVER_PREFIX.substring(0, 4+2));
|
doThrow(new RateLimitExceededException(Duration.ZERO, true)).when(smsVoicePrefixLimiter)
|
||||||
doThrow(new RateLimitExceededException(Duration.ZERO)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_IP_HOST);
|
.validate(SENDER_OVER_PREFIX.substring(0, 4 + 2));
|
||||||
doThrow(new RateLimitExceededException(Duration.ZERO)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_HOST2);
|
doThrow(new RateLimitExceededException(Duration.ZERO, true)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_IP_HOST);
|
||||||
|
doThrow(new RateLimitExceededException(Duration.ZERO, true)).when(smsVoiceIpLimiter).validate(RATE_LIMITED_HOST2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@AfterEach
|
@AfterEach
|
||||||
|
@ -571,7 +572,7 @@ class AccountControllerTest {
|
||||||
@Test
|
@Test
|
||||||
void testSendCodeRateLimited() {
|
void testSendCodeRateLimited() {
|
||||||
when(registrationServiceClient.createRegistrationSession(any(), any()))
|
when(registrationServiceClient.createRegistrationSession(any(), any()))
|
||||||
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofMinutes(10))));
|
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofMinutes(10), true)));
|
||||||
|
|
||||||
Response response =
|
Response response =
|
||||||
resources.getJerseyTest()
|
resources.getJerseyTest()
|
||||||
|
@ -2050,7 +2051,7 @@ class AccountControllerTest {
|
||||||
when(accountsManager.getByAccountIdentifier(accountIdentifier)).thenReturn(Optional.of(account));
|
when(accountsManager.getByAccountIdentifier(accountIdentifier)).thenReturn(Optional.of(account));
|
||||||
|
|
||||||
MockUtils.updateRateLimiterResponseToFail(
|
MockUtils.updateRateLimiterResponseToFail(
|
||||||
rateLimiters, RateLimiters.Handle.CHECK_ACCOUNT_EXISTENCE, "127.0.0.1", expectedRetryAfter);
|
rateLimiters, RateLimiters.Handle.CHECK_ACCOUNT_EXISTENCE, "127.0.0.1", expectedRetryAfter, true);
|
||||||
|
|
||||||
final Response response = resources.getJerseyTest()
|
final Response response = resources.getJerseyTest()
|
||||||
.target(String.format("/v1/accounts/account/%s", accountIdentifier))
|
.target(String.format("/v1/accounts/account/%s", accountIdentifier))
|
||||||
|
@ -2115,7 +2116,7 @@ class AccountControllerTest {
|
||||||
void testLookupUsernameRateLimited() throws RateLimitExceededException {
|
void testLookupUsernameRateLimited() throws RateLimitExceededException {
|
||||||
final Duration expectedRetryAfter = Duration.ofSeconds(13);
|
final Duration expectedRetryAfter = Duration.ofSeconds(13);
|
||||||
MockUtils.updateRateLimiterResponseToFail(
|
MockUtils.updateRateLimiterResponseToFail(
|
||||||
rateLimiters, RateLimiters.Handle.USERNAME_LOOKUP, "127.0.0.1", expectedRetryAfter);
|
rateLimiters, RateLimiters.Handle.USERNAME_LOOKUP, "127.0.0.1", expectedRetryAfter, true);
|
||||||
final Response response = resources.getJerseyTest()
|
final Response response = resources.getJerseyTest()
|
||||||
.target(String.format("v1/accounts/username_hash/%s", BASE_64_URL_USERNAME_HASH_1))
|
.target(String.format("v1/accounts/username_hash/%s", BASE_64_URL_USERNAME_HASH_1))
|
||||||
.request()
|
.request()
|
||||||
|
|
|
@ -0,0 +1,396 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2023 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.controllers;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyString;
|
||||||
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
|
import static org.mockito.Mockito.doThrow;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
import com.google.common.collect.ImmutableSet;
|
||||||
|
import com.google.i18n.phonenumbers.PhoneNumberUtil;
|
||||||
|
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
|
||||||
|
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
||||||
|
import io.dropwizard.testing.junit5.ResourceExtension;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.Base64;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
import javax.annotation.Nullable;
|
||||||
|
import javax.ws.rs.WebApplicationException;
|
||||||
|
import javax.ws.rs.client.Entity;
|
||||||
|
import javax.ws.rs.client.Invocation;
|
||||||
|
import javax.ws.rs.core.HttpHeaders;
|
||||||
|
import javax.ws.rs.core.MediaType;
|
||||||
|
import javax.ws.rs.core.Response;
|
||||||
|
import org.apache.http.HttpStatus;
|
||||||
|
import org.glassfish.jersey.server.ServerProperties;
|
||||||
|
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Nested;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
|
import org.junit.jupiter.params.provider.Arguments;
|
||||||
|
import org.junit.jupiter.params.provider.EnumSource;
|
||||||
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
import org.mockito.stubbing.Answer;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.RegistrationLockError;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
|
||||||
|
import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse;
|
||||||
|
import org.whispersystems.textsecuregcm.entities.ChangeNumberRequest;
|
||||||
|
import org.whispersystems.textsecuregcm.entities.RegistrationSession;
|
||||||
|
import org.whispersystems.textsecuregcm.limits.RateLimiter;
|
||||||
|
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||||
|
import org.whispersystems.textsecuregcm.mappers.ImpossiblePhoneNumberExceptionMapper;
|
||||||
|
import org.whispersystems.textsecuregcm.mappers.NonNormalizedPhoneNumberExceptionMapper;
|
||||||
|
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
|
||||||
|
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.Account;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.ChangeNumberManager;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.Device;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
|
||||||
|
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
|
||||||
|
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||||
|
|
||||||
|
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||||
|
class AccountControllerV2Test {
|
||||||
|
|
||||||
|
public static final String NEW_NUMBER = PhoneNumberUtil.getInstance().format(
|
||||||
|
PhoneNumberUtil.getInstance().getExampleNumber("US"),
|
||||||
|
PhoneNumberUtil.PhoneNumberFormat.E164);
|
||||||
|
|
||||||
|
private final AccountsManager accountsManager = mock(AccountsManager.class);
|
||||||
|
private final ChangeNumberManager changeNumberManager = mock(ChangeNumberManager.class);
|
||||||
|
private final RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class);
|
||||||
|
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock(
|
||||||
|
RegistrationRecoveryPasswordsManager.class);
|
||||||
|
private final RegistrationLockVerificationManager registrationLockVerificationManager = mock(
|
||||||
|
RegistrationLockVerificationManager.class);
|
||||||
|
private final RateLimiters rateLimiters = mock(RateLimiters.class);
|
||||||
|
private final RateLimiter registrationLimiter = mock(RateLimiter.class);
|
||||||
|
|
||||||
|
private final ResourceExtension resources = ResourceExtension.builder()
|
||||||
|
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
|
||||||
|
.addProvider(AuthHelper.getAuthFilter())
|
||||||
|
.addProvider(
|
||||||
|
new PolymorphicAuthValueFactoryProvider.Binder<>(
|
||||||
|
ImmutableSet.of(AuthenticatedAccount.class,
|
||||||
|
DisabledPermittedAuthenticatedAccount.class)))
|
||||||
|
.addProvider(new RateLimitExceededExceptionMapper())
|
||||||
|
.addProvider(new ImpossiblePhoneNumberExceptionMapper())
|
||||||
|
.addProvider(new NonNormalizedPhoneNumberExceptionMapper())
|
||||||
|
.setMapper(SystemMapper.getMapper())
|
||||||
|
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
||||||
|
.addResource(
|
||||||
|
new AccountControllerV2(accountsManager, changeNumberManager,
|
||||||
|
new PhoneVerificationTokenManager(registrationServiceClient, registrationRecoveryPasswordsManager),
|
||||||
|
registrationLockVerificationManager, rateLimiters))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class ChangeNumber {
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() throws Exception {
|
||||||
|
when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter);
|
||||||
|
|
||||||
|
when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any())).thenAnswer(
|
||||||
|
(Answer<Account>) invocation -> {
|
||||||
|
final Account account = invocation.getArgument(0, Account.class);
|
||||||
|
final String number = invocation.getArgument(1, String.class);
|
||||||
|
final String pniIdentityKey = invocation.getArgument(2, String.class);
|
||||||
|
|
||||||
|
final UUID uuid = account.getUuid();
|
||||||
|
final List<Device> devices = account.getDevices();
|
||||||
|
|
||||||
|
final Account updatedAccount = mock(Account.class);
|
||||||
|
when(updatedAccount.getUuid()).thenReturn(uuid);
|
||||||
|
when(updatedAccount.getNumber()).thenReturn(number);
|
||||||
|
when(updatedAccount.getPhoneNumberIdentityKey()).thenReturn(pniIdentityKey);
|
||||||
|
when(updatedAccount.getPhoneNumberIdentifier()).thenReturn(UUID.randomUUID());
|
||||||
|
when(updatedAccount.getDevices()).thenReturn(devices);
|
||||||
|
|
||||||
|
for (long i = 1; i <= 3; i++) {
|
||||||
|
final Optional<Device> d = account.getDevice(i);
|
||||||
|
when(updatedAccount.getDevice(i)).thenReturn(d);
|
||||||
|
}
|
||||||
|
|
||||||
|
return updatedAccount;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void changeNumberSuccess() throws Exception {
|
||||||
|
|
||||||
|
when(registrationServiceClient.getSession(any(), any()))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(Optional.of(new RegistrationSession(NEW_NUMBER, true))));
|
||||||
|
|
||||||
|
final AccountIdentityResponse accountIdentityResponse =
|
||||||
|
resources.getJerseyTest()
|
||||||
|
.target("/v2/accounts/number")
|
||||||
|
.request()
|
||||||
|
.header(HttpHeaders.AUTHORIZATION,
|
||||||
|
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
|
||||||
|
.put(Entity.entity(
|
||||||
|
new ChangeNumberRequest(encodeSessionId("session"), null, NEW_NUMBER, "123", "123",
|
||||||
|
Collections.emptyList(),
|
||||||
|
Collections.emptyMap(), Collections.emptyMap()),
|
||||||
|
MediaType.APPLICATION_JSON_TYPE), AccountIdentityResponse.class);
|
||||||
|
|
||||||
|
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(NEW_NUMBER), any(), any(), any(),
|
||||||
|
any());
|
||||||
|
|
||||||
|
assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid());
|
||||||
|
assertEquals(NEW_NUMBER, accountIdentityResponse.number());
|
||||||
|
assertNotEquals(AuthHelper.VALID_PNI, accountIdentityResponse.pni());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void unprocessableRequestJson() {
|
||||||
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
|
.target("/v2/accounts/number")
|
||||||
|
.request()
|
||||||
|
.header(HttpHeaders.AUTHORIZATION,
|
||||||
|
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
|
||||||
|
try (Response response = request.put(Entity.json(unprocessableJson()))) {
|
||||||
|
assertEquals(400, response.getStatus());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void missingBasicAuthorization() {
|
||||||
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
|
.target("/v2/accounts/number")
|
||||||
|
.request();
|
||||||
|
try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) {
|
||||||
|
assertEquals(401, response.getStatus());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void invalidBasicAuthorization() {
|
||||||
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
|
.target("/v2/accounts/number")
|
||||||
|
.request()
|
||||||
|
.header(HttpHeaders.AUTHORIZATION, "Basic but-invalid");
|
||||||
|
try (Response response = request.put(Entity.json(invalidRequestJson()))) {
|
||||||
|
assertEquals(401, response.getStatus());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void invalidRequestBody() {
|
||||||
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
|
.target("/v2/accounts/number")
|
||||||
|
.request()
|
||||||
|
.header(HttpHeaders.AUTHORIZATION,
|
||||||
|
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
|
||||||
|
try (Response response = request.put(Entity.json(invalidRequestJson()))) {
|
||||||
|
assertEquals(422, response.getStatus());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void rateLimitedNumber() throws Exception {
|
||||||
|
doThrow(new RateLimitExceededException(null, true))
|
||||||
|
.when(registrationLimiter).validate(anyString());
|
||||||
|
|
||||||
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
|
.target("/v2/accounts/number")
|
||||||
|
.request()
|
||||||
|
.header(HttpHeaders.AUTHORIZATION,
|
||||||
|
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
|
||||||
|
try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) {
|
||||||
|
assertEquals(429, response.getStatus());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void registrationServiceTimeout() {
|
||||||
|
when(registrationServiceClient.getSession(any(), any()))
|
||||||
|
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()));
|
||||||
|
|
||||||
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
|
.target("/v2/accounts/number")
|
||||||
|
.request()
|
||||||
|
.header(HttpHeaders.AUTHORIZATION,
|
||||||
|
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
|
||||||
|
try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) {
|
||||||
|
assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@MethodSource
|
||||||
|
void registrationServiceSessionCheck(@Nullable final RegistrationSession session, final int expectedStatus,
|
||||||
|
final String message) {
|
||||||
|
when(registrationServiceClient.getSession(any(), any()))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(Optional.ofNullable(session)));
|
||||||
|
|
||||||
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
|
.target("/v2/accounts/number")
|
||||||
|
.request()
|
||||||
|
.header(HttpHeaders.AUTHORIZATION,
|
||||||
|
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
|
||||||
|
try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) {
|
||||||
|
assertEquals(expectedStatus, response.getStatus(), message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static Stream<Arguments> registrationServiceSessionCheck() {
|
||||||
|
return Stream.of(
|
||||||
|
Arguments.of(null, 401, "session not found"),
|
||||||
|
Arguments.of(new RegistrationSession("+18005551234", false), 400, "session number mismatch"),
|
||||||
|
Arguments.of(new RegistrationSession(NEW_NUMBER, false), 401, "session not verified")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@EnumSource(RegistrationLockError.class)
|
||||||
|
void registrationLock(final RegistrationLockError error) throws Exception {
|
||||||
|
when(registrationServiceClient.getSession(any(), any()))
|
||||||
|
.thenReturn(
|
||||||
|
CompletableFuture.completedFuture(Optional.of(new RegistrationSession(NEW_NUMBER, true))));
|
||||||
|
|
||||||
|
when(accountsManager.getByE164(any())).thenReturn(Optional.of(mock(Account.class)));
|
||||||
|
|
||||||
|
final Exception e = switch (error) {
|
||||||
|
case MISMATCH -> new WebApplicationException(error.getExpectedStatus());
|
||||||
|
case RATE_LIMITED -> new RateLimitExceededException(null, true);
|
||||||
|
};
|
||||||
|
doThrow(e)
|
||||||
|
.when(registrationLockVerificationManager).verifyRegistrationLock(any(), any());
|
||||||
|
|
||||||
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
|
.target("/v2/accounts/number")
|
||||||
|
.request()
|
||||||
|
.header(HttpHeaders.AUTHORIZATION,
|
||||||
|
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
|
||||||
|
try (Response response = request.put(Entity.json(requestJson("sessionId", NEW_NUMBER)))) {
|
||||||
|
assertEquals(error.getExpectedStatus(), response.getStatus());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void recoveryPasswordManagerVerificationTrue() throws Exception {
|
||||||
|
when(registrationRecoveryPasswordsManager.verify(any(), any()))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(true));
|
||||||
|
|
||||||
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
|
.target("/v2/accounts/number")
|
||||||
|
.request()
|
||||||
|
.header(HttpHeaders.AUTHORIZATION,
|
||||||
|
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
|
||||||
|
final byte[] recoveryPassword = new byte[32];
|
||||||
|
try (Response response = request.put(Entity.json(requestJsonRecoveryPassword(recoveryPassword, NEW_NUMBER)))) {
|
||||||
|
assertEquals(200, response.getStatus());
|
||||||
|
|
||||||
|
final AccountIdentityResponse accountIdentityResponse = response.readEntity(AccountIdentityResponse.class);
|
||||||
|
|
||||||
|
verify(changeNumberManager).changeNumber(eq(AuthHelper.VALID_ACCOUNT), eq(NEW_NUMBER), any(), any(), any(),
|
||||||
|
any());
|
||||||
|
|
||||||
|
assertEquals(AuthHelper.VALID_UUID, accountIdentityResponse.uuid());
|
||||||
|
assertEquals(NEW_NUMBER, accountIdentityResponse.number());
|
||||||
|
assertNotEquals(AuthHelper.VALID_PNI, accountIdentityResponse.pni());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void recoveryPasswordManagerVerificationFalse() {
|
||||||
|
when(registrationRecoveryPasswordsManager.verify(any(), any()))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(false));
|
||||||
|
|
||||||
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
|
.target("/v2/accounts/number")
|
||||||
|
.request()
|
||||||
|
.header(HttpHeaders.AUTHORIZATION,
|
||||||
|
AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD));
|
||||||
|
try (Response response = request.put(Entity.json(requestJsonRecoveryPassword(new byte[32], NEW_NUMBER)))) {
|
||||||
|
assertEquals(403, response.getStatus());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Valid request JSON with the given Recovery Password
|
||||||
|
*/
|
||||||
|
private static String requestJsonRecoveryPassword(final byte[] recoveryPassword, final String newNumber) {
|
||||||
|
return requestJson("", recoveryPassword, newNumber);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Valid request JSON with the give session ID and recovery password
|
||||||
|
*/
|
||||||
|
private static String requestJson(final String sessionId, final byte[] recoveryPassword, final String newNumber) {
|
||||||
|
return String.format("""
|
||||||
|
{
|
||||||
|
"sessionId": "%s",
|
||||||
|
"recoveryPassword": "%s",
|
||||||
|
"number": "%s",
|
||||||
|
"reglock": "1234",
|
||||||
|
"pniIdentityKey": "5678",
|
||||||
|
"deviceMessages": [],
|
||||||
|
"devicePniSignedPrekeys": {},
|
||||||
|
"pniRegistrationIds": {}
|
||||||
|
}
|
||||||
|
""", encodeSessionId(sessionId), encodeRecoveryPassword(recoveryPassword), newNumber);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Valid request JSON with the give session ID
|
||||||
|
*/
|
||||||
|
private static String requestJson(final String sessionId, final String newNumber) {
|
||||||
|
return requestJson(sessionId, new byte[0], newNumber);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Request JSON in the shape of {@link org.whispersystems.textsecuregcm.entities.ChangeNumberRequest}, but that
|
||||||
|
* fails validation
|
||||||
|
*/
|
||||||
|
private static String invalidRequestJson() {
|
||||||
|
return """
|
||||||
|
{
|
||||||
|
"sessionId": null
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Request JSON that cannot be marshalled into
|
||||||
|
* {@link org.whispersystems.textsecuregcm.entities.ChangeNumberRequest}
|
||||||
|
*/
|
||||||
|
private static String unprocessableJson() {
|
||||||
|
return """
|
||||||
|
{
|
||||||
|
"sessionId": []
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String encodeSessionId(final String sessionId) {
|
||||||
|
return Base64.getUrlEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String encodeRecoveryPassword(final byte[] recoveryPassword) {
|
||||||
|
return Base64.getEncoder().encodeToString(recoveryPassword);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -86,7 +86,8 @@ class ChallengeControllerTest {
|
||||||
""";
|
""";
|
||||||
|
|
||||||
final Duration retryAfter = Duration.ofMinutes(17);
|
final Duration retryAfter = Duration.ofMinutes(17);
|
||||||
doThrow(new RateLimitExceededException(retryAfter)).when(rateLimitChallengeManager).answerPushChallenge(any(), any());
|
doThrow(new RateLimitExceededException(retryAfter, true)).when(rateLimitChallengeManager)
|
||||||
|
.answerPushChallenge(any(), any());
|
||||||
|
|
||||||
final Response response = EXTENSION.target("/v1/challenge")
|
final Response response = EXTENSION.target("/v1/challenge")
|
||||||
.request()
|
.request()
|
||||||
|
@ -128,7 +129,8 @@ class ChallengeControllerTest {
|
||||||
""";
|
""";
|
||||||
|
|
||||||
final Duration retryAfter = Duration.ofMinutes(17);
|
final Duration retryAfter = Duration.ofMinutes(17);
|
||||||
doThrow(new RateLimitExceededException(retryAfter)).when(rateLimitChallengeManager).answerRecaptchaChallenge(any(), any(), any(), any());
|
doThrow(new RateLimitExceededException(retryAfter, true)).when(rateLimitChallengeManager)
|
||||||
|
.answerRecaptchaChallenge(any(), any(), any(), any());
|
||||||
|
|
||||||
final Response response = EXTENSION.target("/v1/challenge")
|
final Response response = EXTENSION.target("/v1/challenge")
|
||||||
.request()
|
.request()
|
||||||
|
|
|
@ -255,7 +255,8 @@ class ProfileControllerTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testProfileGetByAciRateLimited() throws RateLimitExceededException {
|
void testProfileGetByAciRateLimited() throws RateLimitExceededException {
|
||||||
doThrow(new RateLimitExceededException(Duration.ofSeconds(13))).when(rateLimiter).validate(AuthHelper.VALID_UUID);
|
doThrow(new RateLimitExceededException(Duration.ofSeconds(13), true)).when(rateLimiter)
|
||||||
|
.validate(AuthHelper.VALID_UUID);
|
||||||
|
|
||||||
Response response= resources.getJerseyTest()
|
Response response= resources.getJerseyTest()
|
||||||
.target("/v1/profile/" + AuthHelper.VALID_UUID_TWO)
|
.target("/v1/profile/" + AuthHelper.VALID_UUID_TWO)
|
||||||
|
@ -326,7 +327,8 @@ class ProfileControllerTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testProfileGetByPniRateLimited() throws RateLimitExceededException {
|
void testProfileGetByPniRateLimited() throws RateLimitExceededException {
|
||||||
doThrow(new RateLimitExceededException(Duration.ofSeconds(13))).when(rateLimiter).validate(AuthHelper.VALID_UUID);
|
doThrow(new RateLimitExceededException(Duration.ofSeconds(13), true)).when(rateLimiter)
|
||||||
|
.validate(AuthHelper.VALID_UUID);
|
||||||
|
|
||||||
Response response= resources.getJerseyTest()
|
Response response= resources.getJerseyTest()
|
||||||
.target("/v1/profile/" + AuthHelper.VALID_PNI_TWO)
|
.target("/v1/profile/" + AuthHelper.VALID_PNI_TWO)
|
||||||
|
|
|
@ -1,9 +1,26 @@
|
||||||
package org.whispersystems.textsecuregcm.controllers;
|
package org.whispersystems.textsecuregcm.controllers;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.Mockito.doThrow;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.never;
|
||||||
|
import static org.mockito.Mockito.reset;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
import com.google.common.collect.ImmutableSet;
|
import com.google.common.collect.ImmutableSet;
|
||||||
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
|
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
|
||||||
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
||||||
import io.dropwizard.testing.junit5.ResourceExtension;
|
import io.dropwizard.testing.junit5.ResourceExtension;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.Base64;
|
||||||
|
import java.util.UUID;
|
||||||
|
import javax.ws.rs.client.Entity;
|
||||||
|
import javax.ws.rs.core.MediaType;
|
||||||
|
import javax.ws.rs.core.Response;
|
||||||
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
|
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
@ -20,25 +37,6 @@ import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
|
||||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||||
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
|
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
|
||||||
|
|
||||||
import javax.ws.rs.client.Entity;
|
|
||||||
import javax.ws.rs.core.MediaType;
|
|
||||||
import javax.ws.rs.core.Response;
|
|
||||||
|
|
||||||
import java.nio.charset.StandardCharsets;
|
|
||||||
import java.time.Duration;
|
|
||||||
import java.util.Base64;
|
|
||||||
import java.util.UUID;
|
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
|
||||||
import static org.mockito.Mockito.doThrow;
|
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.never;
|
|
||||||
import static org.mockito.Mockito.reset;
|
|
||||||
import static org.mockito.Mockito.verify;
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||||
class ProvisioningControllerTest {
|
class ProvisioningControllerTest {
|
||||||
|
|
||||||
|
@ -101,7 +99,7 @@ class ProvisioningControllerTest {
|
||||||
final String destination = UUID.randomUUID().toString();
|
final String destination = UUID.randomUUID().toString();
|
||||||
final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8);
|
final byte[] messageBody = "test".getBytes(StandardCharsets.UTF_8);
|
||||||
|
|
||||||
doThrow(new RateLimitExceededException(Duration.ZERO))
|
doThrow(new RateLimitExceededException(Duration.ZERO, true))
|
||||||
.when(messagesRateLimiter).validate(AuthHelper.VALID_UUID);
|
.when(messagesRateLimiter).validate(AuthHelper.VALID_UUID);
|
||||||
|
|
||||||
try (final Response response = RESOURCE_EXTENSION.getJerseyTest()
|
try (final Response response = RESOURCE_EXTENSION.getJerseyTest()
|
||||||
|
|
|
@ -13,6 +13,7 @@ import static org.mockito.Mockito.doThrow;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
import com.google.i18n.phonenumbers.PhoneNumberUtil;
|
||||||
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
||||||
import io.dropwizard.testing.junit5.ResourceExtension;
|
import io.dropwizard.testing.junit5.ResourceExtension;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
|
@ -37,6 +38,7 @@ import org.junit.jupiter.params.provider.Arguments;
|
||||||
import org.junit.jupiter.params.provider.CsvSource;
|
import org.junit.jupiter.params.provider.CsvSource;
|
||||||
import org.junit.jupiter.params.provider.EnumSource;
|
import org.junit.jupiter.params.provider.EnumSource;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
|
import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
|
||||||
import org.whispersystems.textsecuregcm.auth.RegistrationLockError;
|
import org.whispersystems.textsecuregcm.auth.RegistrationLockError;
|
||||||
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
|
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
|
||||||
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
|
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
|
||||||
|
@ -51,12 +53,18 @@ import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
|
||||||
import org.whispersystems.textsecuregcm.storage.Account;
|
import org.whispersystems.textsecuregcm.storage.Account;
|
||||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||||
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
|
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
|
||||||
|
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
|
||||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||||
|
|
||||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||||
class RegistrationControllerTest {
|
class RegistrationControllerTest {
|
||||||
|
|
||||||
private static final String NUMBER = "+18005551212";
|
private static final String NUMBER = PhoneNumberUtil.getInstance().format(
|
||||||
|
PhoneNumberUtil.getInstance().getExampleNumber("US"),
|
||||||
|
PhoneNumberUtil.PhoneNumberFormat.E164);
|
||||||
|
|
||||||
|
public static final String PASSWORD = "password";
|
||||||
|
|
||||||
private final AccountsManager accountsManager = mock(AccountsManager.class);
|
private final AccountsManager accountsManager = mock(AccountsManager.class);
|
||||||
private final RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class);
|
private final RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class);
|
||||||
private final RegistrationLockVerificationManager registrationLockVerificationManager = mock(
|
private final RegistrationLockVerificationManager registrationLockVerificationManager = mock(
|
||||||
|
@ -66,7 +74,6 @@ class RegistrationControllerTest {
|
||||||
private final RateLimiters rateLimiters = mock(RateLimiters.class);
|
private final RateLimiters rateLimiters = mock(RateLimiters.class);
|
||||||
|
|
||||||
private final RateLimiter registrationLimiter = mock(RateLimiter.class);
|
private final RateLimiter registrationLimiter = mock(RateLimiter.class);
|
||||||
private final RateLimiter pinLimiter = mock(RateLimiter.class);
|
|
||||||
|
|
||||||
private final ResourceExtension resources = ResourceExtension.builder()
|
private final ResourceExtension resources = ResourceExtension.builder()
|
||||||
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
|
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
|
||||||
|
@ -76,14 +83,14 @@ class RegistrationControllerTest {
|
||||||
.setMapper(SystemMapper.getMapper())
|
.setMapper(SystemMapper.getMapper())
|
||||||
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
||||||
.addResource(
|
.addResource(
|
||||||
new RegistrationController(accountsManager, registrationServiceClient, registrationLockVerificationManager,
|
new RegistrationController(accountsManager,
|
||||||
registrationRecoveryPasswordsManager, rateLimiters))
|
new PhoneVerificationTokenManager(registrationServiceClient, registrationRecoveryPasswordsManager),
|
||||||
|
registrationLockVerificationManager, rateLimiters))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
void setUp() {
|
void setUp() {
|
||||||
when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter);
|
when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter);
|
||||||
when(rateLimiters.getPinLimiter()).thenReturn(pinLimiter);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -130,25 +137,23 @@ class RegistrationControllerTest {
|
||||||
final Invocation.Builder request = resources.getJerseyTest()
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
.target("/v1/registration")
|
.target("/v1/registration")
|
||||||
.request()
|
.request()
|
||||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||||
try (Response response = request.post(Entity.json(invalidRequestJson()))) {
|
try (Response response = request.post(Entity.json(invalidRequestJson()))) {
|
||||||
assertEquals(422, response.getStatus());
|
assertEquals(422, response.getStatus());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void rateLimitedSession() throws Exception {
|
void rateLimitedNumber() throws Exception {
|
||||||
final String sessionId = "sessionId";
|
|
||||||
doThrow(RateLimitExceededException.class)
|
doThrow(RateLimitExceededException.class)
|
||||||
.when(registrationLimiter).validate(encodeSessionId(sessionId));
|
.when(registrationLimiter).validate(NUMBER);
|
||||||
|
|
||||||
final Invocation.Builder request = resources.getJerseyTest()
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
.target("/v1/registration")
|
.target("/v1/registration")
|
||||||
.request()
|
.request()
|
||||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||||
try (Response response = request.post(Entity.json(requestJson(sessionId)))) {
|
try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
|
||||||
assertEquals(413, response.getStatus());
|
assertEquals(429, response.getStatus());
|
||||||
// In the future, change to assertEquals(429, response.getStatus());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -160,7 +165,7 @@ class RegistrationControllerTest {
|
||||||
final Invocation.Builder request = resources.getJerseyTest()
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
.target("/v1/registration")
|
.target("/v1/registration")
|
||||||
.request()
|
.request()
|
||||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||||
try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
|
try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
|
||||||
assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus());
|
assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus());
|
||||||
}
|
}
|
||||||
|
@ -174,7 +179,7 @@ class RegistrationControllerTest {
|
||||||
final Invocation.Builder request = resources.getJerseyTest()
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
.target("/v1/registration")
|
.target("/v1/registration")
|
||||||
.request()
|
.request()
|
||||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||||
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) {
|
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) {
|
||||||
assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus());
|
assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus());
|
||||||
}
|
}
|
||||||
|
@ -190,7 +195,7 @@ class RegistrationControllerTest {
|
||||||
final Invocation.Builder request = resources.getJerseyTest()
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
.target("/v1/registration")
|
.target("/v1/registration")
|
||||||
.request()
|
.request()
|
||||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||||
try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
|
try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
|
||||||
assertEquals(expectedStatus, response.getStatus(), message);
|
assertEquals(expectedStatus, response.getStatus(), message);
|
||||||
}
|
}
|
||||||
|
@ -214,7 +219,7 @@ class RegistrationControllerTest {
|
||||||
final Invocation.Builder request = resources.getJerseyTest()
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
.target("/v1/registration")
|
.target("/v1/registration")
|
||||||
.request()
|
.request()
|
||||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||||
final byte[] recoveryPassword = new byte[32];
|
final byte[] recoveryPassword = new byte[32];
|
||||||
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(recoveryPassword)))) {
|
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(recoveryPassword)))) {
|
||||||
assertEquals(200, response.getStatus());
|
assertEquals(200, response.getStatus());
|
||||||
|
@ -229,7 +234,7 @@ class RegistrationControllerTest {
|
||||||
final Invocation.Builder request = resources.getJerseyTest()
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
.target("/v1/registration")
|
.target("/v1/registration")
|
||||||
.request()
|
.request()
|
||||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||||
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) {
|
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) {
|
||||||
assertEquals(403, response.getStatus());
|
assertEquals(403, response.getStatus());
|
||||||
}
|
}
|
||||||
|
@ -245,7 +250,7 @@ class RegistrationControllerTest {
|
||||||
|
|
||||||
final Exception e = switch (error) {
|
final Exception e = switch (error) {
|
||||||
case MISMATCH -> new WebApplicationException(error.getExpectedStatus());
|
case MISMATCH -> new WebApplicationException(error.getExpectedStatus());
|
||||||
case RATE_LIMITED -> new RateLimitExceededException(null);
|
case RATE_LIMITED -> new RateLimitExceededException(null, true);
|
||||||
};
|
};
|
||||||
doThrow(e)
|
doThrow(e)
|
||||||
.when(registrationLockVerificationManager).verifyRegistrationLock(any(), any());
|
.when(registrationLockVerificationManager).verifyRegistrationLock(any(), any());
|
||||||
|
@ -253,7 +258,7 @@ class RegistrationControllerTest {
|
||||||
final Invocation.Builder request = resources.getJerseyTest()
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
.target("/v1/registration")
|
.target("/v1/registration")
|
||||||
.request()
|
.request()
|
||||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||||
try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
|
try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
|
||||||
assertEquals(error.getExpectedStatus(), response.getStatus());
|
assertEquals(error.getExpectedStatus(), response.getStatus());
|
||||||
}
|
}
|
||||||
|
@ -286,7 +291,7 @@ class RegistrationControllerTest {
|
||||||
final Invocation.Builder request = resources.getJerseyTest()
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
.target("/v1/registration")
|
.target("/v1/registration")
|
||||||
.request()
|
.request()
|
||||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||||
try (Response response = request.post(Entity.json(requestJson("sessionId", new byte[0], skipDeviceTransfer)))) {
|
try (Response response = request.post(Entity.json(requestJson("sessionId", new byte[0], skipDeviceTransfer)))) {
|
||||||
assertEquals(expectedStatus, response.getStatus());
|
assertEquals(expectedStatus, response.getStatus());
|
||||||
}
|
}
|
||||||
|
@ -294,7 +299,7 @@ class RegistrationControllerTest {
|
||||||
|
|
||||||
// this is functionally the same as deviceTransferAvailable(existingAccount=false)
|
// this is functionally the same as deviceTransferAvailable(existingAccount=false)
|
||||||
@Test
|
@Test
|
||||||
void success() throws Exception {
|
void registrationSuccess() throws Exception {
|
||||||
when(registrationServiceClient.getSession(any(), any()))
|
when(registrationServiceClient.getSession(any(), any()))
|
||||||
.thenReturn(CompletableFuture.completedFuture(Optional.of(new RegistrationSession(NUMBER, true))));
|
.thenReturn(CompletableFuture.completedFuture(Optional.of(new RegistrationSession(NUMBER, true))));
|
||||||
when(accountsManager.create(any(), any(), any(), any(), any()))
|
when(accountsManager.create(any(), any(), any(), any(), any()))
|
||||||
|
@ -303,7 +308,7 @@ class RegistrationControllerTest {
|
||||||
final Invocation.Builder request = resources.getJerseyTest()
|
final Invocation.Builder request = resources.getJerseyTest()
|
||||||
.target("/v1/registration")
|
.target("/v1/registration")
|
||||||
.request()
|
.request()
|
||||||
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
|
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
|
||||||
try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
|
try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
|
||||||
assertEquals(200, response.getStatus());
|
assertEquals(200, response.getStatus());
|
||||||
}
|
}
|
||||||
|
@ -365,13 +370,8 @@ class RegistrationControllerTest {
|
||||||
""";
|
""";
|
||||||
}
|
}
|
||||||
|
|
||||||
private static String authorizationHeader(final String number) {
|
|
||||||
return "Basic " + Base64.getEncoder().encodeToString(
|
|
||||||
String.format("%s:password", number).getBytes(StandardCharsets.UTF_8));
|
|
||||||
}
|
|
||||||
|
|
||||||
private static String encodeSessionId(final String sessionId) {
|
private static String encodeSessionId(final String sessionId) {
|
||||||
return Base64.getEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8));
|
return Base64.getUrlEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static String encodeRecoveryPassword(final byte[] recoveryPassword) {
|
private static String encodeRecoveryPassword(final byte[] recoveryPassword) {
|
||||||
|
|
|
@ -72,11 +72,11 @@ public class RateLimitedByIpTest {
|
||||||
public void testRateLimits() throws Exception {
|
public void testRateLimits() throws Exception {
|
||||||
Mockito.doNothing().when(RATE_LIMITER).validate(Mockito.eq(IP));
|
Mockito.doNothing().when(RATE_LIMITER).validate(Mockito.eq(IP));
|
||||||
validateSuccess("/test/strict", VALID_X_FORWARDED_FOR);
|
validateSuccess("/test/strict", VALID_X_FORWARDED_FOR);
|
||||||
Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER)).when(RATE_LIMITER).validate(Mockito.eq(IP));
|
Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER, true)).when(RATE_LIMITER).validate(Mockito.eq(IP));
|
||||||
validateFailure("/test/strict", VALID_X_FORWARDED_FOR, RETRY_AFTER);
|
validateFailure("/test/strict", VALID_X_FORWARDED_FOR, RETRY_AFTER);
|
||||||
Mockito.doNothing().when(RATE_LIMITER).validate(Mockito.eq(IP));
|
Mockito.doNothing().when(RATE_LIMITER).validate(Mockito.eq(IP));
|
||||||
validateSuccess("/test/strict", VALID_X_FORWARDED_FOR);
|
validateSuccess("/test/strict", VALID_X_FORWARDED_FOR);
|
||||||
Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER)).when(RATE_LIMITER).validate(Mockito.eq(IP));
|
Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER, true)).when(RATE_LIMITER).validate(Mockito.eq(IP));
|
||||||
validateFailure("/test/strict", VALID_X_FORWARDED_FOR, RETRY_AFTER);
|
validateFailure("/test/strict", VALID_X_FORWARDED_FOR, RETRY_AFTER);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,7 +92,7 @@ public class RateLimitedByIpTest {
|
||||||
validateSuccess("/test/loose", "");
|
validateSuccess("/test/loose", "");
|
||||||
|
|
||||||
// also checking that even if rate limiter is failing -- it doesn't matter in the case of invalid IP
|
// also checking that even if rate limiter is failing -- it doesn't matter in the case of invalid IP
|
||||||
Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER)).when(RATE_LIMITER).validate(Mockito.anyString());
|
Mockito.doThrow(new RateLimitExceededException(RETRY_AFTER, true)).when(RATE_LIMITER).validate(Mockito.anyString());
|
||||||
validateFailure("/test/loose", VALID_X_FORWARDED_FOR, RETRY_AFTER);
|
validateFailure("/test/loose", VALID_X_FORWARDED_FOR, RETRY_AFTER);
|
||||||
validateSuccess("/test/loose", INVALID_X_FORWARDED_FOR);
|
validateSuccess("/test/loose", INVALID_X_FORWARDED_FOR);
|
||||||
validateSuccess("/test/loose", "");
|
validateSuccess("/test/loose", "");
|
||||||
|
|
|
@ -340,7 +340,7 @@ class KeysControllerTest {
|
||||||
@Test
|
@Test
|
||||||
void testGetKeysRateLimited() throws RateLimitExceededException {
|
void testGetKeysRateLimited() throws RateLimitExceededException {
|
||||||
Duration retryAfter = Duration.ofSeconds(31);
|
Duration retryAfter = Duration.ofSeconds(31);
|
||||||
doThrow(new RateLimitExceededException(retryAfter)).when(rateLimiter).validate(anyString());
|
doThrow(new RateLimitExceededException(retryAfter, true)).when(rateLimiter).validate(anyString());
|
||||||
|
|
||||||
Response result = resources.getJerseyTest()
|
Response result = resources.getJerseyTest()
|
||||||
.target(String.format("/v2/keys/%s/*", EXISTS_PNI))
|
.target(String.format("/v2/keys/%s/*", EXISTS_PNI))
|
||||||
|
|
|
@ -60,11 +60,12 @@ public final class MockUtils {
|
||||||
final RateLimiters rateLimitersMock,
|
final RateLimiters rateLimitersMock,
|
||||||
final RateLimiters.Handle handle,
|
final RateLimiters.Handle handle,
|
||||||
final String input,
|
final String input,
|
||||||
final Duration retryAfter) {
|
final Duration retryAfter,
|
||||||
|
final boolean legacyStatusCode) {
|
||||||
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
|
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
|
||||||
doReturn(Optional.of(mockRateLimiter)).when(rateLimitersMock).byHandle(eq(handle));
|
doReturn(Optional.of(mockRateLimiter)).when(rateLimitersMock).byHandle(eq(handle));
|
||||||
try {
|
try {
|
||||||
doThrow(new RateLimitExceededException(retryAfter)).when(mockRateLimiter).validate(eq(input));
|
doThrow(new RateLimitExceededException(retryAfter, legacyStatusCode)).when(mockRateLimiter).validate(eq(input));
|
||||||
} catch (final RateLimitExceededException e) {
|
} catch (final RateLimitExceededException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue