From e1ea3795bb76712686a11caff27e166827bd48ab Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Wed, 22 Feb 2023 12:24:43 -0500 Subject: [PATCH] Reuse registration sessions if possible when requesting pre-auth codes --- .../controllers/AccountController.java | 25 +++++- .../controllers/AccountControllerTest.java | 76 +++++++++++++++++-- 2 files changed, 89 insertions(+), 12 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java index f8d81ee0c..8a2da844a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java @@ -226,10 +226,27 @@ public class AccountController { throw new BadRequestException("Bad phone number"); } - final String pushChallenge = generatePushChallenge(); - final byte[] sessionId = createRegistrationSession(phoneNumber); - final StoredVerificationCode storedVerificationCode = - new StoredVerificationCode(null, clock.millis(), pushChallenge, sessionId); + final StoredVerificationCode storedVerificationCode; + { + final Optional maybeStoredVerificationCode = pendingAccounts.getCodeForNumber(number); + + if (maybeStoredVerificationCode.isPresent()) { + final StoredVerificationCode existingStoredVerificationCode = maybeStoredVerificationCode.get(); + + if (StringUtils.isBlank(existingStoredVerificationCode.pushCode())) { + storedVerificationCode = new StoredVerificationCode( + existingStoredVerificationCode.code(), + existingStoredVerificationCode.timestamp(), + generatePushChallenge(), + existingStoredVerificationCode.sessionId()); + } else { + storedVerificationCode = existingStoredVerificationCode; + } + } else { + final byte[] sessionId = createRegistrationSession(phoneNumber); + storedVerificationCode = new StoredVerificationCode(null, clock.millis(), generatePushChallenge(), sessionId); + } + } pendingAccounts.store(number, storedVerificationCode); pushNotificationManager.sendRegistrationChallengeNotification(pushToken, tokenType, storedVerificationCode.pushCode()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java index 98cd6ce22..30d162014 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java @@ -372,8 +372,10 @@ class AccountControllerTest { when(registrationServiceClient.createRegistrationSession(any(), any())) .thenReturn(CompletableFuture.completedFuture(new byte[16])); + when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty()); + Response response = resources.getJerseyTest() - .target("/v1/accounts/fcm/preauth/mytoken/+14152222222") + .target("/v1/accounts/fcm/preauth/mytoken/" + SENDER) .request() .get(); @@ -382,7 +384,7 @@ class AccountControllerTest { final ArgumentCaptor challengeTokenCaptor = ArgumentCaptor.forClass(String.class); verify(registrationServiceClient).createRegistrationSession( - eq(PhoneNumberUtil.getInstance().parse("+14152222222", null)), any()); + eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), any()); verify(pushNotificationManager).sendRegistrationChallengeNotification( eq("mytoken"), eq(PushNotification.TokenType.FCM), challengeTokenCaptor.capture()); @@ -418,8 +420,10 @@ class AccountControllerTest { when(registrationServiceClient.createRegistrationSession(any(), any())) .thenReturn(CompletableFuture.completedFuture(new byte[16])); + when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty()); + Response response = resources.getJerseyTest() - .target("/v1/accounts/apn/preauth/mytoken/+14152222222") + .target("/v1/accounts/apn/preauth/mytoken/" + SENDER) .request() .get(); @@ -428,7 +432,7 @@ class AccountControllerTest { final ArgumentCaptor challengeTokenCaptor = ArgumentCaptor.forClass(String.class); verify(registrationServiceClient).createRegistrationSession( - eq(PhoneNumberUtil.getInstance().parse("+14152222222", null)), any()); + eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), any()); verify(pushNotificationManager).sendRegistrationChallengeNotification( eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture()); @@ -441,8 +445,10 @@ class AccountControllerTest { when(registrationServiceClient.createRegistrationSession(any(), any())) .thenReturn(CompletableFuture.completedFuture(new byte[16])); + when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty()); + Response response = resources.getJerseyTest() - .target("/v1/accounts/apn/preauth/mytoken/+14152222222") + .target("/v1/accounts/apn/preauth/mytoken/" + SENDER) .queryParam("voip", "true") .request() .get(); @@ -452,7 +458,7 @@ class AccountControllerTest { final ArgumentCaptor challengeTokenCaptor = ArgumentCaptor.forClass(String.class); verify(registrationServiceClient).createRegistrationSession( - eq(PhoneNumberUtil.getInstance().parse("+14152222222", null)), any()); + eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), any()); verify(pushNotificationManager).sendRegistrationChallengeNotification( eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture()); @@ -465,8 +471,10 @@ class AccountControllerTest { when(registrationServiceClient.createRegistrationSession(any(), any())) .thenReturn(CompletableFuture.completedFuture(new byte[16])); + when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn(Optional.empty()); + Response response = resources.getJerseyTest() - .target("/v1/accounts/apn/preauth/mytoken/+14152222222") + .target("/v1/accounts/apn/preauth/mytoken/" + SENDER) .queryParam("voip", "false") .request() .get(); @@ -476,7 +484,7 @@ class AccountControllerTest { final ArgumentCaptor challengeTokenCaptor = ArgumentCaptor.forClass(String.class); verify(registrationServiceClient).createRegistrationSession( - eq(PhoneNumberUtil.getInstance().parse("+14152222222", null)), any()); + eq(PhoneNumberUtil.getInstance().parse(SENDER, null)), any()); verify(pushNotificationManager).sendRegistrationChallengeNotification( eq("mytoken"), eq(PushNotification.TokenType.APN), challengeTokenCaptor.capture()); @@ -517,6 +525,58 @@ class AccountControllerTest { verifyNoInteractions(pushNotificationManager); } + @Test + void testGetPreauthExistingSession() throws NumberParseException { + final String existingPushCode = "existing-push-code"; + + when(registrationServiceClient.createRegistrationSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(new byte[16])); + + when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn( + Optional.of(new StoredVerificationCode(null, System.currentTimeMillis(), existingPushCode, new byte[16]))); + + Response response = resources.getJerseyTest() + .target("/v1/accounts/apn/preauth/mytoken/" + SENDER) + .request() + .get(); + + assertThat(response.getStatus()).isEqualTo(200); + + final ArgumentCaptor challengeTokenCaptor = ArgumentCaptor.forClass(String.class); + + verify(registrationServiceClient, never()).createRegistrationSession(any(), any()); + + verify(pushNotificationManager).sendRegistrationChallengeNotification( + eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture()); + + assertThat(challengeTokenCaptor.getValue()).isEqualTo(existingPushCode); + } + + @Test + void testGetPreauthExistingSessionWithoutPushCode() throws NumberParseException { + when(registrationServiceClient.createRegistrationSession(any(), any())) + .thenReturn(CompletableFuture.completedFuture(new byte[16])); + + when(pendingAccountsManager.getCodeForNumber(SENDER)).thenReturn( + Optional.of(new StoredVerificationCode(null, System.currentTimeMillis(), null, new byte[16]))); + + Response response = resources.getJerseyTest() + .target("/v1/accounts/apn/preauth/mytoken/" + SENDER) + .request() + .get(); + + assertThat(response.getStatus()).isEqualTo(200); + + final ArgumentCaptor challengeTokenCaptor = ArgumentCaptor.forClass(String.class); + + verify(registrationServiceClient, never()).createRegistrationSession(any(), any()); + + verify(pushNotificationManager).sendRegistrationChallengeNotification( + eq("mytoken"), eq(PushNotification.TokenType.APN_VOIP), challengeTokenCaptor.capture()); + + assertThat(challengeTokenCaptor.getValue().length()).isEqualTo(32); + } + @Test void testSendCodeWithExistingSessionFromPreauth() { final byte[] sessionId = "session-id".getBytes(StandardCharsets.UTF_8);