Registration Recovery Password support in `/v1/registration`

This commit is contained in:
Sergey Skrobotov 2023-02-08 13:11:10 -08:00
parent 4a3880b5ae
commit 7558489ad0
4 changed files with 167 additions and 41 deletions

View File

@ -747,7 +747,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getCdnConfiguration().getBucket(), zkProfileOperations, batchIdentityCheckExecutor), config.getCdnConfiguration().getBucket(), zkProfileOperations, batchIdentityCheckExecutor),
new ProvisioningController(rateLimiters, provisioningManager), new ProvisioningController(rateLimiters, provisioningManager),
new RegistrationController(accountsManager, registrationServiceClient, registrationLockVerificationManager, new RegistrationController(accountsManager, registrationServiceClient, registrationLockVerificationManager,
rateLimiters), registrationRecoveryPasswordsManager, rateLimiters),
new RemoteConfigController(remoteConfigsManager, adminEventLogger, new RemoteConfigController(remoteConfigsManager, adminEventLogger,
config.getRemoteConfigConfiguration().getAuthorizedTokens(), config.getRemoteConfigConfiguration().getAuthorizedTokens(),
config.getRemoteConfigConfiguration().getGlobalConfig()), config.getRemoteConfigConfiguration().getGlobalConfig()),

View File

@ -17,7 +17,6 @@ import java.security.MessageDigest;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Base64;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CancellationException; import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
@ -26,8 +25,8 @@ import java.util.concurrent.TimeoutException;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import javax.ws.rs.BadRequestException; import javax.ws.rs.BadRequestException;
import javax.ws.rs.ClientErrorException;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
import javax.ws.rs.ForbiddenException;
import javax.ws.rs.HeaderParam; import javax.ws.rs.HeaderParam;
import javax.ws.rs.NotAuthorizedException; import javax.ws.rs.NotAuthorizedException;
import javax.ws.rs.POST; import javax.ws.rs.POST;
@ -37,7 +36,6 @@ import javax.ws.rs.ServerErrorException;
import javax.ws.rs.WebApplicationException; import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.apache.http.HttpStatus;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader; import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader;
@ -50,6 +48,7 @@ import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient; import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
@ -70,19 +69,23 @@ public class RegistrationController {
private static final String VERIFICATION_TYPE_TAG_NAME = "verification"; private static final String VERIFICATION_TYPE_TAG_NAME = "verification";
private static final Duration REGISTRATION_RPC_TIMEOUT = Duration.ofSeconds(15); private static final Duration REGISTRATION_RPC_TIMEOUT = Duration.ofSeconds(15);
private static final long VERIFICATION_TIMEOUT_SECONDS = REGISTRATION_RPC_TIMEOUT.plusSeconds(1).getSeconds();
private final AccountsManager accounts; private final AccountsManager accounts;
private final RegistrationServiceClient registrationServiceClient; private final RegistrationServiceClient registrationServiceClient;
private final RegistrationLockVerificationManager registrationLockVerificationManager; private final RegistrationLockVerificationManager registrationLockVerificationManager;
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager;
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
public RegistrationController(final AccountsManager accounts, public RegistrationController(final AccountsManager accounts,
final RegistrationServiceClient registrationServiceClient, final RegistrationServiceClient registrationServiceClient,
final RegistrationLockVerificationManager registrationLockVerificationManager, final RegistrationLockVerificationManager registrationLockVerificationManager,
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager,
final RateLimiters rateLimiters) { final RateLimiters rateLimiters) {
this.accounts = accounts; this.accounts = accounts;
this.registrationServiceClient = registrationServiceClient; this.registrationServiceClient = registrationServiceClient;
this.registrationLockVerificationManager = registrationLockVerificationManager; this.registrationLockVerificationManager = registrationLockVerificationManager;
this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager;
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
} }
@ -98,34 +101,14 @@ public class RegistrationController {
rateLimiters.getRegistrationLimiter().validate(registrationRequest.sessionId()); rateLimiters.getRegistrationLimiter().validate(registrationRequest.sessionId());
final byte[] sessionId;
try {
sessionId = Base64.getDecoder().decode(registrationRequest.sessionId());
} catch (final IllegalArgumentException e) {
throw new ClientErrorException("Malformed session ID", HttpStatus.SC_UNPROCESSABLE_ENTITY);
}
final String number = authorizationHeader.getUsername(); final String number = authorizationHeader.getUsername();
final String password = authorizationHeader.getPassword(); final String password = authorizationHeader.getPassword();
final String verificationType = "phoneNumberVerification"; // decide on the method of verification based on the registration request parameters and verify
try { final RegistrationRequest.VerificationType verificationType = registrationRequest.verificationType();
final Optional<RegistrationSession> maybeSession = registrationServiceClient.getSession(sessionId, switch (verificationType) {
REGISTRATION_RPC_TIMEOUT) case SESSION -> verifyBySessionId(number, registrationRequest.decodeSessionId());
.get(REGISTRATION_RPC_TIMEOUT.plusSeconds(1).getSeconds(), TimeUnit.SECONDS); case RECOVERY_PASSWORD -> verifyByRecoveryPassword(number, registrationRequest.recoveryPassword());
final RegistrationSession session = maybeSession.orElseThrow(
() -> new NotAuthorizedException("session not verified"));
if (!MessageDigest.isEqual(number.getBytes(), session.number().getBytes())) {
throw new BadRequestException("number does not match session");
}
if (!session.verified()) {
throw new NotAuthorizedException("session not verified");
}
} catch (final CancellationException | ExecutionException | TimeoutException e) {
logger.error("Registration service failure", e);
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE);
} }
final Optional<Account> existingAccount = accounts.getByE164(number); final Optional<Account> existingAccount = accounts.getByE164(number);
@ -150,10 +133,15 @@ public class RegistrationController {
final Account account = accounts.create(number, password, signalAgent, registrationRequest.accountAttributes(), final Account account = accounts.create(number, password, signalAgent, registrationRequest.accountAttributes(),
existingAccount.map(Account::getBadges).orElseGet(ArrayList::new)); existingAccount.map(Account::getBadges).orElseGet(ArrayList::new));
// now that the number is verified and account is created,
// we can store recovery password for this number
registrationRequest.accountAttributes().recoveryPassword().ifPresent(recoveryPassword ->
registrationRecoveryPasswordsManager.storeForCurrentNumber(number, recoveryPassword));
Metrics.counter(ACCOUNT_CREATED_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), Metrics.counter(ACCOUNT_CREATED_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)), Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)),
Tag.of(REGION_CODE_TAG_NAME, Util.getRegion(number)), Tag.of(REGION_CODE_TAG_NAME, Util.getRegion(number)),
Tag.of(VERIFICATION_TYPE_TAG_NAME, verificationType))) Tag.of(VERIFICATION_TYPE_TAG_NAME, verificationType.name())))
.increment(); .increment();
return new AccountIdentityResponse(account.getUuid(), return new AccountIdentityResponse(account.getUuid(),
@ -163,4 +151,34 @@ public class RegistrationController {
existingAccount.map(Account::isStorageSupported).orElse(false)); existingAccount.map(Account::isStorageSupported).orElse(false));
} }
private void verifyBySessionId(final String number, final byte[] sessionId) throws InterruptedException {
try {
final RegistrationSession session = registrationServiceClient
.getSession(sessionId, REGISTRATION_RPC_TIMEOUT)
.get(VERIFICATION_TIMEOUT_SECONDS, TimeUnit.SECONDS)
.orElseThrow(() -> new NotAuthorizedException("session not verified"));
if (!MessageDigest.isEqual(number.getBytes(), session.number().getBytes())) {
throw new BadRequestException("number does not match session");
}
if (!session.verified()) {
throw new NotAuthorizedException("session not verified");
}
} catch (final CancellationException | ExecutionException | TimeoutException e) {
logger.error("Registration service failure", e);
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE);
}
}
private void verifyByRecoveryPassword(final String number, final byte[] recoveryPassword) throws InterruptedException {
try {
final boolean verified = registrationRecoveryPasswordsManager.verify(number, recoveryPassword)
.get(VERIFICATION_TIMEOUT_SECONDS, TimeUnit.SECONDS);
if (!verified) {
throw new ForbiddenException("recoveryPassword couldn't be verified");
}
} catch (final ExecutionException | TimeoutException e) {
throw new ServerErrorException(Response.Status.SERVICE_UNAVAILABLE);
}
}
} }

View File

@ -5,12 +5,43 @@
package org.whispersystems.textsecuregcm.entities; package org.whispersystems.textsecuregcm.entities;
import javax.validation.Valid; import static org.apache.commons.lang3.StringUtils.isNotBlank;
import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
public record RegistrationRequest(@NotBlank String sessionId, import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import java.util.Base64;
import javax.validation.Valid;
import javax.validation.constraints.AssertTrue;
import javax.validation.constraints.NotNull;
import javax.ws.rs.ClientErrorException;
import org.apache.http.HttpStatus;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
public record RegistrationRequest(String sessionId,
@JsonDeserialize(using = ByteArrayAdapter.Deserializing.class) byte[] recoveryPassword,
@NotNull @Valid AccountAttributes accountAttributes, @NotNull @Valid AccountAttributes accountAttributes,
boolean skipDeviceTransfer) { boolean skipDeviceTransfer) {
public enum VerificationType {
SESSION,
RECOVERY_PASSWORD
}
// for the @AssertTrue to work with bean validation, method name must follow 'isSmth()'/'getSmth()' naming convention
@AssertTrue
public boolean isValid() {
// checking that exactly one of sessionId/recoveryPassword is non-empty
return isNotBlank(sessionId) ^ (recoveryPassword != null && recoveryPassword.length > 0);
}
public VerificationType verificationType() {
return isNotBlank(sessionId) ? VerificationType.SESSION : VerificationType.RECOVERY_PASSWORD;
}
public byte[] decodeSessionId() {
try {
return Base64.getUrlDecoder().decode(sessionId());
} catch (final IllegalArgumentException e) {
throw new ClientErrorException("Malformed session ID", HttpStatus.SC_UNPROCESSABLE_ENTITY);
}
}
} }

View File

@ -6,7 +6,10 @@
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -35,8 +38,11 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.auth.RegistrationLockError; import org.whispersystems.textsecuregcm.auth.RegistrationLockError;
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.RegistrationRequest;
import org.whispersystems.textsecuregcm.entities.RegistrationSession; import org.whispersystems.textsecuregcm.entities.RegistrationSession;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
@ -46,6 +52,7 @@ import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient; import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
@ -56,6 +63,8 @@ class RegistrationControllerTest {
private final RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class); private final RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class);
private final RegistrationLockVerificationManager registrationLockVerificationManager = mock( private final RegistrationLockVerificationManager registrationLockVerificationManager = mock(
RegistrationLockVerificationManager.class); RegistrationLockVerificationManager.class);
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock(
RegistrationRecoveryPasswordsManager.class);
private final RateLimiters rateLimiters = mock(RateLimiters.class); private final RateLimiters rateLimiters = mock(RateLimiters.class);
private final RateLimiter registrationLimiter = mock(RateLimiter.class); private final RateLimiter registrationLimiter = mock(RateLimiter.class);
@ -70,7 +79,7 @@ class RegistrationControllerTest {
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource( .addResource(
new RegistrationController(accountsManager, registrationServiceClient, registrationLockVerificationManager, new RegistrationController(accountsManager, registrationServiceClient, registrationLockVerificationManager,
rateLimiters)) registrationRecoveryPasswordsManager, rateLimiters))
.build(); .build();
@BeforeEach @BeforeEach
@ -79,6 +88,14 @@ class RegistrationControllerTest {
when(rateLimiters.getPinLimiter()).thenReturn(pinLimiter); when(rateLimiters.getPinLimiter()).thenReturn(pinLimiter);
} }
@Test
public void testRegistrationRequest() throws Exception {
assertFalse(new RegistrationRequest("", new byte[0], new AccountAttributes(), true).isValid());
assertFalse(new RegistrationRequest("some", new byte[32], new AccountAttributes(), true).isValid());
assertTrue(new RegistrationRequest("", new byte[32], new AccountAttributes(), true).isValid());
assertTrue(new RegistrationRequest("some", new byte[0], new AccountAttributes(), true).isValid());
}
@Test @Test
void unprocessableRequestJson() { void unprocessableRequestJson() {
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
@ -151,6 +168,20 @@ class RegistrationControllerTest {
} }
} }
@Test
void recoveryPasswordManagerVerificationFailureOrTimeout() {
when(registrationRecoveryPasswordsManager.verify(any(), any()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException()));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) {
assertEquals(HttpStatus.SC_SERVICE_UNAVAILABLE, response.getStatus());
}
}
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void registrationServiceSessionCheck(@Nullable final RegistrationSession session, final int expectedStatus, void registrationServiceSessionCheck(@Nullable final RegistrationSession session, final int expectedStatus,
@ -175,6 +206,38 @@ class RegistrationControllerTest {
); );
} }
@Test
void recoveryPasswordManagerVerificationTrue() throws InterruptedException {
when(registrationRecoveryPasswordsManager.verify(any(), any()))
.thenReturn(CompletableFuture.completedFuture(true));
when(accountsManager.create(any(), any(), any(), any(), any()))
.thenReturn(mock(Account.class));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
final byte[] recoveryPassword = new byte[32];
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(recoveryPassword)))) {
assertEquals(200, response.getStatus());
Mockito.verify(registrationRecoveryPasswordsManager).storeForCurrentNumber(eq(NUMBER), eq(recoveryPassword));
}
}
@Test
void recoveryPasswordManagerVerificationFalse() throws InterruptedException {
when(registrationRecoveryPasswordsManager.verify(any(), any()))
.thenReturn(CompletableFuture.completedFuture(false));
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
try (Response response = request.post(Entity.json(requestJsonRecoveryPassword(new byte[32])))) {
assertEquals(403, response.getStatus());
}
}
@ParameterizedTest @ParameterizedTest
@EnumSource(RegistrationLockError.class) @EnumSource(RegistrationLockError.class)
void registrationLock(final RegistrationLockError error) throws Exception { void registrationLock(final RegistrationLockError error) throws Exception {
@ -227,7 +290,7 @@ class RegistrationControllerTest {
.target("/v1/registration") .target("/v1/registration")
.request() .request()
.header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER)); .header(HttpHeaders.AUTHORIZATION, authorizationHeader(NUMBER));
try (Response response = request.post(Entity.json(requestJson("sessionId", skipDeviceTransfer)))) { try (Response response = request.post(Entity.json(requestJson("sessionId", new byte[0], skipDeviceTransfer)))) {
assertEquals(expectedStatus, response.getStatus()); assertEquals(expectedStatus, response.getStatus());
} }
} }
@ -252,21 +315,32 @@ class RegistrationControllerTest {
/** /**
* Valid request JSON with the give session ID and skipDeviceTransfer * Valid request JSON with the give session ID and skipDeviceTransfer
*/ */
private static String requestJson(final String sessionId, final boolean skipDeviceTransfer) { private static String requestJson(final String sessionId, final byte[] recoveryPassword, final boolean skipDeviceTransfer) {
final String rp = encodeRecoveryPassword(recoveryPassword);
return String.format(""" return String.format("""
{ {
"sessionId": "%s", "sessionId": "%s",
"accountAttributes": {}, "recoveryPassword": "%s",
"accountAttributes": {
"recoveryPassword": "%s"
},
"skipDeviceTransfer": %s "skipDeviceTransfer": %s
} }
""", encodeSessionId(sessionId), skipDeviceTransfer); """, encodeSessionId(sessionId), rp, rp, skipDeviceTransfer);
} }
/** /**
* Valid request JSON with the give session ID * Valid request JSON with the given session ID
*/ */
private static String requestJson(final String sessionId) { private static String requestJson(final String sessionId) {
return requestJson(sessionId, false); return requestJson(sessionId, new byte[0], false);
}
/**
* Valid request JSON with the given Recovery Password
*/
private static String requestJsonRecoveryPassword(final byte[] recoveryPassword) {
return requestJson("", recoveryPassword, false);
} }
/** /**
@ -303,4 +377,7 @@ class RegistrationControllerTest {
return Base64.getEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8)); return Base64.getEncoder().encodeToString(sessionId.getBytes(StandardCharsets.UTF_8));
} }
private static String encodeRecoveryPassword(final byte[] recoveryPassword) {
return Base64.getEncoder().encodeToString(recoveryPassword);
}
} }