Registration Recovery Password support in `/v1/registration`

This commit is contained in:
Sergey Skrobotov 2023-02-08 13:11:10 -08:00
parent 4a3880b5ae
commit 7558489ad0
4 changed files with 167 additions and 41 deletions

View File

@ -747,7 +747,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getCdnConfiguration().getBucket(), zkProfileOperations, batchIdentityCheckExecutor),
new ProvisioningController(rateLimiters, provisioningManager),
new RegistrationController(accountsManager, registrationServiceClient, registrationLockVerificationManager,
rateLimiters),
registrationRecoveryPasswordsManager, rateLimiters),
new RemoteConfigController(remoteConfigsManager, adminEventLogger,
config.getRemoteConfigConfiguration().getAuthorizedTokens(),
config.getRemoteConfigConfiguration().getGlobalConfig()),

View File

@ -17,7 +17,6 @@ import java.security.MessageDigest;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Optional;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
@ -26,8 +25,8 @@ import java.util.concurrent.TimeoutException;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import javax.ws.rs.BadRequestException;
import javax.ws.rs.ClientErrorException;
import javax.ws.rs.Consumes;
import javax.ws.rs.ForbiddenException;
import javax.ws.rs.HeaderParam;
import javax.ws.rs.NotAuthorizedException;
import javax.ws.rs.POST;
@ -37,7 +36,6 @@ import javax.ws.rs.ServerErrorException;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.apache.http.HttpStatus;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader;
@ -50,6 +48,7 @@ import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;
@ -70,19 +69,23 @@ public class RegistrationController {
private static final String VERIFICATION_TYPE_TAG_NAME = "verification";
private static final Duration REGISTRATION_RPC_TIMEOUT = Duration.ofSeconds(15);
private static final long VERIFICATION_TIMEOUT_SECONDS = REGISTRATION_RPC_TIMEOUT.plusSeconds(1).getSeconds();
private final AccountsManager accounts;
private final RegistrationServiceClient registrationServiceClient;
private final RegistrationLockVerificationManager registrationLockVerificationManager;
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager;
private final RateLimiters rateLimiters;
public RegistrationController(final AccountsManager accounts,
final RegistrationServiceClient registrationServiceClient,
final RegistrationLockVerificationManager registrationLockVerificationManager,
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager,
final RateLimiters rateLimiters) {
this.accounts = accounts;
this.registrationServiceClient = registrationServiceClient;
this.registrationLockVerificationManager = registrationLockVerificationManager;
this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager;
this.rateLimiters = rateLimiters;
}
@ -98,34 +101,14 @@ public class RegistrationController {
rateLimiters.getRegistrationLimiter().validate(registrationRequest.sessionId());
final byte[] sessionId;
try {
sessionId = Base64.getDecoder().decode(registrationRequest.sessionId());
} catch (final IllegalArgumentException e) {
throw new ClientErrorException("Malformed session ID", HttpStatus.SC_UNPROCESSABLE_ENTITY);
}
final String number = authorizationHeader.getUsername();
final String password = authorizationHeader.getPassword();
final String verificationType = "phoneNumberVerification";
try {
final Optional<RegistrationSession> maybeSession = registrationServiceClient.getSession(sessionId,
REGISTRATION_RPC_TIMEOUT)
.get(REGISTRATION_RPC_TIMEOUT.plusSeconds(1).getSeconds(), TimeUnit.SECONDS);
final RegistrationSession session = maybeSession.orElseThrow(
() -> new NotAuthorizedException("session not verified"));
if (!MessageDigest.isEqual(number.getBytes(), session.number().getBytes())) {
throw new BadRequestException("number does not match session");
}
if (!session.verified()) {
throw new NotAuthorizedException("session not verified");
}
} catch (final CancellationException | ExecutionException | TimeoutException e) {
logger.error("Registration service failure", e);
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE);
// decide on the method of verification based on the registration request parameters and verify
final RegistrationRequest.VerificationType verificationType = registrationRequest.verificationType();
switch (verificationType) {
case SESSION -> verifyBySessionId(number, registrationRequest.decodeSessionId());
case RECOVERY_PASSWORD -> verifyByRecoveryPassword(number, registrationRequest.recoveryPassword());
}
final Optional<Account> existingAccount = accounts.getByE164(number);
@ -150,10 +133,15 @@ public class RegistrationController {
final Account account = accounts.create(number, password, signalAgent, registrationRequest.accountAttributes(),
existingAccount.map(Account::getBadges).orElseGet(ArrayList::new));
// now that the number is verified and account is created,
// we can store recovery password for this number
registrationRequest.accountAttributes().recoveryPassword().ifPresent(recoveryPassword ->
registrationRecoveryPasswordsManager.storeForCurrentNumber(number, recoveryPassword));
Metrics.counter(ACCOUNT_CREATED_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)),
Tag.of(REGION_CODE_TAG_NAME, Util.getRegion(number)),
Tag.of(VERIFICATION_TYPE_TAG_NAME, verificationType)))
Tag.of(VERIFICATION_TYPE_TAG_NAME, verificationType.name())))
.increment();
return new AccountIdentityResponse(account.getUuid(),
@ -163,4 +151,34 @@ public class RegistrationController {
existingAccount.map(Account::isStorageSupported).orElse(false));
}
private void verifyBySessionId(final String number, final byte[] sessionId) throws InterruptedException {
try {
final RegistrationSession session = registrationServiceClient
.getSession(sessionId, REGISTRATION_RPC_TIMEOUT)
.get(VERIFICATION_TIMEOUT_SECONDS, TimeUnit.SECONDS)
.orElseThrow(() -> new NotAuthorizedException("session not verified"));
if (!MessageDigest.isEqual(number.getBytes(), session.number().getBytes())) {
throw new BadRequestException("number does not match session");
}
if (!session.verified()) {
throw new NotAuthorizedException("session not verified");
}
} catch (final CancellationException | ExecutionException | TimeoutException e) {
logger.error("Registration service failure", e);
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE);
}
}
private void verifyByRecoveryPassword(final String number, final byte[] recoveryPassword) throws InterruptedException {
try {
final boolean verified = registrationRecoveryPasswordsManager.verify(number, recoveryPassword)
.get(VERIFICATION_TIMEOUT_SECONDS, TimeUnit.SECONDS);
if (!verified) {
throw new ForbiddenException("recoveryPassword couldn't be verified");
}
} catch (final ExecutionException | TimeoutException e) {
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE);
}
}
}

View File

@ -5,12 +5,43 @@
package org.whispersystems.textsecuregcm.entities;
import javax.validation.Valid;
import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
import static org.apache.commons.lang3.StringUtils.isNotBlank;
public record RegistrationRequest(@NotBlank String sessionId,
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import java.util.Base64;
import javax.validation.Valid;
import javax.validation.constraints.AssertTrue;
import javax.validation.constraints.NotNull;
import javax.ws.rs.ClientErrorException;
import org.apache.http.HttpStatus;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
public record RegistrationRequest(String sessionId,
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) byte[] recoveryPassword,
@NotNull @Valid AccountAttributes accountAttributes,
boolean skipDeviceTransfer) {
public enum VerificationType {
SESSION,
RECOVERY_PASSWORD
}
// for the @AssertTrue to work with bean validation, method name must follow 'isSmth()'/'getSmth()' naming convention
@AssertTrue
public boolean isValid() {
// checking that exactly one of sessionId/recoveryPassword is non-empty
return isNotBlank(sessionId) ^ (recoveryPassword != null && recoveryPassword.length > 0);
}
public VerificationType verificationType() {
return isNotBlank(sessionId) ? VerificationType.SESSION : VerificationType.RECOVERY_PASSWORD;
}
public byte[] decodeSessionId() {
try {
return Base64.getUrlDecoder().decode(sessionId());
} catch (final IllegalArgumentException e) {
throw new ClientErrorException("Malformed session ID", HttpStatus.SC_UNPROCESSABLE_ENTITY);
}
}
}

View File

@ -6,7 +6,10 @@
package org.whispersystems.textsecuregcm.controllers;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -35,8 +38,11 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.auth.RegistrationLockError;
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.RegistrationRequest;
import org.whispersystems.textsecuregcm.entities.RegistrationSession;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
@ -46,6 +52,7 @@ import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ExtendWith(DropwizardExtensionsSupport.class)
@ -56,6 +63,8 @@ class RegistrationControllerTest {
private final RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class);
private final RegistrationLockVerificationManager registrationLockVerificationManager = mock(
RegistrationLockVerificationManager.class);
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock(
RegistrationRecoveryPasswordsManager.class);
private final RateLimiters rateLimiters = mock(RateLimiters.class);
private final RateLimiter registrationLimiter = mock(RateLimiter.class);
@ -70,7 +79,7 @@ class RegistrationControllerTest {
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(
new RegistrationController(accountsManager, registrationServiceClient, registrationLockVerificationManager,
rateLimiters))
registrationRecoveryPasswordsManager, rateLimiters))
.build();
@BeforeEach
@ -79,6 +88,14 @@ class RegistrationControllerTest {
when(rateLimiters.getPinLimiter()).thenReturn(pinLimiter);
}
@Test
public void testRegistrationRequest() throws Exception {
assertFalse(new RegistrationRequest("", new byte[0], new AccountAttributes(), true).isValid());
assertFalse(new RegistrationRequest("some", new byte[32], new AccountAttributes(), true).isValid());
assertTrue(new RegistrationRequest("", new byte[32], new AccountAttributes(), true).isValid());
assertTrue(new RegistrationRequest("some", new byte[0], new AccountAttributes(), true).isValid());
}
@Test
void unprocessableRequestJson() {
final Invocation.Builder request = resources.getJerseyTest()
@ -151,6 +168,20 @@ class RegistrationControllerTest {
}
}
@Test
void recoveryPasswordManagerVerificationFailureOrTimeout() {
when(registrationRecoveryPasswordsManager.verify(any(), any()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) {
assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus());
}
}
@ParameterizedTest
@MethodSource
void registrationServiceSessionCheck(@Nullable final RegistrationSession session, final int expectedStatus,
@ -175,6 +206,38 @@ class RegistrationControllerTest {
);
}
@Test
void recoveryPasswordManagerVerificationTrue() throws InterruptedException {
when(registrationRecoveryPasswordsManager.verify(any(), any()))
.thenReturn(CompletableFuture.completedFuture(true));
when(accountsManager.create(any(), any(), any(), any(), any()))
.thenReturn(mock(Account.class));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
final byte[] recoveryPassword = new byte[32];
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(recoveryPassword)))) {
assertEquals(200, response.getStatus());
Mockito.verify(registrationRecoveryPasswordsManager).storeForCurrentNumber(eq(NUMBER), eq(recoveryPassword));
}
}
@Test
void recoveryPasswordManagerVerificationFalse() throws InterruptedException {
when(registrationRecoveryPasswordsManager.verify(any(), any()))
.thenReturn(CompletableFuture.completedFuture(false));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) {
assertEquals(403, response.getStatus());
}
}
@ParameterizedTest
@EnumSource(RegistrationLockError.class)
void registrationLock(final RegistrationLockError error) throws Exception {
@ -227,7 +290,7 @@ class RegistrationControllerTest {
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
try (Response response = request.post(Entity.json(requestJson("sessionId", skipDeviceTransfer)))) {
try (Response response = request.post(Entity.json(requestJson("sessionId", new byte[0], skipDeviceTransfer)))) {
assertEquals(expectedStatus, response.getStatus());
}
}
@ -252,21 +315,32 @@ class RegistrationControllerTest {
/**
* Valid request JSON with the give session ID and skipDeviceTransfer
*/
private static String requestJson(final String sessionId, final boolean skipDeviceTransfer) {
private static String requestJson(final String sessionId, final byte[] recoveryPassword, final boolean skipDeviceTransfer) {
final String rp = encodeRecoveryPassword(recoveryPassword);
return String.format("""
{
"sessionId": "%s",
"accountAttributes": {},
"recoveryPassword": "%s",
"accountAttributes": {
"recoveryPassword": "%s"
},
"skipDeviceTransfer": %s
}
""", encodeSessionId(sessionId), skipDeviceTransfer);
""", encodeSessionId(sessionId), rp, rp, skipDeviceTransfer);
}
/**
* Valid request JSON with the give session ID
* Valid request JSON with the given session ID
*/
private static String requestJson(final String sessionId) {
return requestJson(sessionId, false);
return requestJson(sessionId, new byte[0], false);
}
/**
* Valid request JSON with the given Recovery Password
*/
private static String requestJsonRecoveryPassword(final byte[] recoveryPassword) {
return requestJson("", recoveryPassword, false);
}
/**
@ -303,4 +377,7 @@ class RegistrationControllerTest {
return Base64.getEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8));
}
private static String encodeRecoveryPassword(final byte[] recoveryPassword) {
return Base64.getEncoder().encodeToString(recoveryPassword);
}
}