Send "account already exists" flag when creating registration sessions

This commit is contained in:
Jon Chambers 2023-03-14 17:25:49 -04:00 committed by Jon Chambers
parent 2052e62c01
commit 35606a9afd
7 changed files with 96 additions and 39 deletions

View File

@ -767,7 +767,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getCdnConfiguration().getBucket()), config.getCdnConfiguration().getBucket()),
new VerificationController(registrationServiceClient, new VerificationSessionManager(verificationSessions), new VerificationController(registrationServiceClient, new VerificationSessionManager(verificationSessions),
pushNotificationManager, registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters, pushNotificationManager, registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters,
clock) accountsManager, clock)
); );
if (config.getSubscription() != null && config.getOneTimeDonations() != null) { if (config.getSubscription() != null && config.getOneTimeDonations() != null) {
commonControllers.add(new SubscriptionController(clock, config.getSubscription(), config.getOneTimeDonations(), commonControllers.add(new SubscriptionController(clock, config.getSubscription(), config.getOneTimeDonations(),

View File

@ -238,7 +238,7 @@ public class AccountController {
storedVerificationCode = existingStoredVerificationCode; storedVerificationCode = existingStoredVerificationCode;
} }
} else { } else {
final byte[] sessionId = createRegistrationSession(phoneNumber); final byte[] sessionId = createRegistrationSession(phoneNumber, accounts.getByE164(number).isPresent());
storedVerificationCode = new StoredVerificationCode(null, clock.millis(), generatePushChallenge(), sessionId); storedVerificationCode = new StoredVerificationCode(null, clock.millis(), generatePushChallenge(), sessionId);
new StoredVerificationCode(null, clock.millis(), generatePushChallenge(), sessionId); new StoredVerificationCode(null, clock.millis(), generatePushChallenge(), sessionId);
} }
@ -345,8 +345,9 @@ public class AccountController {
// During the transition to explicit session creation, some previously-stored records may not have a session ID; // During the transition to explicit session creation, some previously-stored records may not have a session ID;
// after the transition, we can assume that any existing record has an associated session ID. // after the transition, we can assume that any existing record has an associated session ID.
final byte[] sessionId = maybeStoredVerificationCode.isPresent() && maybeStoredVerificationCode.get().sessionId() != null ? final byte[] sessionId = maybeStoredVerificationCode.isPresent() && maybeStoredVerificationCode.get().sessionId() != null
maybeStoredVerificationCode.get().sessionId() : createRegistrationSession(phoneNumber); ? maybeStoredVerificationCode.get().sessionId()
: createRegistrationSession(phoneNumber, accounts.getByE164(number).isPresent());
sendVerificationCode(sessionId, messageTransport, clientType, acceptLanguage); sendVerificationCode(sessionId, messageTransport, clientType, acceptLanguage);
@ -859,10 +860,11 @@ public class AccountController {
return HexFormat.of().formatHex(challenge); return HexFormat.of().formatHex(challenge);
} }
private byte[] createRegistrationSession(final Phonenumber.PhoneNumber phoneNumber) throws RateLimitExceededException { private byte[] createRegistrationSession(final Phonenumber.PhoneNumber phoneNumber,
final boolean accountExistsWithPhoneNumber) throws RateLimitExceededException {
try { try {
return registrationServiceClient.createRegistrationSession(phoneNumber, REGISTRATION_RPC_TIMEOUT).join(); return registrationServiceClient.createRegistrationSession(phoneNumber, accountExistsWithPhoneNumber, REGISTRATION_RPC_TIMEOUT).join();
} catch (final CompletionException e) { } catch (final CompletionException e) {
rethrowRateLimitException(e); rethrowRateLimitException(e);

View File

@ -76,6 +76,7 @@ import org.whispersystems.textsecuregcm.registration.RegistrationServiceExceptio
import org.whispersystems.textsecuregcm.registration.RegistrationServiceSenderException; import org.whispersystems.textsecuregcm.registration.RegistrationServiceSenderException;
import org.whispersystems.textsecuregcm.registration.VerificationSession; import org.whispersystems.textsecuregcm.registration.VerificationSession;
import org.whispersystems.textsecuregcm.spam.FilterSpam; import org.whispersystems.textsecuregcm.spam.FilterSpam;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.storage.VerificationSessionManager; import org.whispersystems.textsecuregcm.storage.VerificationSessionManager;
import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.ExceptionUtils;
@ -112,6 +113,7 @@ public class VerificationController {
private final RegistrationCaptchaManager registrationCaptchaManager; private final RegistrationCaptchaManager registrationCaptchaManager;
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager; private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager;
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
private final AccountsManager accountsManager;
private final Clock clock; private final Clock clock;
@ -119,7 +121,9 @@ public class VerificationController {
final VerificationSessionManager verificationSessionManager, final VerificationSessionManager verificationSessionManager,
final PushNotificationManager pushNotificationManager, final PushNotificationManager pushNotificationManager,
final RegistrationCaptchaManager registrationCaptchaManager, final RegistrationCaptchaManager registrationCaptchaManager,
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager, final RateLimiters rateLimiters, final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager,
final RateLimiters rateLimiters,
final AccountsManager accountsManager,
final Clock clock) { final Clock clock) {
this.registrationServiceClient = registrationServiceClient; this.registrationServiceClient = registrationServiceClient;
this.verificationSessionManager = verificationSessionManager; this.verificationSessionManager = verificationSessionManager;
@ -127,6 +131,7 @@ public class VerificationController {
this.registrationCaptchaManager = registrationCaptchaManager; this.registrationCaptchaManager = registrationCaptchaManager;
this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager; this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager;
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
this.accountsManager = accountsManager;
this.clock = clock; this.clock = clock;
} }
@ -151,6 +156,7 @@ public class VerificationController {
final RegistrationServiceSession registrationServiceSession; final RegistrationServiceSession registrationServiceSession;
try { try {
registrationServiceSession = registrationServiceClient.createRegistrationSessionSession(phoneNumber, registrationServiceSession = registrationServiceClient.createRegistrationSessionSession(phoneNumber,
accountsManager.getByE164(request.getNumber()).isPresent(),
REGISTRATION_RPC_TIMEOUT).join(); REGISTRATION_RPC_TIMEOUT).join();
} catch (final CancellationException e) { } catch (final CancellationException e) {

View File

@ -81,13 +81,14 @@ public class RegistrationServiceClient implements Managed {
// The Session suffix methods distinguish the new methods, which return Sessions, from the old. // The Session suffix methods distinguish the new methods, which return Sessions, from the old.
// Once the deprecated methods are removed, the names can be streamlined. // Once the deprecated methods are removed, the names can be streamlined.
public CompletableFuture<RegistrationServiceSession> createRegistrationSessionSession( public CompletableFuture<RegistrationServiceSession> createRegistrationSessionSession(
final Phonenumber.PhoneNumber phoneNumber, final Duration timeout) { final Phonenumber.PhoneNumber phoneNumber, final boolean accountExistsWithPhoneNumber, final Duration timeout) {
final long e164 = Long.parseLong( final long e164 = Long.parseLong(
PhoneNumberUtil.getInstance().format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164).substring(1)); PhoneNumberUtil.getInstance().format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164).substring(1));
return toCompletableFuture(stub.withDeadline(toDeadline(timeout)) return toCompletableFuture(stub.withDeadline(toDeadline(timeout))
.createSession(CreateRegistrationSessionRequest.newBuilder() .createSession(CreateRegistrationSessionRequest.newBuilder()
.setE164(e164) .setE164(e164)
.setAccountExistsWithE164(accountExistsWithPhoneNumber)
.build())) .build()))
.thenApply(response -> switch (response.getResponseCase()) { .thenApply(response -> switch (response.getResponseCase()) {
case SESSION_METADATA -> buildSessionResponseFromMetadata(response.getSessionMetadata()); case SESSION_METADATA -> buildSessionResponseFromMetadata(response.getSessionMetadata());
@ -111,8 +112,8 @@ public class RegistrationServiceClient implements Managed {
@Deprecated @Deprecated
public CompletableFuture<byte[]> createRegistrationSession(final Phonenumber.PhoneNumber phoneNumber, public CompletableFuture<byte[]> createRegistrationSession(final Phonenumber.PhoneNumber phoneNumber,
final Duration timeout) { final boolean accountExistsWithPhoneNumber, final Duration timeout) {
return createRegistrationSessionSession(phoneNumber, timeout) return createRegistrationSessionSession(phoneNumber, accountExistsWithPhoneNumber, timeout)
.thenApply(RegistrationServiceSession::id); .thenApply(RegistrationServiceSession::id);
} }

View File

@ -35,6 +35,12 @@ message CreateRegistrationSessionRequest {
* The phone number for which to create a new registration session. * The phone number for which to create a new registration session.
*/ */
uint64 e164 = 1; uint64 e164 = 1;
/**
* Indicates whether an account already exists with the given e164 (i.e. this
* session represents a "re-registration" attempt).
*/
bool account_exists_with_e164 = 2;
} }
message CreateRegistrationSessionResponse { message CreateRegistrationSessionResponse {

View File

@ -9,6 +9,7 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.argThat;
@ -386,7 +387,7 @@ class AccountControllerTest {
@Test @Test
void testGetFcmPreauth() throws NumberParseException { void testGetFcmPreauth() throws NumberParseException {
when(registrationServiceClient.createRegistrationSession(any(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(new byte[16])); .thenReturn(CompletableFuture.completedFuture(new byte[16]));
when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty()); when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty());
@ -401,7 +402,7 @@ class AccountControllerTest {
final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class); final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class);
verify(registrationServiceClient).createRegistrationSession( verify(registrationServiceClient).createRegistrationSession(
eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), any()); eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), anyBoolean(), any());
verify(pushNotificationManager).sendRegistrationChallengeNotification( verify(pushNotificationManager).sendRegistrationChallengeNotification(
eq("mytoken"), eq(PushNotification.TokenType.FCM), challengeTokenCaptor.capture()); eq("mytoken"), eq(PushNotification.TokenType.FCM), challengeTokenCaptor.capture());
@ -411,7 +412,7 @@ class AccountControllerTest {
@Test @Test
void testGetFcmPreauthIvoryCoast() throws NumberParseException { void testGetFcmPreauthIvoryCoast() throws NumberParseException {
when(registrationServiceClient.createRegistrationSession(any(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(new byte[16])); .thenReturn(CompletableFuture.completedFuture(new byte[16]));
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()
@ -424,7 +425,7 @@ class AccountControllerTest {
final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class); final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class);
verify(registrationServiceClient).createRegistrationSession( verify(registrationServiceClient).createRegistrationSession(
eq(PhoneNumberUtil.getInstance().parse("+2250707312345", null)), any()); eq(PhoneNumberUtil.getInstance().parse("+2250707312345", null)), anyBoolean(), any());
verify(pushNotificationManager).sendRegistrationChallengeNotification( verify(pushNotificationManager).sendRegistrationChallengeNotification(
eq("mytoken"), eq(PushNotification.TokenType.FCM), challengeTokenCaptor.capture()); eq("mytoken"), eq(PushNotification.TokenType.FCM), challengeTokenCaptor.capture());
@ -434,7 +435,7 @@ class AccountControllerTest {
@Test @Test
void testGetApnPreauth() throws NumberParseException { void testGetApnPreauth() throws NumberParseException {
when(registrationServiceClient.createRegistrationSession(any(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(new byte[16])); .thenReturn(CompletableFuture.completedFuture(new byte[16]));
when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty()); when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty());
@ -449,7 +450,7 @@ class AccountControllerTest {
final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class); final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class);
verify(registrationServiceClient).createRegistrationSession( verify(registrationServiceClient).createRegistrationSession(
eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), any()); eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), anyBoolean(), any());
verify(pushNotificationManager).sendRegistrationChallengeNotification( verify(pushNotificationManager).sendRegistrationChallengeNotification(
eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture()); eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture());
@ -459,7 +460,7 @@ class AccountControllerTest {
@Test @Test
void testGetApnPreauthExplicitVoip() throws NumberParseException { void testGetApnPreauthExplicitVoip() throws NumberParseException {
when(registrationServiceClient.createRegistrationSession(any(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(new byte[16])); .thenReturn(CompletableFuture.completedFuture(new byte[16]));
when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty()); when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty());
@ -475,7 +476,7 @@ class AccountControllerTest {
final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class); final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class);
verify(registrationServiceClient).createRegistrationSession( verify(registrationServiceClient).createRegistrationSession(
eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), any()); eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), anyBoolean(), any());
verify(pushNotificationManager).sendRegistrationChallengeNotification( verify(pushNotificationManager).sendRegistrationChallengeNotification(
eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture()); eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture());
@ -485,7 +486,7 @@ class AccountControllerTest {
@Test @Test
void testGetApnPreauthExplicitNoVoip() throws NumberParseException { void testGetApnPreauthExplicitNoVoip() throws NumberParseException {
when(registrationServiceClient.createRegistrationSession(any(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(new byte[16])); .thenReturn(CompletableFuture.completedFuture(new byte[16]));
when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty()); when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty());
@ -501,7 +502,7 @@ class AccountControllerTest {
final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class); final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class);
verify(registrationServiceClient).createRegistrationSession( verify(registrationServiceClient).createRegistrationSession(
eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), any()); eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), anyBoolean(), any());
verify(pushNotificationManager).sendRegistrationChallengeNotification( verify(pushNotificationManager).sendRegistrationChallengeNotification(
eq("mytoken"), eq(PushNotification.TokenType.APN), challengeTokenCaptor.capture()); eq("mytoken"), eq(PushNotification.TokenType.APN), challengeTokenCaptor.capture());
@ -546,7 +547,7 @@ class AccountControllerTest {
void testGetPreauthExistingSession() throws NumberParseException { void testGetPreauthExistingSession() throws NumberParseException {
final String existingPushCode = "existing-push-code"; final String existingPushCode = "existing-push-code";
when(registrationServiceClient.createRegistrationSession(any(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(new byte[16])); .thenReturn(CompletableFuture.completedFuture(new byte[16]));
when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn( when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(
@ -561,7 +562,7 @@ class AccountControllerTest {
final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class); final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class);
verify(registrationServiceClient, never()).createRegistrationSession(any(), any()); verify(registrationServiceClient, never()).createRegistrationSession(any(), anyBoolean(), any());
verify(pushNotificationManager).sendRegistrationChallengeNotification( verify(pushNotificationManager).sendRegistrationChallengeNotification(
eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture()); eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture());
@ -571,7 +572,7 @@ class AccountControllerTest {
@Test @Test
void testGetPreauthExistingSessionWithoutPushCode() throws NumberParseException { void testGetPreauthExistingSessionWithoutPushCode() throws NumberParseException {
when(registrationServiceClient.createRegistrationSession(any(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(new byte[16])); .thenReturn(CompletableFuture.completedFuture(new byte[16]));
when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn( when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(
@ -586,7 +587,7 @@ class AccountControllerTest {
final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class); final ArgumentCaptor<String> challengeTokenCaptor = ArgumentCaptor.forClass(String.class);
verify(registrationServiceClient, never()).createRegistrationSession(any(), any()); verify(registrationServiceClient, never()).createRegistrationSession(any(), anyBoolean(), any());
verify(pushNotificationManager).sendRegistrationChallengeNotification( verify(pushNotificationManager).sendRegistrationChallengeNotification(
eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture()); eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture());
@ -624,7 +625,7 @@ class AccountControllerTest {
void testSendCode() { void testSendCode() {
final byte[] sessionId = "session-id".getBytes(StandardCharsets.UTF_8); final byte[] sessionId = "session-id".getBytes(StandardCharsets.UTF_8);
when(registrationServiceClient.createRegistrationSession(any(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId)); .thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
@ -648,7 +649,7 @@ class AccountControllerTest {
@Test @Test
void testSendCodeRateLimited() { void testSendCodeRateLimited() {
when(registrationServiceClient.createRegistrationSession(any(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofMinutes(10), true))); .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(Duration.ofMinutes(10), true)));
Response response = Response response =
@ -709,7 +710,7 @@ class AccountControllerTest {
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId)); .thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.createRegistrationSession(any(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId)); .thenReturn(CompletableFuture.completedFuture(sessionId));
Response response = Response response =
@ -732,7 +733,7 @@ class AccountControllerTest {
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId)); .thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.createRegistrationSession(any(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId)); .thenReturn(CompletableFuture.completedFuture(sessionId));
Response response = Response response =
@ -785,7 +786,7 @@ class AccountControllerTest {
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId)); .thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.createRegistrationSession(any(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId)); .thenReturn(CompletableFuture.completedFuture(sessionId));
Response response = Response response =
@ -809,7 +810,7 @@ class AccountControllerTest {
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId)); .thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.createRegistrationSession(any(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId)); .thenReturn(CompletableFuture.completedFuture(sessionId));
Response response = Response response =
@ -834,7 +835,7 @@ class AccountControllerTest {
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId)); .thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.createRegistrationSession(any(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId)); .thenReturn(CompletableFuture.completedFuture(sessionId));
Response response = Response response =
@ -944,7 +945,7 @@ class AccountControllerTest {
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId)); .thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.createRegistrationSession(any(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId)); .thenReturn(CompletableFuture.completedFuture(sessionId));
Response response = Response response =
@ -975,7 +976,7 @@ class AccountControllerTest {
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId)); .thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.createRegistrationSession(any(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId)); .thenReturn(CompletableFuture.completedFuture(sessionId));
Response response = Response response =
@ -998,7 +999,7 @@ class AccountControllerTest {
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId)); .thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.createRegistrationSession(any(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId)); .thenReturn(CompletableFuture.completedFuture(sessionId));
Response response = Response response =
@ -2092,7 +2093,7 @@ class AccountControllerTest {
when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any())) when(registrationServiceClient.sendRegistrationCode(any(), any(), any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId)); .thenReturn(CompletableFuture.completedFuture(sessionId));
when(registrationServiceClient.createRegistrationSession(any(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.completedFuture(sessionId)); .thenReturn(CompletableFuture.completedFuture(sessionId));
Response response = Response response =

View File

@ -11,7 +11,9 @@ import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -19,6 +21,8 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.common.net.HttpHeaders; import com.google.common.net.HttpHeaders;
import com.google.i18n.phonenumbers.NumberParseException;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension; import io.dropwizard.testing.junit5.ResourceExtension;
import java.io.IOException; import java.io.IOException;
@ -46,6 +50,7 @@ import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.captcha.AssessmentResult; import org.whispersystems.textsecuregcm.captcha.AssessmentResult;
import org.whispersystems.textsecuregcm.captcha.RegistrationCaptchaManager; import org.whispersystems.textsecuregcm.captcha.RegistrationCaptchaManager;
@ -62,6 +67,8 @@ import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.registration.RegistrationServiceException; import org.whispersystems.textsecuregcm.registration.RegistrationServiceException;
import org.whispersystems.textsecuregcm.registration.RegistrationServiceSenderException; import org.whispersystems.textsecuregcm.registration.RegistrationServiceSenderException;
import org.whispersystems.textsecuregcm.registration.VerificationSession; import org.whispersystems.textsecuregcm.registration.VerificationSession;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.storage.VerificationSessionManager; import org.whispersystems.textsecuregcm.storage.VerificationSessionManager;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -81,6 +88,7 @@ class VerificationControllerTest {
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock( private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock(
RegistrationRecoveryPasswordsManager.class); RegistrationRecoveryPasswordsManager.class);
private final RateLimiters rateLimiters = mock(RateLimiters.class); private final RateLimiters rateLimiters = mock(RateLimiters.class);
private final AccountsManager accountsManager = mock(AccountsManager.class);
private final Clock clock = Clock.systemUTC(); private final Clock clock = Clock.systemUTC();
private final RateLimiter captchaLimiter = mock(RateLimiter.class); private final RateLimiter captchaLimiter = mock(RateLimiter.class);
@ -96,7 +104,7 @@ class VerificationControllerTest {
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource( .addResource(
new VerificationController(registrationServiceClient, verificationSessionManager, pushNotificationManager, new VerificationController(registrationServiceClient, verificationSessionManager, pushNotificationManager,
registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters, clock)) registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters, accountsManager, clock))
.build(); .build();
@BeforeEach @BeforeEach
@ -105,6 +113,8 @@ class VerificationControllerTest {
.thenReturn(captchaLimiter); .thenReturn(captchaLimiter);
when(rateLimiters.getVerificationPushChallengeLimiter()) when(rateLimiters.getVerificationPushChallengeLimiter())
.thenReturn(pushChallengeLimiter); .thenReturn(pushChallengeLimiter);
when(accountsManager.getByE164(any())).thenReturn(Optional.empty());
} }
@ParameterizedTest @ParameterizedTest
@ -153,7 +163,7 @@ class VerificationControllerTest {
@Test @Test
void createSessionRateLimited() { void createSessionRateLimited() {
when(registrationServiceClient.createRegistrationSessionSession(any(), any())) when(registrationServiceClient.createRegistrationSessionSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null, true))); .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null, true)));
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
@ -167,7 +177,7 @@ class VerificationControllerTest {
@Test @Test
void createSessionRegistrationServiceError() { void createSessionRegistrationServiceError() {
when(registrationServiceClient.createRegistrationSessionSession(any(), any())) when(registrationServiceClient.createRegistrationSessionSession(any(), anyBoolean(), any()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException("expected service error"))); .thenReturn(CompletableFuture.failedFuture(new RuntimeException("expected service error")));
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
@ -183,7 +193,7 @@ class VerificationControllerTest {
@MethodSource @MethodSource
void createSessionSuccess(final String pushToken, final String pushTokenType, void createSessionSuccess(final String pushToken, final String pushTokenType,
final List<VerificationSession.Information> expectedRequestedInformation) { final List<VerificationSession.Information> expectedRequestedInformation) {
when(registrationServiceClient.createRegistrationSessionSession(any(), any())) when(registrationServiceClient.createRegistrationSessionSession(any(), anyBoolean(), any()))
.thenReturn( .thenReturn(
CompletableFuture.completedFuture( CompletableFuture.completedFuture(
new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null, new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null,
@ -214,6 +224,37 @@ class VerificationControllerTest {
); );
} }
@ParameterizedTest
@ValueSource(booleans = {true, false})
void createSessionReregistration(final boolean isReregistration) throws NumberParseException {
when(registrationServiceClient.createRegistrationSessionSession(any(), anyBoolean(), any()))
.thenReturn(
CompletableFuture.completedFuture(
new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null,
SESSION_EXPIRATION_SECONDS)));
when(verificationSessionManager.insert(any(), any()))
.thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.getByE164(NUMBER))
.thenReturn(isReregistration ? Optional.of(mock(Account.class)) : Optional.empty());
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/verification/session")
.request()
.header(HttpHeaders.X_FORWARDED_FOR, "127.0.0.1");
try (final Response response = request.post(Entity.json(createSessionJson(NUMBER, null, null)))) {
assertEquals(HttpStatus.SC_OK, response.getStatus());
verify(registrationServiceClient).createRegistrationSessionSession(
eq(PhoneNumberUtil.getInstance().parse(NUMBER, null)),
eq(isReregistration),
any()
);
}
}
@Test @Test
void patchSessionMalformedId() { void patchSessionMalformedId() {
final String invalidSessionId = "()()()"; final String invalidSessionId = "()()()";