diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 4066a9947..843c468a7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -747,7 +747,7 @@ public class WhisperServerService extends Application maybeSession = registrationServiceClient.getSession(sessionId, - REGISTRATION_RPC_TIMEOUT) - .get(REGISTRATION_RPC_TIMEOUT.plusSeconds(1).getSeconds(), TimeUnit.SECONDS); - - final RegistrationSession session = maybeSession.orElseThrow( - () -> new NotAuthorizedException("session not verified")); - if (!MessageDigest.isEqual(number.getBytes(), session.number().getBytes())) { - throw new BadRequestException("number does not match session"); - } - if (!session.verified()) { - throw new NotAuthorizedException("session not verified"); - } - - } catch (final CancellationException | ExecutionException | TimeoutException e) { - logger.error("Registration service failure", e); - throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE); + // decide on the method of verification based on the registration request parameters and verify + final RegistrationRequest.VerificationType verificationType = registrationRequest.verificationType(); + switch (verificationType) { + case SESSION -> verifyBySessionId(number, registrationRequest.decodeSessionId()); + case RECOVERY_PASSWORD -> verifyByRecoveryPassword(number, registrationRequest.recoveryPassword()); } final Optional existingAccount = accounts.getByE164(number); @@ -150,10 +133,15 @@ public class RegistrationController { final Account account = accounts.create(number, password, signalAgent, registrationRequest.accountAttributes(), existingAccount.map(Account::getBadges).orElseGet(ArrayList::new)); + // now that the number is verified and account is created, + // we can store recovery password for this number + registrationRequest.accountAttributes().recoveryPassword().ifPresent(recoveryPassword -> + registrationRecoveryPasswordsManager.storeForCurrentNumber(number, recoveryPassword)); + Metrics.counter(ACCOUNT_CREATED_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)), Tag.of(REGION_CODE_TAG_NAME, Util.getRegion(number)), - Tag.of(VERIFICATION_TYPE_TAG_NAME, verificationType))) + Tag.of(VERIFICATION_TYPE_TAG_NAME, verificationType.name()))) .increment(); return new AccountIdentityResponse(account.getUuid(), @@ -163,4 +151,34 @@ public class RegistrationController { existingAccount.map(Account::isStorageSupported).orElse(false)); } + private void verifyBySessionId(final String number, final byte[] sessionId) throws InterruptedException { + try { + final RegistrationSession session = registrationServiceClient + .getSession(sessionId, REGISTRATION_RPC_TIMEOUT) + .get(VERIFICATION_TIMEOUT_SECONDS, TimeUnit.SECONDS) + .orElseThrow(() -> new NotAuthorizedException("session not verified")); + + if (!MessageDigest.isEqual(number.getBytes(), session.number().getBytes())) { + throw new BadRequestException("number does not match session"); + } + if (!session.verified()) { + throw new NotAuthorizedException("session not verified"); + } + } catch (final CancellationException | ExecutionException | TimeoutException e) { + logger.error("Registration service failure", e); + throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE); + } + } + + private void verifyByRecoveryPassword(final String number, final byte[] recoveryPassword) throws InterruptedException { + try { + final boolean verified = registrationRecoveryPasswordsManager.verify(number, recoveryPassword) + .get(VERIFICATION_TIMEOUT_SECONDS, TimeUnit.SECONDS); + if (!verified) { + throw new ForbiddenException("recoveryPassword couldn't be verified"); + } + } catch (final ExecutionException | TimeoutException e) { + throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE); + } + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java index 5efcced2e..e0bb4105d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java @@ -5,12 +5,43 @@ package org.whispersystems.textsecuregcm.entities; -import javax.validation.Valid; -import javax.validation.constraints.NotBlank; -import javax.validation.constraints.NotNull; +import static org.apache.commons.lang3.StringUtils.isNotBlank; -public record RegistrationRequest(@NotBlank String sessionId, +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import java.util.Base64; +import javax.validation.Valid; +import javax.validation.constraints.AssertTrue; +import javax.validation.constraints.NotNull; +import javax.ws.rs.ClientErrorException; +import org.apache.http.HttpStatus; +import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; + +public record RegistrationRequest(String sessionId, + @JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) byte[] recoveryPassword, @NotNull @Valid AccountAttributes accountAttributes, boolean skipDeviceTransfer) { + public enum VerificationType { + SESSION, + RECOVERY_PASSWORD + } + + // for the @AssertTrue to work with bean validation, method name must follow 'isSmth()'/'getSmth()' naming convention + @AssertTrue + public boolean isValid() { + // checking that exactly one of sessionId/recoveryPassword is non-empty + return isNotBlank(sessionId) ^ (recoveryPassword != null && recoveryPassword.length > 0); + } + + public VerificationType verificationType() { + return isNotBlank(sessionId) ? VerificationType.SESSION : VerificationType.RECOVERY_PASSWORD; + } + + public byte[] decodeSessionId() { + try { + return Base64.getUrlDecoder().decode(sessionId()); + } catch (final IllegalArgumentException e) { + throw new ClientErrorException("Malformed session ID", HttpStatus.SC_UNPROCESSABLE_ENTITY); + } + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java index fe1f2988a..a6ce4f025 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java @@ -6,7 +6,10 @@ package org.whispersystems.textsecuregcm.controllers; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -35,8 +38,11 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; import org.whispersystems.textsecuregcm.auth.RegistrationLockError; import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; +import org.whispersystems.textsecuregcm.entities.AccountAttributes; +import org.whispersystems.textsecuregcm.entities.RegistrationRequest; import org.whispersystems.textsecuregcm.entities.RegistrationSession; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; @@ -46,6 +52,7 @@ import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; import org.whispersystems.textsecuregcm.util.SystemMapper; @ExtendWith(DropwizardExtensionsSupport.class) @@ -56,6 +63,8 @@ class RegistrationControllerTest { private final RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class); private final RegistrationLockVerificationManager registrationLockVerificationManager = mock( RegistrationLockVerificationManager.class); + private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock( + RegistrationRecoveryPasswordsManager.class); private final RateLimiters rateLimiters = mock(RateLimiters.class); private final RateLimiter registrationLimiter = mock(RateLimiter.class); @@ -70,7 +79,7 @@ class RegistrationControllerTest { .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource( new RegistrationController(accountsManager, registrationServiceClient, registrationLockVerificationManager, - rateLimiters)) + registrationRecoveryPasswordsManager, rateLimiters)) .build(); @BeforeEach @@ -79,6 +88,14 @@ class RegistrationControllerTest { when(rateLimiters.getPinLimiter()).thenReturn(pinLimiter); } + @Test + public void testRegistrationRequest() throws Exception { + assertFalse(new RegistrationRequest("", new byte[0], new AccountAttributes(), true).isValid()); + assertFalse(new RegistrationRequest("some", new byte[32], new AccountAttributes(), true).isValid()); + assertTrue(new RegistrationRequest("", new byte[32], new AccountAttributes(), true).isValid()); + assertTrue(new RegistrationRequest("some", new byte[0], new AccountAttributes(), true).isValid()); + } + @Test void unprocessableRequestJson() { final Invocation.Builder request = resources.getJerseyTest() @@ -151,6 +168,20 @@ class RegistrationControllerTest { } } + @Test + void recoveryPasswordManagerVerificationFailureOrTimeout() { + when(registrationRecoveryPasswordsManager.verify(any(), any())) + .thenReturn(CompletableFuture.failedFuture(new RuntimeException())); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/registration") + .request() + .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); + try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) { + assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus()); + } + } + @ParameterizedTest @MethodSource void registrationServiceSessionCheck(@Nullable final RegistrationSession session, final int expectedStatus, @@ -175,6 +206,38 @@ class RegistrationControllerTest { ); } + @Test + void recoveryPasswordManagerVerificationTrue() throws InterruptedException { + when(registrationRecoveryPasswordsManager.verify(any(), any())) + .thenReturn(CompletableFuture.completedFuture(true)); + when(accountsManager.create(any(), any(), any(), any(), any())) + .thenReturn(mock(Account.class)); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/registration") + .request() + .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); + final byte[] recoveryPassword = new byte[32]; + try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(recoveryPassword)))) { + assertEquals(200, response.getStatus()); + Mockito.verify(registrationRecoveryPasswordsManager).storeForCurrentNumber(eq(NUMBER), eq(recoveryPassword)); + } + } + + @Test + void recoveryPasswordManagerVerificationFalse() throws InterruptedException { + when(registrationRecoveryPasswordsManager.verify(any(), any())) + .thenReturn(CompletableFuture.completedFuture(false)); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/registration") + .request() + .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); + try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) { + assertEquals(403, response.getStatus()); + } + } + @ParameterizedTest @EnumSource(RegistrationLockError.class) void registrationLock(final RegistrationLockError error) throws Exception { @@ -227,7 +290,7 @@ class RegistrationControllerTest { .target("/v1/registration") .request() .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); - try (Response response = request.post(Entity.json(requestJson("sessionId", skipDeviceTransfer)))) { + try (Response response = request.post(Entity.json(requestJson("sessionId", new byte[0], skipDeviceTransfer)))) { assertEquals(expectedStatus, response.getStatus()); } } @@ -252,21 +315,32 @@ class RegistrationControllerTest { /** * Valid request JSON with the give session ID and skipDeviceTransfer */ - private static String requestJson(final String sessionId, final boolean skipDeviceTransfer) { + private static String requestJson(final String sessionId, final byte[] recoveryPassword, final boolean skipDeviceTransfer) { + final String rp = encodeRecoveryPassword(recoveryPassword); return String.format(""" { "sessionId": "%s", - "accountAttributes": {}, + "recoveryPassword": "%s", + "accountAttributes": { + "recoveryPassword": "%s" + }, "skipDeviceTransfer": %s } - """, encodeSessionId(sessionId), skipDeviceTransfer); + """, encodeSessionId(sessionId), rp, rp, skipDeviceTransfer); } /** - * Valid request JSON with the give session ID + * Valid request JSON with the given session ID */ private static String requestJson(final String sessionId) { - return requestJson(sessionId, false); + return requestJson(sessionId, new byte[0], false); + } + + /** + * Valid request JSON with the given Recovery Password + */ + private static String requestJsonRecoveryPassword(final byte[] recoveryPassword) { + return requestJson("", recoveryPassword, false); } /** @@ -303,4 +377,7 @@ class RegistrationControllerTest { return Base64.getEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8)); } + private static String encodeRecoveryPassword(final byte[] recoveryPassword) { + return Base64.getEncoder().encodeToString(recoveryPassword); + } }