diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index c552c772b..bf7bb7f7a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -767,7 +767,7 @@ public class WhisperServerService extends Application createRegistrationSessionSession( - final Phonenumber.PhoneNumber phoneNumber, final Duration timeout) { + final Phonenumber.PhoneNumber phoneNumber, final boolean accountExistsWithPhoneNumber, final Duration timeout) { final long e164 = Long.parseLong( PhoneNumberUtil.getInstance().format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164).substring(1)); return toCompletableFuture(stub.withDeadline(toDeadline(timeout)) .createSession(CreateRegistrationSessionRequest.newBuilder() .setE164(e164) + .setAccountExistsWithE164(accountExistsWithPhoneNumber) .build())) .thenApply(response -> switch (response.getResponseCase()) { case SESSION_METADATA -> buildSessionResponseFromMetadata(response.getSessionMetadata()); @@ -111,8 +112,8 @@ public class RegistrationServiceClient implements Managed { @Deprecated public CompletableFuture createRegistrationSession(final Phonenumber.PhoneNumber phoneNumber, - final Duration timeout) { - return createRegistrationSessionSession(phoneNumber, timeout) + final boolean accountExistsWithPhoneNumber, final Duration timeout) { + return createRegistrationSessionSession(phoneNumber, accountExistsWithPhoneNumber, timeout) .thenApply(RegistrationServiceSession::id); } diff --git a/service/src/main/proto/RegistrationService.proto b/service/src/main/proto/RegistrationService.proto index 8f8c2ae7f..ea3f33ca2 100644 --- a/service/src/main/proto/RegistrationService.proto +++ b/service/src/main/proto/RegistrationService.proto @@ -35,6 +35,12 @@ message CreateRegistrationSessionRequest { * The phone number for which to create a new registration session. */ uint64 e164 = 1; + + /** + * Indicates whether an account already exists with the given e164 (i.e. this + * session represents a "re-registration" attempt). + */ + bool account_exists_with_e164 = 2; } message CreateRegistrationSessionResponse { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java index 54c5520ab..e150c7ceb 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java @@ -9,6 +9,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.argThat; @@ -386,7 +387,7 @@ class AccountControllerTest { @Test void testGetFcmPreauth() throws NumberParseException { - when(registrationServiceClient.createRegistrationSession(any(), any())) + when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.completedFuture(new byte[16])); when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty()); @@ -401,7 +402,7 @@ class AccountControllerTest { final ArgumentCaptor challengeTokenCaptor = ArgumentCaptor.forClass(String.class); verify(registrationServiceClient).createRegistrationSession( - eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), any()); + eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), anyBoolean(), any()); verify(pushNotificationManager).sendRegistrationChallengeNotification( eq("mytoken"), eq(PushNotification.TokenType.FCM), challengeTokenCaptor.capture()); @@ -411,7 +412,7 @@ class AccountControllerTest { @Test void testGetFcmPreauthIvoryCoast() throws NumberParseException { - when(registrationServiceClient.createRegistrationSession(any(), any())) + when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.completedFuture(new byte[16])); Response response = resources.getJerseyTest() @@ -424,7 +425,7 @@ class AccountControllerTest { final ArgumentCaptor challengeTokenCaptor = ArgumentCaptor.forClass(String.class); verify(registrationServiceClient).createRegistrationSession( - eq(PhoneNumberUtil.getInstance().parse("+2250707312345", null)), any()); + eq(PhoneNumberUtil.getInstance().parse("+2250707312345", null)), anyBoolean(), any()); verify(pushNotificationManager).sendRegistrationChallengeNotification( eq("mytoken"), eq(PushNotification.TokenType.FCM), challengeTokenCaptor.capture()); @@ -434,7 +435,7 @@ class AccountControllerTest { @Test void testGetApnPreauth() throws NumberParseException { - when(registrationServiceClient.createRegistrationSession(any(), any())) + when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.completedFuture(new byte[16])); when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty()); @@ -449,7 +450,7 @@ class AccountControllerTest { final ArgumentCaptor challengeTokenCaptor = ArgumentCaptor.forClass(String.class); verify(registrationServiceClient).createRegistrationSession( - eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), any()); + eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), anyBoolean(), any()); verify(pushNotificationManager).sendRegistrationChallengeNotification( eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture()); @@ -459,7 +460,7 @@ class AccountControllerTest { @Test void testGetApnPreauthExplicitVoip() throws NumberParseException { - when(registrationServiceClient.createRegistrationSession(any(), any())) + when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.completedFuture(new byte[16])); when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty()); @@ -475,7 +476,7 @@ class AccountControllerTest { final ArgumentCaptor challengeTokenCaptor = ArgumentCaptor.forClass(String.class); verify(registrationServiceClient).createRegistrationSession( - eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), any()); + eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), anyBoolean(), any()); verify(pushNotificationManager).sendRegistrationChallengeNotification( eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture()); @@ -485,7 +486,7 @@ class AccountControllerTest { @Test void testGetApnPreauthExplicitNoVoip() throws NumberParseException { - when(registrationServiceClient.createRegistrationSession(any(), any())) + when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.completedFuture(new byte[16])); when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty()); @@ -501,7 +502,7 @@ class AccountControllerTest { final ArgumentCaptor challengeTokenCaptor = ArgumentCaptor.forClass(String.class); verify(registrationServiceClient).createRegistrationSession( - eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), any()); + eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), anyBoolean(), any()); verify(pushNotificationManager).sendRegistrationChallengeNotification( eq("mytoken"), eq(PushNotification.TokenType.APN), challengeTokenCaptor.capture()); @@ -546,7 +547,7 @@ class AccountControllerTest { void testGetPreauthExistingSession() throws NumberParseException { final String existingPushCode = "existing-push-code"; - when(registrationServiceClient.createRegistrationSession(any(), any())) + when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.completedFuture(new byte[16])); when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn( @@ -561,7 +562,7 @@ class AccountControllerTest { final ArgumentCaptor challengeTokenCaptor = ArgumentCaptor.forClass(String.class); - verify(registrationServiceClient, never()).createRegistrationSession(any(), any()); + verify(registrationServiceClient, never()).createRegistrationSession(any(), anyBoolean(), any()); verify(pushNotificationManager).sendRegistrationChallengeNotification( eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture()); @@ -571,7 +572,7 @@ class AccountControllerTest { @Test void testGetPreauthExistingSessionWithoutPushCode() throws NumberParseException { - when(registrationServiceClient.createRegistrationSession(any(), any())) + when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.completedFuture(new byte[16])); when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn( @@ -586,7 +587,7 @@ class AccountControllerTest { final ArgumentCaptor challengeTokenCaptor = ArgumentCaptor.forClass(String.class); - verify(registrationServiceClient, never()).createRegistrationSession(any(), any()); + verify(registrationServiceClient, never()).createRegistrationSession(any(), anyBoolean(), any()); verify(pushNotificationManager).sendRegistrationChallengeNotification( eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture()); @@ -624,7 +625,7 @@ class AccountControllerTest { void testSendCode() { final byte[] sessionId = "session-id".getBytes(StandardCharsets.UTF_8); - when(registrationServiceClient.createRegistrationSession(any(), any())) + when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) @@ -648,7 +649,7 @@ class AccountControllerTest { @Test void testSendCodeRateLimited() { - when(registrationServiceClient.createRegistrationSession(any(), any())) + when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofMinutes(10), true))); Response response = @@ -709,7 +710,7 @@ class AccountControllerTest { when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); - when(registrationServiceClient.createRegistrationSession(any(), any())) + when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); Response response = @@ -732,7 +733,7 @@ class AccountControllerTest { when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); - when(registrationServiceClient.createRegistrationSession(any(), any())) + when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); Response response = @@ -785,7 +786,7 @@ class AccountControllerTest { when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); - when(registrationServiceClient.createRegistrationSession(any(), any())) + when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); Response response = @@ -809,7 +810,7 @@ class AccountControllerTest { when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); - when(registrationServiceClient.createRegistrationSession(any(), any())) + when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); Response response = @@ -834,7 +835,7 @@ class AccountControllerTest { when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); - when(registrationServiceClient.createRegistrationSession(any(), any())) + when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); Response response = @@ -944,7 +945,7 @@ class AccountControllerTest { when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); - when(registrationServiceClient.createRegistrationSession(any(), any())) + when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); Response response = @@ -975,7 +976,7 @@ class AccountControllerTest { when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); - when(registrationServiceClient.createRegistrationSession(any(), any())) + when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); Response response = @@ -998,7 +999,7 @@ class AccountControllerTest { when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); - when(registrationServiceClient.createRegistrationSession(any(), any())) + when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); Response response = @@ -2092,7 +2093,7 @@ class AccountControllerTest { when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); - when(registrationServiceClient.createRegistrationSession(any(), any())) + when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.completedFuture(sessionId)); Response response = diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/VerificationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/VerificationControllerTest.java index e9e3f994f..9b62bbb82 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/VerificationControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/VerificationControllerTest.java @@ -11,7 +11,9 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -19,6 +21,8 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import com.google.common.net.HttpHeaders; +import com.google.i18n.phonenumbers.NumberParseException; +import com.google.i18n.phonenumbers.PhoneNumberUtil; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.ResourceExtension; import java.io.IOException; @@ -46,6 +50,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; import org.whispersystems.textsecuregcm.captcha.AssessmentResult; import org.whispersystems.textsecuregcm.captcha.RegistrationCaptchaManager; @@ -62,6 +67,8 @@ import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient; import org.whispersystems.textsecuregcm.registration.RegistrationServiceException; import org.whispersystems.textsecuregcm.registration.RegistrationServiceSenderException; import org.whispersystems.textsecuregcm.registration.VerificationSession; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; import org.whispersystems.textsecuregcm.storage.VerificationSessionManager; import org.whispersystems.textsecuregcm.util.SystemMapper; @@ -81,6 +88,7 @@ class VerificationControllerTest { private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock( RegistrationRecoveryPasswordsManager.class); private final RateLimiters rateLimiters = mock(RateLimiters.class); + private final AccountsManager accountsManager = mock(AccountsManager.class); private final Clock clock = Clock.systemUTC(); private final RateLimiter captchaLimiter = mock(RateLimiter.class); @@ -96,7 +104,7 @@ class VerificationControllerTest { .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource( new VerificationController(registrationServiceClient, verificationSessionManager, pushNotificationManager, - registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters, clock)) + registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters, accountsManager, clock)) .build(); @BeforeEach @@ -105,6 +113,8 @@ class VerificationControllerTest { .thenReturn(captchaLimiter); when(rateLimiters.getVerificationPushChallengeLimiter()) .thenReturn(pushChallengeLimiter); + + when(accountsManager.getByE164(any())).thenReturn(Optional.empty()); } @ParameterizedTest @@ -153,7 +163,7 @@ class VerificationControllerTest { @Test void createSessionRateLimited() { - when(registrationServiceClient.createRegistrationSessionSession(any(), any())) + when(registrationServiceClient.createRegistrationSessionSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null, true))); final Invocation.Builder request = resources.getJerseyTest() @@ -167,7 +177,7 @@ class VerificationControllerTest { @Test void createSessionRegistrationServiceError() { - when(registrationServiceClient.createRegistrationSessionSession(any(), any())) + when(registrationServiceClient.createRegistrationSessionSession(any(), anyBoolean(), any())) .thenReturn(CompletableFuture.failedFuture(new RuntimeException("expected service error"))); final Invocation.Builder request = resources.getJerseyTest() @@ -183,7 +193,7 @@ class VerificationControllerTest { @MethodSource void createSessionSuccess(final String pushToken, final String pushTokenType, final List expectedRequestedInformation) { - when(registrationServiceClient.createRegistrationSessionSession(any(), any())) + when(registrationServiceClient.createRegistrationSessionSession(any(), anyBoolean(), any())) .thenReturn( CompletableFuture.completedFuture( new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null, @@ -214,6 +224,37 @@ class VerificationControllerTest { ); } + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void createSessionReregistration(final boolean isReregistration) throws NumberParseException { + when(registrationServiceClient.createRegistrationSessionSession(any(), anyBoolean(), any())) + .thenReturn( + CompletableFuture.completedFuture( + new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null, + SESSION_EXPIRATION_SECONDS))); + + when(verificationSessionManager.insert(any(), any())) + .thenReturn(CompletableFuture.completedFuture(null)); + + when(accountsManager.getByE164(NUMBER)) + .thenReturn(isReregistration ? Optional.of(mock(Account.class)) : Optional.empty()); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/verification/session") + .request() + .header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1"); + + try (final Response response = request.post(Entity.json(createSessionJson(NUMBER, null, null)))) { + assertEquals(HttpStatus.SC_OK, response.getStatus()); + + verify(registrationServiceClient).createRegistrationSessionSession( + eq(PhoneNumberUtil.getInstance().parse(NUMBER, null)), + eq(isReregistration), + any() + ); + } + } + @Test void patchSessionMalformedId() { final String invalidSessionId = "()()()";