diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java index 273770b01..0af4b9c75 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java @@ -29,6 +29,7 @@ import java.util.ArrayList; import java.util.Map; import java.util.Optional; import java.util.UUID; +import java.util.concurrent.CompletionException; import javax.annotation.Nullable; import javax.servlet.http.HttpServletRequest; import javax.validation.Valid; @@ -232,7 +233,7 @@ public class AccountController { @PathParam("token") String pushToken, @PathParam("number") String number, @QueryParam("voip") @DefaultValue("true") boolean useVoip) - throws ImpossiblePhoneNumberException, NonNormalizedPhoneNumberException { + throws ImpossiblePhoneNumberException, NonNormalizedPhoneNumberException, RateLimitExceededException { final PushNotification.TokenType tokenType = switch(pushType) { case "apn" -> useVoip ? PushNotification.TokenType.APN_VOIP : PushNotification.TokenType.APN; @@ -242,9 +243,18 @@ public class AccountController { Util.requireNormalizedNumber(number); - String pushChallenge = generatePushChallenge(); - StoredVerificationCode storedVerificationCode = - new StoredVerificationCode(null, clock.millis(), pushChallenge, null); + final Phonenumber.PhoneNumber phoneNumber; + try { + phoneNumber = PhoneNumberUtil.getInstance().parse(number, null); + } catch (final NumberParseException e) { + // This should never happen since we just verified that the number is already normalized + throw new BadRequestException("Bad phone number"); + } + + final String pushChallenge = generatePushChallenge(); + final byte[] sessionId = createRegistrationSession(phoneNumber); + final StoredVerificationCode storedVerificationCode = + new StoredVerificationCode(null, clock.millis(), pushChallenge, sessionId); pendingAccounts.store(number, storedVerificationCode); pushNotificationManager.sendRegistrationChallengeNotification(pushToken, tokenType, storedVerificationCode.pushCode()); @@ -346,8 +356,16 @@ public class AccountController { } }).orElse(ClientType.UNKNOWN); - final byte[] sessionId = registrationServiceClient.sendRegistrationCode(phoneNumber, - messageTransport, clientType, acceptLanguage.orElse(null), REGISTRATION_RPC_TIMEOUT).join(); + // During the transition to explicit session creation, some previously-stored records may not have a session ID; + // after the transition, we can assume that any existing record has an associated session ID. + final byte[] sessionId = maybeStoredVerificationCode.isPresent() && maybeStoredVerificationCode.get().sessionId() != null ? + maybeStoredVerificationCode.get().sessionId() : createRegistrationSession(phoneNumber); + + registrationServiceClient.sendRegistrationCode(sessionId, + messageTransport, + clientType, + acceptLanguage.orElse(null), + REGISTRATION_RPC_TIMEOUT).join(); final StoredVerificationCode storedVerificationCode = new StoredVerificationCode(null, clock.millis(), @@ -940,4 +958,23 @@ public class AccountController { return Hex.toStringCondensed(challenge); } + + private byte[] createRegistrationSession(final Phonenumber.PhoneNumber phoneNumber) throws RateLimitExceededException { + + try { + return registrationServiceClient.createRegistrationSession(phoneNumber, REGISTRATION_RPC_TIMEOUT).join(); + } catch (final CompletionException e) { + Throwable cause = e; + + while (cause instanceof CompletionException) { + cause = cause.getCause(); + } + + if (cause instanceof RateLimitExceededException rateLimitExceededException) { + throw rateLimitExceededException; + } + + throw e; + } + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/registration/RegistrationServiceClient.java b/service/src/main/java/org/whispersystems/textsecuregcm/registration/RegistrationServiceClient.java index a75353596..1dfb5d4ac 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/registration/RegistrationServiceClient.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/registration/RegistrationServiceClient.java @@ -16,15 +16,19 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.time.Duration; +import java.time.Instant; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import org.apache.commons.lang3.StringUtils; import org.checkerframework.checker.nullness.qual.Nullable; import org.signal.registration.rpc.CheckVerificationCodeRequest; import org.signal.registration.rpc.CheckVerificationCodeResponse; +import org.signal.registration.rpc.CreateRegistrationSessionRequest; import org.signal.registration.rpc.RegistrationServiceGrpc; import org.signal.registration.rpc.SendVerificationCodeRequest; +import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; public class RegistrationServiceClient implements Managed { @@ -52,17 +56,36 @@ public class RegistrationServiceClient implements Managed { this.callbackExecutor = callbackExecutor; } - public CompletableFuture sendRegistrationCode(final Phonenumber.PhoneNumber phoneNumber, + public CompletableFuture createRegistrationSession(final Phonenumber.PhoneNumber phoneNumber, final Duration timeout) { + final long e164 = Long.parseLong( + PhoneNumberUtil.getInstance().format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164).substring(1)); + + return toCompletableFuture(stub.withDeadline(toDeadline(timeout)) + .createSession(CreateRegistrationSessionRequest.newBuilder() + .setE164(e164) + .build())) + .thenApply(response -> switch (response.getResponseCase()) { + case SESSION_METADATA -> response.getSessionMetadata().getSessionId().toByteArray(); + + case ERROR -> { + switch (response.getError().getErrorType()) { + case ERROR_TYPE_RATE_LIMITED -> throw new CompletionException(new RateLimitExceededException(Duration.ofSeconds(response.getError().getRetryAfterSeconds()))); + default -> throw new RuntimeException("Unrecognized error type from registration service: " + response.getError().getErrorType()); + } + } + + case RESPONSE_NOT_SET -> throw new RuntimeException("No response from registration service"); + }); + } + + public CompletableFuture sendRegistrationCode(final byte[] sessionId, final MessageTransport messageTransport, final ClientType clientType, @Nullable final String acceptLanguage, final Duration timeout) { - final long e164 = Long.parseLong( - PhoneNumberUtil.getInstance().format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164).substring(1)); - final SendVerificationCodeRequest.Builder requestBuilder = SendVerificationCodeRequest.newBuilder() - .setE164(e164) + .setSessionId(ByteString.copyFrom(sessionId)) .setTransport(getRpcMessageTransport(messageTransport)) .setClientType(getRpcClientType(clientType)); diff --git a/service/src/main/proto/RegistrationService.proto b/service/src/main/proto/RegistrationService.proto index ac49978f2..5f19e9975 100644 --- a/service/src/main/proto/RegistrationService.proto +++ b/service/src/main/proto/RegistrationService.proto @@ -6,8 +6,13 @@ package org.signal.registration.rpc; service RegistrationService { /** - * Sends a verification code to a destination phone number and returns the - * ID of the newly-created registration session. + * Create a new registration session for a given destination phone number. + */ + rpc create_session (CreateRegistrationSessionRequest) returns (CreateRegistrationSessionResponse) {} + + /** + * Sends a verification code to a destination phone number within the context + * of a previously-created registration session. */ rpc send_verification_code (SendVerificationCodeRequest) returns (SendVerificationCodeResponse) {} @@ -18,9 +23,71 @@ service RegistrationService { rpc check_verification_code (CheckVerificationCodeRequest) returns (CheckVerificationCodeResponse) {} } +message CreateRegistrationSessionRequest { + /** + * The phone number for which to create a new registration session. + */ + uint64 e164 = 1; +} + +message CreateRegistrationSessionResponse { + oneof response { + /** + * Metadata for the newly-created session. + */ + RegistrationSessionMetadata session_metadata = 1; + + /** + * A response explaining why a session could not be created as requested. + */ + CreateRegistrationSessionError error = 2; + } +} + +message RegistrationSessionMetadata { + /** + * An opaque sequence of bytes that uniquely identifies the registration + * session associated with this registration attempt. + */ + bytes session_id = 1; +} + +message CreateRegistrationSessionError { + /** + * The type of error that prevented a session from being created. + */ + CreateRegistrationSessionErrorType error_type = 1; + + /** + * Indicates that this error is fatal and should not be retried without + * modification. Non-fatal errors may be retried without modification after + * the duration indicated by `retry_after_seconds`. + */ + bool fatal = 2; + + /** + * If this error is not fatal (see `fatal`), indicates the duration in seconds + * from the present after which the request may be retried without + * modification. This value has no meaning otherwise. + */ + uint64 retry_after_seconds = 3; +} + +enum CreateRegistrationSessionErrorType { + ERROR_TYPE_UNSPECIFIED = 0; + + /** + * Indicates that a session could not be created because too many requests to + * create a session for the given phone number have been received in some + * window of time. Callers should wait and try again later. + */ + ERROR_TYPE_RATE_LIMITED = 1; +} + message SendVerificationCodeRequest { /** - * The phone number to which to send a verification code. + * The phone number to which to send a verification code. Ignored (and may be + * null if `session_id` is set. */ uint64 e164 = 1; @@ -31,8 +98,8 @@ message SendVerificationCodeRequest { MessageTransport transport = 2; /** - * The value of the `Accept-Language` header provided by remote clients (if - * any). + * A prioritized list of languages accepted by the destination; should be + * provided in the same format as the value of an HTTP Accept-Language header. */ string accept_language = 3; @@ -40,6 +107,11 @@ message SendVerificationCodeRequest { * The type of client requesting a verification code. */ ClientType client_type = 4; + + /** + * The ID of a session within which to send (or re-send) a verification code. + */ + bytes session_id = 5; } enum MessageTransport { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java index 32adf01d1..1812b8425 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java @@ -29,7 +29,6 @@ import com.google.common.collect.ImmutableSet; import com.google.common.net.HttpHeaders; import com.google.i18n.phonenumbers.NumberParseException; import com.google.i18n.phonenumbers.PhoneNumberUtil; -import com.google.i18n.phonenumbers.Phonenumber; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.ResourceExtension; @@ -339,7 +338,10 @@ class AccountControllerTest { } @Test - void testGetFcmPreauth() throws Exception { + void testGetFcmPreauth() throws NumberParseException { + when(registrationServiceClient.createRegistrationSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(new byte[16])); + Response response = resources.getJerseyTest() .target("/v1/accounts/fcm/preauth/mytoken/+14152222222") .request() @@ -349,6 +351,9 @@ class AccountControllerTest { final ArgumentCaptor challengeTokenCaptor = ArgumentCaptor.forClass(String.class); + verify(registrationServiceClient).createRegistrationSession( + eq(PhoneNumberUtil.getInstance().parse("+14152222222", null)), any()); + verify(pushNotificationManager).sendRegistrationChallengeNotification( eq("mytoken"), eq(PushNotification.TokenType.FCM), challengeTokenCaptor.capture()); @@ -356,7 +361,10 @@ class AccountControllerTest { } @Test - void testGetFcmPreauthIvoryCoast() throws Exception { + void testGetFcmPreauthIvoryCoast() throws NumberParseException { + when(registrationServiceClient.createRegistrationSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(new byte[16])); + Response response = resources.getJerseyTest() .target("/v1/accounts/fcm/preauth/mytoken/+2250707312345") .request() @@ -366,6 +374,9 @@ class AccountControllerTest { final ArgumentCaptor challengeTokenCaptor = ArgumentCaptor.forClass(String.class); + verify(registrationServiceClient).createRegistrationSession( + eq(PhoneNumberUtil.getInstance().parse("+2250707312345", null)), any()); + verify(pushNotificationManager).sendRegistrationChallengeNotification( eq("mytoken"), eq(PushNotification.TokenType.FCM), challengeTokenCaptor.capture()); @@ -373,7 +384,10 @@ class AccountControllerTest { } @Test - void testGetApnPreauth() throws Exception { + void testGetApnPreauth() throws NumberParseException { + when(registrationServiceClient.createRegistrationSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(new byte[16])); + Response response = resources.getJerseyTest() .target("/v1/accounts/apn/preauth/mytoken/+14152222222") .request() @@ -383,6 +397,9 @@ class AccountControllerTest { final ArgumentCaptor challengeTokenCaptor = ArgumentCaptor.forClass(String.class); + verify(registrationServiceClient).createRegistrationSession( + eq(PhoneNumberUtil.getInstance().parse("+14152222222", null)), any()); + verify(pushNotificationManager).sendRegistrationChallengeNotification( eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture()); @@ -390,7 +407,10 @@ class AccountControllerTest { } @Test - void testGetApnPreauthExplicitVoip() throws Exception { + void testGetApnPreauthExplicitVoip() throws NumberParseException { + when(registrationServiceClient.createRegistrationSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(new byte[16])); + Response response = resources.getJerseyTest() .target("/v1/accounts/apn/preauth/mytoken/+14152222222") .queryParam("voip", "true") @@ -401,6 +421,9 @@ class AccountControllerTest { final ArgumentCaptor challengeTokenCaptor = ArgumentCaptor.forClass(String.class); + verify(registrationServiceClient).createRegistrationSession( + eq(PhoneNumberUtil.getInstance().parse("+14152222222", null)), any()); + verify(pushNotificationManager).sendRegistrationChallengeNotification( eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture()); @@ -408,7 +431,10 @@ class AccountControllerTest { } @Test - void testGetApnPreauthExplicitNoVoip() throws Exception { + void testGetApnPreauthExplicitNoVoip() throws NumberParseException { + when(registrationServiceClient.createRegistrationSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(new byte[16])); + Response response = resources.getJerseyTest() .target("/v1/accounts/apn/preauth/mytoken/+14152222222") .queryParam("voip", "false") @@ -419,6 +445,9 @@ class AccountControllerTest { final ArgumentCaptor challengeTokenCaptor = ArgumentCaptor.forClass(String.class); + verify(registrationServiceClient).createRegistrationSession( + eq(PhoneNumberUtil.getInstance().parse("+14152222222", null)), any()); + verify(pushNotificationManager).sendRegistrationChallengeNotification( eq("mytoken"), eq(PushNotification.TokenType.APN), challengeTokenCaptor.capture()); @@ -435,6 +464,7 @@ class AccountControllerTest { assertThat(response.getStatus()).isEqualTo(400); assertThat(response.readEntity(String.class)).isBlank(); + verifyNoInteractions(registrationServiceClient); verifyNoInteractions(pushNotificationManager); } @@ -453,14 +483,43 @@ class AccountControllerTest { assertThat(responseEntity.getOriginalNumber()).isEqualTo(number); assertThat(responseEntity.getNormalizedNumber()).isEqualTo("+447700900111"); + verifyNoInteractions(registrationServiceClient); verifyNoInteractions(pushNotificationManager); } @Test - void testSendCode() throws NumberParseException { - + void testSendCodeWithExistingSessionFromPreauth() { final byte[] sessionId = "session-id".getBytes(StandardCharsets.UTF_8); + when(pendingAccountsManager.getCodeForNumber(SENDER)) + .thenReturn(Optional.of(new StoredVerificationCode(null, System.currentTimeMillis(), "1234-push", sessionId))); + + when(registrationServiceClient.sendRegistrationCode(eq(sessionId), any(), any(), any(), any())) + .thenReturn(CompletableFuture.completedFuture(sessionId)); + + Response response = + resources.getJerseyTest() + .target(String.format("/v1/accounts/sms/code/%s", SENDER)) + .queryParam("challenge", "1234-push") + .request() + .header(HttpHeaders.X_FORWARDED_FOR, NICE_HOST) + .get(); + + assertThat(response.getStatus()).isEqualTo(200); + + verify(registrationServiceClient).sendRegistrationCode(sessionId, MessageTransport.SMS, ClientType.UNKNOWN, null, AccountController.REGISTRATION_RPC_TIMEOUT); + verify(pendingAccountsManager).store(eq(SENDER), argThat( + storedVerificationCode -> Arrays.equals(storedVerificationCode.sessionId(), sessionId) && + "1234-push".equals(storedVerificationCode.pushCode()))); + } + + @Test + void testSendCode() { + final byte[] sessionId = "session-id".getBytes(StandardCharsets.UTF_8); + + when(registrationServiceClient.createRegistrationSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(sessionId)); + when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); @@ -474,14 +533,30 @@ class AccountControllerTest { assertThat(response.getStatus()).isEqualTo(200); - final Phonenumber.PhoneNumber expectedPhoneNumber = PhoneNumberUtil.getInstance().parse(SENDER, null); - - verify(registrationServiceClient).sendRegistrationCode(expectedPhoneNumber, MessageTransport.SMS, ClientType.UNKNOWN, null, AccountController.REGISTRATION_RPC_TIMEOUT); + verify(registrationServiceClient).sendRegistrationCode(sessionId, MessageTransport.SMS, ClientType.UNKNOWN, null, AccountController.REGISTRATION_RPC_TIMEOUT); verify(pendingAccountsManager).store(eq(SENDER), argThat( storedVerificationCode -> Arrays.equals(storedVerificationCode.sessionId(), sessionId) && "1234-push".equals(storedVerificationCode.pushCode()))); } + @Test + void testSendCodeRateLimited() { + when(registrationServiceClient.createRegistrationSession(any(), any())) + .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofMinutes(10)))); + + Response response = + resources.getJerseyTest() + .target(String.format("/v1/accounts/sms/code/%s", SENDER)) + .queryParam("challenge", "1234-push") + .request() + .header(HttpHeaders.X_FORWARDED_FOR, NICE_HOST) + .get(); + + assertThat(response.getStatus()).isEqualTo(413); + + verify(registrationServiceClient, never()).sendRegistrationCode(any(), any(), any(), any(), any()); + } + @Test void testSendCodeImpossibleNumber() { final Response response = @@ -520,10 +595,15 @@ class AccountControllerTest { } @Test - public void testSendCodeVoiceNoLocale() throws NumberParseException { + public void testSendCodeVoiceNoLocale() { + + final byte[] sessionId = "session".getBytes(StandardCharsets.UTF_8); when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) - .thenReturn(CompletableFuture.completedFuture(new byte[16])); + .thenReturn(CompletableFuture.completedFuture(sessionId)); + + when(registrationServiceClient.createRegistrationSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(sessionId)); Response response = resources.getJerseyTest() @@ -533,17 +613,20 @@ class AccountControllerTest { .header(HttpHeaders.X_FORWARDED_FOR, NICE_HOST) .get(); - final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().parse(SENDER, null); - assertThat(response.getStatus()).isEqualTo(200); - verify(registrationServiceClient).sendRegistrationCode(phoneNumber, MessageTransport.VOICE, ClientType.UNKNOWN, null, AccountController.REGISTRATION_RPC_TIMEOUT); + verify(registrationServiceClient).sendRegistrationCode(sessionId, MessageTransport.VOICE, ClientType.UNKNOWN, null, AccountController.REGISTRATION_RPC_TIMEOUT); } @Test - void testSendCodeWithValidPreauth() throws NumberParseException { + void testSendCodeWithValidPreauth() { + + final byte[] sessionId = "session".getBytes(StandardCharsets.UTF_8); when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) - .thenReturn(CompletableFuture.completedFuture(new byte[16])); + .thenReturn(CompletableFuture.completedFuture(sessionId)); + + when(registrationServiceClient.createRegistrationSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(sessionId)); Response response = resources.getJerseyTest() @@ -555,9 +638,7 @@ class AccountControllerTest { assertThat(response.getStatus()).isEqualTo(200); - final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().parse(SENDER_PREAUTH, null); - - verify(registrationServiceClient).sendRegistrationCode(phoneNumber, MessageTransport.SMS, ClientType.UNKNOWN, null, AccountController.REGISTRATION_RPC_TIMEOUT); + verify(registrationServiceClient).sendRegistrationCode(sessionId, MessageTransport.SMS, ClientType.UNKNOWN, null, AccountController.REGISTRATION_RPC_TIMEOUT); } @Test @@ -590,10 +671,15 @@ class AccountControllerTest { } @Test - void testSendiOSCode() throws NumberParseException { + void testSendiOSCode() { + + final byte[] sessionId = "session".getBytes(StandardCharsets.UTF_8); when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) - .thenReturn(CompletableFuture.completedFuture(new byte[16])); + .thenReturn(CompletableFuture.completedFuture(sessionId)); + + when(registrationServiceClient.createRegistrationSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(sessionId)); Response response = resources.getJerseyTest() @@ -606,15 +692,18 @@ class AccountControllerTest { assertThat(response.getStatus()).isEqualTo(200); - final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().parse(SENDER, null); - - verify(registrationServiceClient).sendRegistrationCode(phoneNumber, MessageTransport.SMS, ClientType.IOS, null, AccountController.REGISTRATION_RPC_TIMEOUT); + verify(registrationServiceClient).sendRegistrationCode(sessionId, MessageTransport.SMS, ClientType.IOS, null, AccountController.REGISTRATION_RPC_TIMEOUT); } @Test - void testSendAndroidNgCode() throws NumberParseException { + void testSendAndroidNgCode() { + final byte[] sessionId = "session".getBytes(StandardCharsets.UTF_8); + when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) - .thenReturn(CompletableFuture.completedFuture(new byte[16])); + .thenReturn(CompletableFuture.completedFuture(sessionId)); + + when(registrationServiceClient.createRegistrationSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(sessionId)); Response response = resources.getJerseyTest() @@ -627,16 +716,19 @@ class AccountControllerTest { assertThat(response.getStatus()).isEqualTo(200); - final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().parse(SENDER, null); - - verify(registrationServiceClient).sendRegistrationCode(phoneNumber, MessageTransport.SMS, ClientType.ANDROID_WITHOUT_FCM, null, AccountController.REGISTRATION_RPC_TIMEOUT); + verify(registrationServiceClient).sendRegistrationCode(sessionId, MessageTransport.SMS, ClientType.ANDROID_WITHOUT_FCM, null, AccountController.REGISTRATION_RPC_TIMEOUT); } @Test void testSendWithValidCaptcha() throws NumberParseException, IOException { + final byte[] sessionId = "session".getBytes(StandardCharsets.UTF_8); + when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) - .thenReturn(CompletableFuture.completedFuture(new byte[16])); + .thenReturn(CompletableFuture.completedFuture(sessionId)); + + when(registrationServiceClient.createRegistrationSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(sessionId)); Response response = resources.getJerseyTest() @@ -648,10 +740,8 @@ class AccountControllerTest { assertThat(response.getStatus()).isEqualTo(200); - final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().parse(SENDER, null); - verify(captchaChecker).verify(eq(VALID_CAPTCHA_TOKEN), eq(NICE_HOST)); - verify(registrationServiceClient).sendRegistrationCode(phoneNumber, MessageTransport.SMS, ClientType.UNKNOWN, null, AccountController.REGISTRATION_RPC_TIMEOUT); + verify(registrationServiceClient).sendRegistrationCode(sessionId, MessageTransport.SMS, ClientType.UNKNOWN, null, AccountController.REGISTRATION_RPC_TIMEOUT); } @Test @@ -742,8 +832,13 @@ class AccountControllerTest { when(pendingAccountsManager.getCodeForNumber(number)) .thenReturn(Optional.of(new StoredVerificationCode(null, System.currentTimeMillis(), challenge, null))); + final byte[] sessionId = "session".getBytes(StandardCharsets.UTF_8); + when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) - .thenReturn(CompletableFuture.completedFuture(new byte[16])); + .thenReturn(CompletableFuture.completedFuture(sessionId)); + + when(registrationServiceClient.createRegistrationSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(sessionId)); Response response = resources.getJerseyTest() @@ -755,9 +850,7 @@ class AccountControllerTest { if (expectSendCode) { assertThat(response.getStatus()).isEqualTo(200); - - final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().parse(number, null); - verify(registrationServiceClient).sendRegistrationCode(phoneNumber, MessageTransport.SMS, ClientType.UNKNOWN, null, AccountController.REGISTRATION_RPC_TIMEOUT); + verify(registrationServiceClient).sendRegistrationCode(sessionId, MessageTransport.SMS, ClientType.UNKNOWN, null, AccountController.REGISTRATION_RPC_TIMEOUT); } else { assertThat(response.getStatus()).isEqualTo(402); verifyNoInteractions(registrationServiceClient); @@ -765,13 +858,17 @@ class AccountControllerTest { } @Test - void testSendRestrictedIn() throws NumberParseException { + void testSendRestrictedIn() { final String challenge = "challenge"; + final byte[] sessionId = "session".getBytes(StandardCharsets.UTF_8); when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.of(new StoredVerificationCode(null, System.currentTimeMillis(), challenge, null))); when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) - .thenReturn(CompletableFuture.completedFuture(new byte[16])); + .thenReturn(CompletableFuture.completedFuture(sessionId)); + + when(registrationServiceClient.createRegistrationSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(sessionId)); Response response = resources.getJerseyTest() @@ -783,18 +880,19 @@ class AccountControllerTest { assertThat(response.getStatus()).isEqualTo(200); - final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().parse(SENDER, null); - - verify(registrationServiceClient).sendRegistrationCode(phoneNumber, MessageTransport.SMS, ClientType.UNKNOWN, null, AccountController.REGISTRATION_RPC_TIMEOUT); + verify(registrationServiceClient).sendRegistrationCode(sessionId, MessageTransport.SMS, ClientType.UNKNOWN, null, AccountController.REGISTRATION_RPC_TIMEOUT); } @Test - void testSendCodeTestDeviceNumber() throws NumberParseException { + void testSendCodeTestDeviceNumber() { final byte[] sessionId = "session-id".getBytes(StandardCharsets.UTF_8); when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); + when(registrationServiceClient.createRegistrationSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(sessionId)); + Response response = resources.getJerseyTest() .target(String.format("/v1/accounts/sms/code/%s", TEST_NUMBER)) @@ -809,8 +907,7 @@ class AccountControllerTest { assertThat(response.getStatus()).isEqualTo(200); // Even though no actual SMS will be sent, we leave that decision to the registration service - final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().parse(TEST_NUMBER, null); - verify(registrationServiceClient).sendRegistrationCode(phoneNumber, MessageTransport.SMS, ClientType.UNKNOWN, null, AccountController.REGISTRATION_RPC_TIMEOUT); + verify(registrationServiceClient).sendRegistrationCode(sessionId, MessageTransport.SMS, ClientType.UNKNOWN, null, AccountController.REGISTRATION_RPC_TIMEOUT); } @Test @@ -1766,8 +1863,7 @@ class AccountControllerTest { @ParameterizedTest @MethodSource - void testSignupCaptcha(final String message, final boolean enforced, final Set countryCodes, final int expectedResponseStatusCode) - throws NumberParseException { + void testSignupCaptcha(final String message, final boolean enforced, final Set countryCodes, final int expectedResponseStatusCode) { DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); when(dynamicConfigurationManager.getConfiguration()) .thenReturn(dynamicConfiguration); @@ -1777,8 +1873,13 @@ class AccountControllerTest { when(dynamicConfiguration.getCaptchaConfiguration()) .thenReturn(signupCaptchaConfig); + final byte[] sessionId = "session".getBytes(StandardCharsets.UTF_8); + when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) - .thenReturn(CompletableFuture.completedFuture(new byte[16])); + .thenReturn(CompletableFuture.completedFuture(sessionId)); + + when(registrationServiceClient.createRegistrationSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(sessionId)); Response response = resources.getJerseyTest() @@ -1790,10 +1891,8 @@ class AccountControllerTest { assertThat(response.getStatus()).isEqualTo(expectedResponseStatusCode); - final Phonenumber.PhoneNumber phoneNumber = PhoneNumberUtil.getInstance().parse(SENDER, null); - verify(registrationServiceClient, 200 == expectedResponseStatusCode ? times(1) : never()) - .sendRegistrationCode(phoneNumber, MessageTransport.SMS, ClientType.UNKNOWN, null, AccountController.REGISTRATION_RPC_TIMEOUT); + .sendRegistrationCode(sessionId, MessageTransport.SMS, ClientType.UNKNOWN, null, AccountController.REGISTRATION_RPC_TIMEOUT); } static Stream testSignupCaptcha() {