diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 53dd82f3a..4066a9947 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -76,6 +76,7 @@ import org.whispersystems.textsecuregcm.auth.CertificateGenerator; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator; +import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener; import org.whispersystems.textsecuregcm.badges.ConfiguredProfileBadgeConverter; @@ -101,6 +102,7 @@ import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.PaymentsController; import org.whispersystems.textsecuregcm.controllers.ProfileController; import org.whispersystems.textsecuregcm.controllers.ProvisioningController; +import org.whispersystems.textsecuregcm.controllers.RegistrationController; import org.whispersystems.textsecuregcm.controllers.RemoteConfigController; import org.whispersystems.textsecuregcm.controllers.SecureBackupController; import org.whispersystems.textsecuregcm.controllers.SecureStorageController; @@ -518,6 +520,9 @@ public class WhisperServerService extends Application deviceIds = updatedAccount.getDevices().stream().map(Device::getId).toList(); + clientPresenceManager.disconnectAllPresences(updatedAccount.getUuid(), deviceIds); + */ + + throw new WebApplicationException(Response.status(FAILURE_HTTP_STATUS) + .entity(new RegistrationLockFailure(existingRegistrationLock.getTimeRemaining(), + existingRegistrationLock.needsFailureCredentials() ? existingBackupCredentials : null)) + .build()); + } + + rateLimiters.getPinLimiter().clear(phoneNumber); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RateLimitsConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RateLimitsConfiguration.java index f98701542..1d598d10c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RateLimitsConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/RateLimitsConfiguration.java @@ -23,15 +23,15 @@ public class RateLimitsConfiguration { @JsonProperty private RateLimitConfiguration smsVoicePrefix = new RateLimitConfiguration(1000, 1000); - @JsonProperty - private RateLimitConfiguration autoBlock = new RateLimitConfiguration(500, 500); - @JsonProperty private RateLimitConfiguration verifyNumber = new RateLimitConfiguration(2, 2); @JsonProperty private RateLimitConfiguration verifyPin = new RateLimitConfiguration(10, 1 / (24.0 * 60.0)); + @JsonProperty + private RateLimitConfiguration registration = new RateLimitConfiguration(2, 2); + @JsonProperty private RateLimitConfiguration attachments = new RateLimitConfiguration(50, 50); @@ -71,16 +71,9 @@ public class RateLimitsConfiguration { @JsonProperty private RateLimitConfiguration checkAccountExistence = new RateLimitConfiguration(1_000, 1_000 / 60.0); - @JsonProperty - private RateLimitConfiguration stories = new RateLimitConfiguration(10_000, 10_000 / (24.0 * 60.0)); - @JsonProperty private RateLimitConfiguration backupAuthCheck = new RateLimitConfiguration(100, 100 / (24.0 * 60.0)); - public RateLimitConfiguration getAutoBlock() { - return autoBlock; - } - public RateLimitConfiguration getAllocateDevice() { return allocateDevice; } @@ -129,6 +122,10 @@ public class RateLimitsConfiguration { return verifyPin; } + public RateLimitConfiguration getRegistration() { + return registration; + } + public RateLimitConfiguration getTurnAllocations() { return turnAllocations; } @@ -161,10 +158,6 @@ public class RateLimitsConfiguration { return checkAccountExistence; } - public RateLimitConfiguration getStories() { - return stories; - } - public RateLimitConfiguration getBackupAuthCheck() { return backupAuthCheck; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java index 63eb99dab..c91f2839c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java @@ -32,7 +32,6 @@ import java.util.Map; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletionException; -import javax.annotation.Nullable; import javax.servlet.http.HttpServletRequest; import javax.validation.Valid; import javax.validation.constraints.NotNull; @@ -61,10 +60,8 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader; import org.whispersystems.textsecuregcm.auth.ChangesDeviceEnabledState; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; -import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; -import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator; +import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; -import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock; import org.whispersystems.textsecuregcm.auth.StoredVerificationCode; import org.whispersystems.textsecuregcm.auth.TurnToken; import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; @@ -82,7 +79,6 @@ import org.whispersystems.textsecuregcm.entities.DeviceName; import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; import org.whispersystems.textsecuregcm.entities.MismatchedDevices; import org.whispersystems.textsecuregcm.entities.RegistrationLock; -import org.whispersystems.textsecuregcm.entities.RegistrationLockFailure; import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashRequest; import org.whispersystems.textsecuregcm.entities.ReserveUsernameHashResponse; import org.whispersystems.textsecuregcm.entities.StaleDevices; @@ -91,7 +87,6 @@ import org.whispersystems.textsecuregcm.limits.RateLimitedByIp; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; -import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.PushNotification; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.registration.ClientType; @@ -137,7 +132,7 @@ public class AccountController { .publishPercentiles(0.75, 0.95, 0.99, 0.999) .distributionStatisticExpiry(Duration.ofHours(2)) .register(Metrics.globalRegistry); - private static final String LOCKED_ACCOUNT_COUNTER_NAME = name(AccountController.class, "lockedAccount"); + private static final String CHALLENGE_PRESENT_TAG_NAME = "present"; private static final String CHALLENGE_MATCH_TAG_NAME = "matches"; private static final String COUNTRY_CODE_TAG_NAME = "countryCode"; @@ -150,25 +145,22 @@ public class AccountController { private static final String REGION_CODE_TAG_NAME = "regionCode"; private static final String VERIFICATION_TRANSPORT_TAG_NAME = "transport"; private static final String SCORE_TAG_NAME = "score"; - private static final String LOCK_REASON_TAG_NAME = "lockReason"; - private static final String ALREADY_LOCKED_TAG_NAME = "alreadyLocked"; - private final StoredVerificationCodeManager pendingAccounts; - private final AccountsManager accounts; - private final RateLimiters rateLimiters; - private final RegistrationServiceClient registrationServiceClient; + private final StoredVerificationCodeManager pendingAccounts; + private final AccountsManager accounts; + private final RateLimiters rateLimiters; + private final RegistrationServiceClient registrationServiceClient; private final DynamicConfigurationManager dynamicConfigurationManager; - private final TurnTokenGenerator turnTokenGenerator; - private final Map testDevices; - private final CaptchaChecker captchaChecker; - private final PushNotificationManager pushNotificationManager; - private final ExternalServiceCredentialsGenerator backupServiceCredentialsGenerator; + private final TurnTokenGenerator turnTokenGenerator; + private final Map testDevices; + private final CaptchaChecker captchaChecker; + private final PushNotificationManager pushNotificationManager; + private final RegistrationLockVerificationManager registrationLockVerificationManager; private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager; private final ChangeNumberManager changeNumberManager; private final Clock clock; - private final ClientPresenceManager clientPresenceManager; @VisibleForTesting static final Duration REGISTRATION_RPC_TIMEOUT = Duration.ofSeconds(15); @@ -184,9 +176,8 @@ public class AccountController { CaptchaChecker captchaChecker, PushNotificationManager pushNotificationManager, ChangeNumberManager changeNumberManager, + RegistrationLockVerificationManager registrationLockVerificationManager, RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager, - ExternalServiceCredentialsGenerator backupServiceCredentialsGenerator, - ClientPresenceManager clientPresenceManager, Clock clock ) { this.pendingAccounts = pendingAccounts; @@ -198,34 +189,12 @@ public class AccountController { this.turnTokenGenerator = turnTokenGenerator; this.captchaChecker = captchaChecker; this.pushNotificationManager = pushNotificationManager; - this.backupServiceCredentialsGenerator = backupServiceCredentialsGenerator; + this.registrationLockVerificationManager = registrationLockVerificationManager; this.changeNumberManager = changeNumberManager; - this.clientPresenceManager = clientPresenceManager; this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager; this.clock = clock; } - @VisibleForTesting - public AccountController( - StoredVerificationCodeManager pendingAccounts, - AccountsManager accounts, - RateLimiters rateLimiters, - RegistrationServiceClient registrationServiceClient, - DynamicConfigurationManager dynamicConfigurationManager, - TurnTokenGenerator turnTokenGenerator, - Map testDevices, - CaptchaChecker captchaChecker, - PushNotificationManager pushNotificationManager, - ChangeNumberManager changeNumberManager, - RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager, - ExternalServiceCredentialsGenerator backupServiceCredentialsGenerator - ) { - this(pendingAccounts, accounts, rateLimiters, - registrationServiceClient, dynamicConfigurationManager, turnTokenGenerator, testDevices, captchaChecker, - pushNotificationManager, changeNumberManager, registrationRecoveryPasswordsManager, - backupServiceCredentialsGenerator, null, Clock.systemUTC()); - } - @Timed @GET @Path("/{type}/preauth/{token}/{number}") @@ -424,7 +393,8 @@ public class AccountController { }); if (existingAccount.isPresent()) { - verifyRegistrationLock(existingAccount.get(), accountAttributes.getRegistrationLock()); + registrationLockVerificationManager.verifyRegistrationLock(existingAccount.get(), + accountAttributes.getRegistrationLock()); } if (availableForTransfer.orElse(false) && existingAccount.map(Account::isTransferSupported).orElse(false)) { @@ -488,7 +458,7 @@ public class AccountController { final Optional existingAccount = accounts.getByE164(number); if (existingAccount.isPresent()) { - verifyRegistrationLock(existingAccount.get(), request.registrationLock()); + registrationLockVerificationManager.verifyRegistrationLock(existingAccount.get(), request.registrationLock()); } rateLimiters.getVerifyLimiter().clear(number); @@ -823,51 +793,6 @@ public class AccountController { rateLimiter.validate(mostRecentProxy); } - private void verifyRegistrationLock(final Account existingAccount, @Nullable final String clientRegistrationLock) - throws RateLimitExceededException, WebApplicationException { - - final StoredRegistrationLock existingRegistrationLock = existingAccount.getRegistrationLock(); - final ExternalServiceCredentials existingBackupCredentials = - backupServiceCredentialsGenerator.generateForUuid(existingAccount.getUuid()); - - if (existingRegistrationLock.requiresClientRegistrationLock()) { - if (!Util.isEmpty(clientRegistrationLock)) { - rateLimiters.getPinLimiter().validate(existingAccount.getNumber()); - } - - final String phoneNumber = existingAccount.getNumber(); - - if (!existingRegistrationLock.verify(clientRegistrationLock)) { - // At this point, the client verified ownership of the phone number but doesn’t have the reglock PIN. - // Freezing the existing account credentials will definitively start the reglock timeout. - // Until the timeout, the current reglock can still be supplied, - // along with phone number verification, to restore access. - /* boolean alreadyLocked = existingAccount.hasLockedCredentials(); - Metrics.counter(LOCKED_ACCOUNT_COUNTER_NAME, - LOCK_REASON_TAG_NAME, "verifiedNumberFailedReglock", - ALREADY_LOCKED_TAG_NAME, Boolean.toString(alreadyLocked)) - .increment(); - - final Account updatedAccount; - if (!alreadyLocked) { - updatedAccount = accounts.update(existingAccount, Account::lockAuthenticationCredentials); - } else { - updatedAccount = existingAccount; - } - - List deviceIds = updatedAccount.getDevices().stream().map(Device::getId).toList(); - clientPresenceManager.disconnectAllPresences(updatedAccount.getUuid(), deviceIds); */ - - throw new WebApplicationException(Response.status(423) - .entity(new RegistrationLockFailure(existingRegistrationLock.getTimeRemaining(), - existingRegistrationLock.needsFailureCredentials() ? existingBackupCredentials : null)) - .build()); - } - - rateLimiters.getPinLimiter().clear(phoneNumber); - } - } - @VisibleForTesting static boolean pushChallengeMatches( final String number, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java new file mode 100644 index 000000000..b730d0b6d --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java @@ -0,0 +1,166 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.controllers; + +import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; + +import com.codahale.metrics.annotation.Timed; +import com.google.common.net.HttpHeaders; +import io.micrometer.core.instrument.DistributionSummary; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.Tags; +import java.security.MessageDigest; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Base64; +import java.util.Optional; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import javax.validation.Valid; +import javax.validation.constraints.NotNull; +import javax.ws.rs.BadRequestException; +import javax.ws.rs.ClientErrorException; +import javax.ws.rs.Consumes; +import javax.ws.rs.HeaderParam; +import javax.ws.rs.NotAuthorizedException; +import javax.ws.rs.POST; +import javax.ws.rs.Path; +import javax.ws.rs.Produces; +import javax.ws.rs.ServerErrorException; +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import org.apache.http.HttpStatus; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader; +import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; +import org.whispersystems.textsecuregcm.entities.AccountIdentityResponse; +import org.whispersystems.textsecuregcm.entities.RegistrationRequest; +import org.whispersystems.textsecuregcm.entities.RegistrationSession; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; +import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.util.HeaderUtils; +import org.whispersystems.textsecuregcm.util.Util; + +@Path("/v1/registration") +public class RegistrationController { + + private static final Logger logger = LoggerFactory.getLogger(RegistrationController.class); + + private static final DistributionSummary REREGISTRATION_IDLE_DAYS_DISTRIBUTION = DistributionSummary + .builder(name(RegistrationController.class, "reregistrationIdleDays")) + .publishPercentiles(0.75, 0.95, 0.99, 0.999) + .distributionStatisticExpiry(Duration.ofHours(2)) + .register(Metrics.globalRegistry); + + private static final String ACCOUNT_CREATED_COUNTER_NAME = name(RegistrationController.class, "accountCreated"); + private static final String COUNTRY_CODE_TAG_NAME = "countryCode"; + private static final String REGION_CODE_TAG_NAME = "regionCode"; + private static final String VERIFICATION_TYPE_TAG_NAME = "verification"; + + private static final Duration REGISTRATION_RPC_TIMEOUT = Duration.ofSeconds(15); + + private final AccountsManager accounts; + private final RegistrationServiceClient registrationServiceClient; + private final RegistrationLockVerificationManager registrationLockVerificationManager; + private final RateLimiters rateLimiters; + + public RegistrationController(final AccountsManager accounts, + final RegistrationServiceClient registrationServiceClient, + final RegistrationLockVerificationManager registrationLockVerificationManager, + final RateLimiters rateLimiters) { + this.accounts = accounts; + this.registrationServiceClient = registrationServiceClient; + this.registrationLockVerificationManager = registrationLockVerificationManager; + this.rateLimiters = rateLimiters; + } + + @Timed + @POST + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + public AccountIdentityResponse register( + @HeaderParam(HttpHeaders.AUTHORIZATION) @NotNull final BasicAuthorizationHeader authorizationHeader, + @HeaderParam(HeaderUtils.X_SIGNAL_AGENT) final String signalAgent, + @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, + @NotNull @Valid final RegistrationRequest registrationRequest) throws RateLimitExceededException, InterruptedException { + + rateLimiters.getRegistrationLimiter().validate(registrationRequest.sessionId()); + + final byte[] sessionId; + try { + sessionId = Base64.getDecoder().decode(registrationRequest.sessionId()); + } catch (final IllegalArgumentException e) { + throw new ClientErrorException("Malformed session ID", HttpStatus.SC_UNPROCESSABLE_ENTITY); + } + + final String number = authorizationHeader.getUsername(); + final String password = authorizationHeader.getPassword(); + + final String verificationType = "phoneNumberVerification"; + try { + final Optional maybeSession = registrationServiceClient.getSession(sessionId, + REGISTRATION_RPC_TIMEOUT) + .get(REGISTRATION_RPC_TIMEOUT.plusSeconds(1).getSeconds(), TimeUnit.SECONDS); + + final RegistrationSession session = maybeSession.orElseThrow( + () -> new NotAuthorizedException("session not verified")); + if (!MessageDigest.isEqual(number.getBytes(), session.number().getBytes())) { + throw new BadRequestException("number does not match session"); + } + if (!session.verified()) { + throw new NotAuthorizedException("session not verified"); + } + + } catch (final CancellationException | ExecutionException | TimeoutException e) { + logger.error("Registration service failure", e); + throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE); + } + + final Optional existingAccount = accounts.getByE164(number); + + existingAccount.ifPresent(account -> { + final Instant accountLastSeen = Instant.ofEpochMilli(account.getLastSeen()); + final Duration timeSinceLastSeen = Duration.between(accountLastSeen, Instant.now()); + REREGISTRATION_IDLE_DAYS_DISTRIBUTION.record(timeSinceLastSeen.toDays()); + }); + + if (existingAccount.isPresent()) { + registrationLockVerificationManager.verifyRegistrationLock(existingAccount.get(), + registrationRequest.accountAttributes().getRegistrationLock()); + } + + if (!registrationRequest.skipDeviceTransfer() && existingAccount.map(Account::isTransferSupported).orElse(false)) { + // If a device transfer is possible, clients must explicitly opt out of a transfer (i.e. after prompting the user) + // before we'll let them create a new account "from scratch" + throw new WebApplicationException(Response.status(409, "device transfer available").build()); + } + + final Account account = accounts.create(number, password, signalAgent, registrationRequest.accountAttributes(), + existingAccount.map(Account::getBadges).orElseGet(ArrayList::new)); + + Metrics.counter(ACCOUNT_CREATED_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), + Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)), + Tag.of(REGION_CODE_TAG_NAME, Util.getRegion(number)), + Tag.of(VERIFICATION_TYPE_TAG_NAME, verificationType))) + .increment(); + + return new AccountIdentityResponse(account.getUuid(), + account.getNumber(), + account.getPhoneNumberIdentifier(), + account.getUsernameHash().orElse(null), + existingAccount.map(Account::isStorageSupported).orElse(false)); + } + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java new file mode 100644 index 000000000..5efcced2e --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java @@ -0,0 +1,16 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.entities; + +import javax.validation.Valid; +import javax.validation.constraints.NotBlank; +import javax.validation.constraints.NotNull; + +public record RegistrationRequest(@NotBlank String sessionId, + @NotNull @Valid AccountAttributes accountAttributes, + boolean skipDeviceTransfer) { + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationSession.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationSession.java new file mode 100644 index 000000000..d7e05bfc5 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationSession.java @@ -0,0 +1,10 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.entities; + +public record RegistrationSession(String number, boolean verified) { + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java index c79209f8e..15967b426 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java @@ -43,6 +43,7 @@ public class RateLimiters { private final RateLimiter smsVoicePrefixLimiter; private final RateLimiter verifyLimiter; private final RateLimiter pinLimiter; + private final RateLimiter registrationLimiter; private final RateLimiter attachmentLimiter; private final RateLimiter preKeysLimiter; private final RateLimiter messagesLimiter; @@ -54,7 +55,6 @@ public class RateLimiters { private final RateLimiter artPackLimiter; private final RateLimiter usernameSetLimiter; private final RateLimiter usernameReserveLimiter; - private final RateLimiter storiesLimiter; private final Map rateLimiterByHandle; @@ -66,6 +66,7 @@ public class RateLimiters { this.smsVoicePrefixLimiter = fromConfig("smsVoicePrefix", config.getSmsVoicePrefix(), cacheCluster); this.verifyLimiter = fromConfig("verify", config.getVerifyNumber(), cacheCluster); this.pinLimiter = fromConfig("pin", config.getVerifyPin(), cacheCluster); + this.registrationLimiter = fromConfig("registration", config.getRegistration(), cacheCluster); this.attachmentLimiter = fromConfig("attachmentCreate", config.getAttachments(), cacheCluster); this.preKeysLimiter = fromConfig("prekeys", config.getPreKeys(), cacheCluster); this.messagesLimiter = fromConfig("messages", config.getMessages(), cacheCluster); @@ -77,7 +78,6 @@ public class RateLimiters { this.artPackLimiter = fromConfig("artPack", config.getArtPack(), cacheCluster); this.usernameSetLimiter = fromConfig("usernameSet", config.getUsernameSet(), cacheCluster); this.usernameReserveLimiter = fromConfig("usernameReserve", config.getUsernameReserve(), cacheCluster); - this.storiesLimiter = fromConfig("stories", config.getStories(), cacheCluster); this.rateLimiterByHandle = Stream.of( fromConfig(Handle.BACKUP_AUTH_CHECK.id(), config.getBackupAuthCheck(), cacheCluster), @@ -138,6 +138,10 @@ public class RateLimiters { return pinLimiter; } + public RateLimiter getRegistrationLimiter() { + return registrationLimiter; + } + public RateLimiter getTurnLimiter() { return turnLimiter; } @@ -170,10 +174,6 @@ public class RateLimiters { return byHandle(Handle.CHECK_ACCOUNT_EXISTENCE).orElseThrow(); } - public RateLimiter getStoriesLimiter() { - return storiesLimiter; - } - private static RateLimiter fromConfig( final String name, final RateLimitsConfiguration.RateLimitConfiguration cfg, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/registration/RegistrationServiceClient.java b/service/src/main/java/org/whispersystems/textsecuregcm/registration/RegistrationServiceClient.java index aca4b8332..926161763 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/registration/RegistrationServiceClient.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/registration/RegistrationServiceClient.java @@ -3,6 +3,7 @@ package org.whispersystems.textsecuregcm.registration; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.i18n.phonenumbers.NumberParseException; import com.google.i18n.phonenumbers.PhoneNumberUtil; import com.google.i18n.phonenumbers.Phonenumber; import com.google.protobuf.ByteString; @@ -16,7 +17,7 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.time.Duration; -import java.time.Instant; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.Executor; @@ -24,11 +25,12 @@ import java.util.concurrent.TimeUnit; import org.apache.commons.lang3.StringUtils; import org.checkerframework.checker.nullness.qual.Nullable; import org.signal.registration.rpc.CheckVerificationCodeRequest; -import org.signal.registration.rpc.CheckVerificationCodeResponse; import org.signal.registration.rpc.CreateRegistrationSessionRequest; +import org.signal.registration.rpc.GetRegistrationSessionMetadataRequest; import org.signal.registration.rpc.RegistrationServiceGrpc; import org.signal.registration.rpc.SendVerificationCodeRequest; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.entities.RegistrationSession; public class RegistrationServiceClient implements Managed { @@ -36,6 +38,24 @@ public class RegistrationServiceClient implements Managed { private final RegistrationServiceGrpc.RegistrationServiceFutureStub stub; private final Executor callbackExecutor; + /** + * @param from an e164 in a {@code long} representation e.g. {@code 18005550123} + * @return the e164 in a {@code String} representation (e.g. {@code "+18005550123"}) + * @throws IllegalArgumentException if the number cannot be parsed to a string + */ + static String convertNumeralE164ToString(long from) { + + try { + final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance() + .parse("+" + from, null); + return PhoneNumberUtil.getInstance() + .format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164); + } catch (final NumberParseException e) { + throw new IllegalArgumentException("could not parse to phone number", e); + } + + } + public RegistrationServiceClient(final String host, final int port, final String apiKey, @@ -116,9 +136,9 @@ public class RegistrationServiceClient implements Managed { return toCompletableFuture(stub.withDeadline(toDeadline(timeout)) .checkVerificationCode(CheckVerificationCodeRequest.newBuilder() - .setSessionId(ByteString.copyFrom(sessionId)) - .setVerificationCode(verificationCode) - .build())) + .setSessionId(ByteString.copyFrom(sessionId)) + .setVerificationCode(verificationCode) + .build())) .thenApply(response -> { if (response.hasError()) { switch (response.getError().getErrorType()) { @@ -133,6 +153,26 @@ public class RegistrationServiceClient implements Managed { }); } + public CompletableFuture> getSession(final byte[] sessionId, + final Duration timeout) { + return toCompletableFuture(stub.withDeadline(toDeadline(timeout)).getSessionMetadata( + GetRegistrationSessionMetadataRequest.newBuilder() + .setSessionId(ByteString.copyFrom(sessionId)).build())) + .thenApply(response -> { + if (response.hasError()) { + switch (response.getError().getErrorType()) { + case GET_REGISTRATION_SESSION_METADATA_ERROR_TYPE_NOT_FOUND -> { + return Optional.empty(); + } + default -> throw new RuntimeException("Failed to get session: " + response.getError().getErrorType()); + } + } + + final String number = convertNumeralE164ToString(response.getSessionMetadata().getE164()); + return Optional.of(new RegistrationSession(number, response.getSessionMetadata().getVerified())); + }); + } + private static Deadline toDeadline(final Duration timeout) { return Deadline.after(timeout.toMillis(), TimeUnit.MILLISECONDS); } diff --git a/service/src/main/proto/RegistrationService.proto b/service/src/main/proto/RegistrationService.proto index bcba2ee2f..deab91865 100644 --- a/service/src/main/proto/RegistrationService.proto +++ b/service/src/main/proto/RegistrationService.proto @@ -10,6 +10,11 @@ service RegistrationService { */ rpc create_session (CreateRegistrationSessionRequest) returns (CreateRegistrationSessionResponse) {} + /** + * Retrieves session metadata for a given session. + */ + rpc get_session_metadata (GetRegistrationSessionMetadataRequest) returns (GetRegistrationSessionMetadataResponse) {} + /** * Sends a verification code to a destination phone number within the context * of a previously-created registration session. @@ -56,6 +61,11 @@ message RegistrationSessionMetadata { * of this session. */ bool verified = 2; + + /** + * The phone number associated with this registration session. + */ + uint64 e164 = 3; } message CreateRegistrationSessionError { @@ -95,6 +105,33 @@ enum CreateRegistrationSessionErrorType { CREATE_REGISTRATION_SESSION_ERROR_TYPE_ILLEGAL_PHONE_NUMBER = 2; } +message GetRegistrationSessionMetadataRequest { + /** + * The ID of the session for which to retrieve metadata. + */ + bytes session_id = 1; +} + +message GetRegistrationSessionMetadataResponse { + oneof response { + RegistrationSessionMetadata session_metadata = 1; + GetRegistrationSessionMetadataError error = 2; + } +} + +message GetRegistrationSessionMetadataError { + GetRegistrationSessionMetadataErrorType error_type = 1; +} + +enum GetRegistrationSessionMetadataErrorType { + GET_REGISTRATION_SESSION_METADATA_ERROR_TYPE_UNSPECIFIED = 0; + + /** + * No session was found with the given identifier. + */ + GET_REGISTRATION_SESSION_METADATA_ERROR_TYPE_NOT_FOUND = 1; +} + message SendVerificationCodeRequest { reserved 1; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/RegistrationLockError.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/RegistrationLockError.java new file mode 100644 index 000000000..5c3a00ab3 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/RegistrationLockError.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.auth; + +public enum RegistrationLockError { + MISMATCH(RegistrationLockVerificationManager.FAILURE_HTTP_STATUS), + RATE_LIMITED(413) // This will be changed to 429 in a future revision + ; + + private final int expectedStatus; + + RegistrationLockError(final int expectedStatus) { + this.expectedStatus = expectedStatus; + } + + public int getExpectedStatus() { + return expectedStatus; + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManagerTest.java new file mode 100644 index 000000000..28d4a9257 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/RegistrationLockVerificationManagerTest.java @@ -0,0 +1,118 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.auth; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.UUID; +import java.util.function.Consumer; +import java.util.stream.Stream; +import javax.annotation.Nullable; +import javax.ws.rs.WebApplicationException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.MethodSource; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.limits.RateLimiter; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.util.Pair; + +class RegistrationLockVerificationManagerTest { + + private final AccountsManager accountsManager = mock(AccountsManager.class); + private final ClientPresenceManager clientPresenceManager = mock(ClientPresenceManager.class); + private final ExternalServiceCredentialsGenerator backupServiceCredentialsGeneraor = mock( + ExternalServiceCredentialsGenerator.class); + private final RateLimiters rateLimiters = mock(RateLimiters.class); + private final RegistrationLockVerificationManager registrationLockVerificationManager = new RegistrationLockVerificationManager( + accountsManager, clientPresenceManager, backupServiceCredentialsGeneraor, rateLimiters); + + private final RateLimiter pinLimiter = mock(RateLimiter.class); + + private Account account; + private StoredRegistrationLock existingRegistrationLock; + + @BeforeEach + void setUp() { + when(rateLimiters.getPinLimiter()).thenReturn(pinLimiter); + when(backupServiceCredentialsGeneraor.generateForUuid(any())) + .thenReturn(mock(ExternalServiceCredentials.class)); + + account = mock(Account.class); + when(account.getUuid()).thenReturn(UUID.randomUUID()); + when(account.getNumber()).thenReturn("+18005551212"); + existingRegistrationLock = mock(StoredRegistrationLock.class); + when(account.getRegistrationLock()).thenReturn(existingRegistrationLock); + } + + @ParameterizedTest + @EnumSource + void testErrors(RegistrationLockError error) throws Exception { + + when(existingRegistrationLock.requiresClientRegistrationLock()).thenReturn(true); + + final String submittedRegistrationLock = "reglock"; + + final Pair, Consumer> exceptionType = switch (error) { + case MISMATCH -> { + when(existingRegistrationLock.verify(submittedRegistrationLock)).thenReturn(false); + yield new Pair<>(WebApplicationException.class, e -> { + if (e instanceof WebApplicationException wae) { + assertEquals(RegistrationLockVerificationManager.FAILURE_HTTP_STATUS, wae.getResponse().getStatus()); + } else { + fail("Exception was not of expected type"); + } + }); + } + case RATE_LIMITED -> { + when(existingRegistrationLock.verify(any())).thenReturn(true); + doThrow(RateLimitExceededException.class).when(pinLimiter).validate(anyString()); + yield new Pair<>(RateLimitExceededException.class, ignored -> { + }); + } + }; + + final Exception e = assertThrows(exceptionType.first(), () -> + registrationLockVerificationManager.verifyRegistrationLock(account, submittedRegistrationLock)); + + exceptionType.second().accept(e); + } + + @ParameterizedTest + @MethodSource + void testSuccess(final boolean requiresClientRegistrationLock, @Nullable final String submittedRegistrationLock) { + + when(existingRegistrationLock.requiresClientRegistrationLock()) + .thenReturn(requiresClientRegistrationLock); + when(existingRegistrationLock.verify(submittedRegistrationLock)).thenReturn(true); + + assertDoesNotThrow( + () -> registrationLockVerificationManager.verifyRegistrationLock(account, submittedRegistrationLock)); + } + + static Stream testSuccess() { + return Stream.of( + Arguments.of(false, null), + Arguments.of(true, null), + Arguments.of(false, "reglock"), + Arguments.of(true, "reglock") + ); + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java index 25fafea9b..da28e4ebb 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java @@ -67,13 +67,14 @@ import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator; +import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock; import org.whispersystems.textsecuregcm.auth.StoredVerificationCode; import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; import org.whispersystems.textsecuregcm.captcha.AssessmentResult; import org.whispersystems.textsecuregcm.captcha.CaptchaChecker; -import org.whispersystems.textsecuregcm.configuration.SecureStorageServiceConfiguration; +import org.whispersystems.textsecuregcm.configuration.SecureBackupServiceConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicCaptchaConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.entities.AccountAttributes; @@ -184,12 +185,15 @@ class AccountControllerTest { private byte[] registration_lock_key = new byte[32]; - private static final SecureStorageServiceConfiguration STORAGE_CFG = MockUtils.buildMock( - SecureStorageServiceConfiguration.class, - cfg -> when(cfg.decodeUserAuthenticationTokenSharedSecret()).thenReturn(new byte[32])); + private static final SecureBackupServiceConfiguration BACKUP_CFG = MockUtils.buildMock( + SecureBackupServiceConfiguration.class, + cfg -> when(cfg.getUserAuthenticationTokenSharedSecret()).thenReturn(new byte[32])); - private static final ExternalServiceCredentialsGenerator STORAGE_CREDENTIAL_GENERATOR = SecureStorageController - .credentialsGenerator(STORAGE_CFG); + private static final ExternalServiceCredentialsGenerator backupCredentialsGenerator = SecureBackupController.credentialsGenerator( + BACKUP_CFG); + + private static final RegistrationLockVerificationManager registrationLockVerificationManager = new RegistrationLockVerificationManager( + accountsManager, clientPresenceManager, backupCredentialsGenerator, rateLimiters); private static final ResourceExtension resources = ResourceExtension.builder() .addProvider(AuthHelper.getAuthFilter()) @@ -214,9 +218,8 @@ class AccountControllerTest { captchaChecker, pushNotificationManager, changeNumberManager, + registrationLockVerificationManager, registrationRecoveryPasswordsManager, - STORAGE_CREDENTIAL_GENERATOR, - clientPresenceManager, testClock)) .build(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java new file mode 100644 index 000000000..fe1f2988a --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java @@ -0,0 +1,306 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.controllers; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; +import io.dropwizard.testing.junit5.ResourceExtension; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; +import javax.annotation.Nullable; +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.client.Entity; +import javax.ws.rs.client.Invocation; +import javax.ws.rs.core.HttpHeaders; +import javax.ws.rs.core.Response; +import org.apache.http.HttpStatus; +import org.glassfish.jersey.server.ServerProperties; +import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.MethodSource; +import org.whispersystems.textsecuregcm.auth.RegistrationLockError; +import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; +import org.whispersystems.textsecuregcm.entities.RegistrationSession; +import org.whispersystems.textsecuregcm.limits.RateLimiter; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.mappers.ImpossiblePhoneNumberExceptionMapper; +import org.whispersystems.textsecuregcm.mappers.NonNormalizedPhoneNumberExceptionMapper; +import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; +import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.util.SystemMapper; + +@ExtendWith(DropwizardExtensionsSupport.class) +class RegistrationControllerTest { + + private static final String NUMBER = "+18005551212"; + private final AccountsManager accountsManager = mock(AccountsManager.class); + private final RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class); + private final RegistrationLockVerificationManager registrationLockVerificationManager = mock( + RegistrationLockVerificationManager.class); + private final RateLimiters rateLimiters = mock(RateLimiters.class); + + private final RateLimiter registrationLimiter = mock(RateLimiter.class); + private final RateLimiter pinLimiter = mock(RateLimiter.class); + + private final ResourceExtension resources = ResourceExtension.builder() + .addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE) + .addProvider(new RateLimitExceededExceptionMapper()) + .addProvider(new ImpossiblePhoneNumberExceptionMapper()) + .addProvider(new NonNormalizedPhoneNumberExceptionMapper()) + .setMapper(SystemMapper.getMapper()) + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addResource( + new RegistrationController(accountsManager, registrationServiceClient, registrationLockVerificationManager, + rateLimiters)) + .build(); + + @BeforeEach + void setUp() { + when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter); + when(rateLimiters.getPinLimiter()).thenReturn(pinLimiter); + } + + @Test + void unprocessableRequestJson() { + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/registration") + .request(); + try (Response response = request.post(Entity.json(unprocessableJson()))) { + assertEquals(400, response.getStatus()); + } + } + + @Test + void missingBasicAuthorization() { + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/registration") + .request(); + try (Response response = request.post(Entity.json(requestJson("sessionId")))) { + assertEquals(400, response.getStatus()); + } + } + + @Test + void invalidBasicAuthorization() { + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/registration") + .request() + .header(HttpHeaders.AUTHORIZATION, "Basic but-invalid"); + try (Response response = request.post(Entity.json(invalidRequestJson()))) { + assertEquals(401, response.getStatus()); + } + } + + @Test + void invalidRequestBody() { + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/registration") + .request() + .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); + try (Response response = request.post(Entity.json(invalidRequestJson()))) { + assertEquals(422, response.getStatus()); + } + } + + @Test + void rateLimitedSession() throws Exception { + final String sessionId = "sessionId"; + doThrow(RateLimitExceededException.class) + .when(registrationLimiter).validate(encodeSessionId(sessionId)); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/registration") + .request() + .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); + try (Response response = request.post(Entity.json(requestJson(sessionId)))) { + assertEquals(413, response.getStatus()); + // In the future, change to assertEquals(429, response.getStatus()); + } + } + + @Test + void registrationServiceTimeout() { + when(registrationServiceClient.getSession(any(), any())) + .thenReturn(CompletableFuture.failedFuture(new RuntimeException())); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/registration") + .request() + .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); + try (Response response = request.post(Entity.json(requestJson("sessionId")))) { + assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus()); + } + } + + @ParameterizedTest + @MethodSource + void registrationServiceSessionCheck(@Nullable final RegistrationSession session, final int expectedStatus, + final String message) { + when(registrationServiceClient.getSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(Optional.ofNullable(session))); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/registration") + .request() + .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); + try (Response response = request.post(Entity.json(requestJson("sessionId")))) { + assertEquals(expectedStatus, response.getStatus(), message); + } + } + + static Stream registrationServiceSessionCheck() { + return Stream.of( + Arguments.of(null, 401, "session not found"), + Arguments.of(new RegistrationSession("+18005551234", false), 400, "session number mismatch"), + Arguments.of(new RegistrationSession(NUMBER, false), 401, "session not verified") + ); + } + + @ParameterizedTest + @EnumSource(RegistrationLockError.class) + void registrationLock(final RegistrationLockError error) throws Exception { + when(registrationServiceClient.getSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(new RegistrationSession(NUMBER, true)))); + + when(accountsManager.getByE164(any())).thenReturn(Optional.of(mock(Account.class))); + + final Exception e = switch (error) { + case MISMATCH -> new WebApplicationException(error.getExpectedStatus()); + case RATE_LIMITED -> new RateLimitExceededException(null); + }; + doThrow(e) + .when(registrationLockVerificationManager).verifyRegistrationLock(any(), any()); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/registration") + .request() + .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); + try (Response response = request.post(Entity.json(requestJson("sessionId")))) { + assertEquals(error.getExpectedStatus(), response.getStatus()); + } + } + + @ParameterizedTest + @CsvSource({ + "false, false, false, 200", + "true, false, false, 200", + "true, false, true, 200", + "true, true, false, 409", + "true, true, true, 200" + }) + void deviceTransferAvailable(final boolean existingAccount, final boolean transferSupported, + final boolean skipDeviceTransfer, final int expectedStatus) throws Exception { + when(registrationServiceClient.getSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(new RegistrationSession(NUMBER, true)))); + + final Optional maybeAccount; + if (existingAccount) { + final Account account = mock(Account.class); + when(account.isTransferSupported()).thenReturn(transferSupported); + maybeAccount = Optional.of(account); + } else { + maybeAccount = Optional.empty(); + } + when(accountsManager.getByE164(any())).thenReturn(maybeAccount); + when(accountsManager.create(any(), any(), any(), any(), any())).thenReturn(mock(Account.class)); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/registration") + .request() + .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); + try (Response response = request.post(Entity.json(requestJson("sessionId", skipDeviceTransfer)))) { + assertEquals(expectedStatus, response.getStatus()); + } + } + + // this is functionally the same as deviceTransferAvailable(existingAccount=false) + @Test + void success() throws Exception { + when(registrationServiceClient.getSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(new RegistrationSession(NUMBER, true)))); + when(accountsManager.create(any(), any(), any(), any(), any())) + .thenReturn(mock(Account.class)); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/registration") + .request() + .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); + try (Response response = request.post(Entity.json(requestJson("sessionId")))) { + assertEquals(200, response.getStatus()); + } + } + + /** + * Valid request JSON with the give session ID and skipDeviceTransfer + */ + private static String requestJson(final String sessionId, final boolean skipDeviceTransfer) { + return String.format(""" + { + "sessionId": "%s", + "accountAttributes": {}, + "skipDeviceTransfer": %s + } + """, encodeSessionId(sessionId), skipDeviceTransfer); + } + + /** + * Valid request JSON with the give session ID + */ + private static String requestJson(final String sessionId) { + return requestJson(sessionId, false); + } + + /** + * Request JSON in the shape of {@link org.whispersystems.textsecuregcm.entities.RegistrationRequest}, but that fails + * validation + */ + private static String invalidRequestJson() { + return """ + { + "sessionId": null, + "accountAttributes": {}, + "skipDeviceTransfer": false + } + """; + } + + /** + * Request JSON that cannot be marshalled into {@link org.whispersystems.textsecuregcm.entities.RegistrationRequest} + */ + private static String unprocessableJson() { + return """ + { + "sessionId": [] + } + """; + } + + private static String authorizationHeader(final String number) { + return "Basic " + Base64.getEncoder().encodeToString( + String.format("%s:password", number).getBytes(StandardCharsets.UTF_8)); + } + + private static String encodeSessionId(final String sessionId) { + return Base64.getEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8)); + } + +}