Add collation key to registration service session creation rpc call

This commit is contained in:
Chris Eager 2025-01-22 12:01:12 -06:00 committed by Chris Eager
parent 5cc76f48aa
commit 47550d48e7
10 changed files with 67 additions and 17 deletions

View File

@ -92,6 +92,8 @@ paymentsService.coinGeckoApiKey: unset
currentReportingKey.secret: AAAAAAAAAAA= currentReportingKey.secret: AAAAAAAAAAA=
currentReportingKey.salt: AAAAAAAAAAA= currentReportingKey.salt: AAAAAAAAAAA=
registrationService.collationKeySalt: AAAAAAAAAAA=
turn.secret: AAAAAAAAAAA= turn.secret: AAAAAAAAAAA=
turn.cloudflare.apiToken: ABCDEFGHIJKLM turn.cloudflare.apiToken: ABCDEFGHIJKLM

View File

@ -399,6 +399,7 @@ registrationService:
"example": "example" "example": "example"
} }
identityTokenAudience: https://registration.example.com identityTokenAudience: https://registration.example.com
collationKeySalt: secret://registrationService.collationKeySalt
registrationCaCertificate: | # Registration service TLS certificate trust root registrationCaCertificate: | # Registration service TLS certificate trust root
-----BEGIN CERTIFICATE----- -----BEGIN CERTIFICATE-----
ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ/0123456789+abcdefghijklmnopqrstuvwxyz

View File

@ -3,9 +3,11 @@ package org.whispersystems.textsecuregcm.configuration;
import com.fasterxml.jackson.annotation.JsonTypeName; import com.fasterxml.jackson.annotation.JsonTypeName;
import io.dropwizard.core.setup.Environment; import io.dropwizard.core.setup.Environment;
import jakarta.validation.constraints.NotBlank; import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import java.io.IOException; import java.io.IOException;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes;
import org.whispersystems.textsecuregcm.registration.IdentityTokenCallCredentials; import org.whispersystems.textsecuregcm.registration.IdentityTokenCallCredentials;
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient; import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
@ -14,7 +16,8 @@ public record RegistrationServiceConfiguration(@NotBlank String host,
int port, int port,
@NotBlank String credentialConfigurationJson, @NotBlank String credentialConfigurationJson,
@NotBlank String identityTokenAudience, @NotBlank String identityTokenAudience,
@NotBlank String registrationCaCertificate) implements @NotBlank String registrationCaCertificate,
@NotNull SecretBytes collationKeySalt) implements
RegistrationServiceClientFactory { RegistrationServiceClientFactory {
@Override @Override
@ -26,7 +29,7 @@ public record RegistrationServiceConfiguration(@NotBlank String host,
environment.lifecycle().manage(callCredentials); environment.lifecycle().manage(callCredentials);
return new RegistrationServiceClient(host, port, callCredentials, registrationCaCertificate, return new RegistrationServiceClient(host, port, callCredentials, registrationCaCertificate, collationKeySalt.value(),
identityRefreshExecutor); identityRefreshExecutor);
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);

View File

@ -173,7 +173,8 @@ public class VerificationController {
name = "Retry-After", name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed", description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed",
schema = @Schema(implementation = Integer.class))) schema = @Schema(implementation = Integer.class)))
public VerificationSessionResponse createSession(@NotNull @Valid final CreateVerificationSessionRequest request) public VerificationSessionResponse createSession(@NotNull @Valid final CreateVerificationSessionRequest request,
@Context final ContainerRequestContext requestContext)
throws RateLimitExceededException, ObsoletePhoneNumberFormatException { throws RateLimitExceededException, ObsoletePhoneNumberFormatException {
final Pair<String, PushNotification.TokenType> pushTokenAndType = validateAndExtractPushToken( final Pair<String, PushNotification.TokenType> pushTokenAndType = validateAndExtractPushToken(
@ -188,7 +189,9 @@ public class VerificationController {
final RegistrationServiceSession registrationServiceSession; final RegistrationServiceSession registrationServiceSession;
try { try {
registrationServiceSession = registrationServiceClient.createRegistrationSession(phoneNumber, final String sourceHost = (String) requestContext.getProperty(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
registrationServiceSession = registrationServiceClient.createRegistrationSession(phoneNumber, sourceHost,
accountsManager.getByE164(request.getNumber()).isPresent(), accountsManager.getByE164(request.getNumber()).isPresent(),
REGISTRATION_RPC_TIMEOUT).join(); REGISTRATION_RPC_TIMEOUT).join();
} catch (final CancellationException e) { } catch (final CancellationException e) {

View File

@ -14,12 +14,15 @@ import io.grpc.TlsChannelCredentials;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.NoSuchAlgorithmException;
import java.time.Duration; import java.time.Duration;
import java.util.Base64;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import javax.crypto.Mac;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.checkerframework.checker.nullness.qual.Nullable; import org.checkerframework.checker.nullness.qual.Nullable;
import org.signal.registration.rpc.CheckVerificationCodeRequest; import org.signal.registration.rpc.CheckVerificationCodeRequest;
@ -35,9 +38,12 @@ import org.whispersystems.textsecuregcm.util.CompletableFutureUtil;
public class RegistrationServiceClient implements Managed { public class RegistrationServiceClient implements Managed {
private static final Base64.Encoder BASE64_UNPADDED_ENCODER = Base64.getEncoder().withoutPadding();
private final ManagedChannel channel; private final ManagedChannel channel;
private final RegistrationServiceGrpc.RegistrationServiceFutureStub stub; private final RegistrationServiceGrpc.RegistrationServiceFutureStub stub;
private final Executor callbackExecutor; private final Executor callbackExecutor;
private final byte[] collationKeySalt;
/** /**
* @param from an e164 in a {@code long} representation e.g. {@code 18005550123} * @param from an e164 in a {@code long} representation e.g. {@code 18005550123}
@ -60,6 +66,7 @@ public class RegistrationServiceClient implements Managed {
final int port, final int port,
final CallCredentials callCredentials, final CallCredentials callCredentials,
final String caCertificatePem, final String caCertificatePem,
final byte[] collationKeySalt,
final Executor callbackExecutor) throws IOException { final Executor callbackExecutor) throws IOException {
try (final ByteArrayInputStream certificateInputStream = new ByteArrayInputStream(caCertificatePem.getBytes(StandardCharsets.UTF_8))) { try (final ByteArrayInputStream certificateInputStream = new ByteArrayInputStream(caCertificatePem.getBytes(StandardCharsets.UTF_8))) {
@ -73,19 +80,22 @@ public class RegistrationServiceClient implements Managed {
} }
this.stub = RegistrationServiceGrpc.newFutureStub(channel).withCallCredentials(callCredentials); this.stub = RegistrationServiceGrpc.newFutureStub(channel).withCallCredentials(callCredentials);
this.collationKeySalt = collationKeySalt;
this.callbackExecutor = callbackExecutor; this.callbackExecutor = callbackExecutor;
} }
public CompletableFuture<RegistrationServiceSession> createRegistrationSession( public CompletableFuture<RegistrationServiceSession> createRegistrationSession(
final Phonenumber.PhoneNumber phoneNumber, final boolean accountExistsWithPhoneNumber, final Duration timeout) { final Phonenumber.PhoneNumber phoneNumber, final String sourceHost, final boolean accountExistsWithPhoneNumber, final Duration timeout) {
final long e164 = Long.parseLong( final long e164 = Long.parseLong(
PhoneNumberUtil.getInstance().format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164).substring(1)); PhoneNumberUtil.getInstance().format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164).substring(1));
final String rateLimitCollationKey = hmac(sourceHost, collationKeySalt);
return CompletableFutureUtil.toCompletableFuture(stub.withDeadline(toDeadline(timeout)) return CompletableFutureUtil.toCompletableFuture(stub.withDeadline(toDeadline(timeout))
.createSession(CreateRegistrationSessionRequest.newBuilder() .createSession(CreateRegistrationSessionRequest.newBuilder()
.setE164(e164) .setE164(e164)
.setAccountExistsWithE164(accountExistsWithPhoneNumber) .setAccountExistsWithE164(accountExistsWithPhoneNumber)
.setRateLimitCollationKey(rateLimitCollationKey)
.build()), callbackExecutor) .build()), callbackExecutor)
.thenApply(response -> switch (response.getResponseCase()) { .thenApply(response -> switch (response.getResponseCase()) {
case SESSION_METADATA -> buildSessionResponseFromMetadata(response.getSessionMetadata()); case SESSION_METADATA -> buildSessionResponseFromMetadata(response.getSessionMetadata());
@ -259,4 +269,18 @@ public class RegistrationServiceClient implements Managed {
channel.shutdown(); channel.shutdown();
} }
} }
private static String hmac(String sourceHost, byte[] collationKeySalt) {
final Mac hmacSha256;
try {
hmacSha256 = Mac.getInstance("HmacSHA256");
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
}
hmacSha256.update(sourceHost.getBytes(StandardCharsets.UTF_8));
hmacSha256.update(collationKeySalt);
return BASE64_UNPADDED_ENCODER.encodeToString(hmacSha256.doFinal());
}
} }

View File

@ -39,6 +39,12 @@ message CreateRegistrationSessionRequest {
* session represents a "re-registration" attempt). * session represents a "re-registration" attempt).
*/ */
bool account_exists_with_e164 = 2; bool account_exists_with_e164 = 2;
/**
* The session creation rate limit for the number will be
* collated by this key.
*/
string rate_limit_collation_key = 3;
} }
message CreateRegistrationSessionResponse { message CreateRegistrationSessionResponse {

View File

@ -23,6 +23,7 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import org.checkerframework.checker.nullness.qual.Nullable; import org.checkerframework.checker.nullness.qual.Nullable;
import org.whispersystems.textsecuregcm.configuration.secrets.SecretBytes;
import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession; import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession;
import org.whispersystems.textsecuregcm.registration.ClientType; import org.whispersystems.textsecuregcm.registration.ClientType;
import org.whispersystems.textsecuregcm.registration.MessageTransport; import org.whispersystems.textsecuregcm.registration.MessageTransport;
@ -35,12 +36,16 @@ public class StubRegistrationServiceClientFactory implements RegistrationService
@NotNull @NotNull
private String registrationCaCertificate; private String registrationCaCertificate;
@JsonProperty
@NotNull
private SecretBytes collationKeySalt;
@Override @Override
public RegistrationServiceClient build(final Environment environment, final Executor callbackExecutor, public RegistrationServiceClient build(final Environment environment, final Executor callbackExecutor,
final ScheduledExecutorService identityRefreshExecutor) { final ScheduledExecutorService identityRefreshExecutor) {
try { try {
return new StubRegistrationServiceClient(registrationCaCertificate); return new StubRegistrationServiceClient(registrationCaCertificate, collationKeySalt.value());
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -50,13 +55,13 @@ public class StubRegistrationServiceClientFactory implements RegistrationService
private final static Map<String, RegistrationServiceSession> SESSIONS = new ConcurrentHashMap<>(); private final static Map<String, RegistrationServiceSession> SESSIONS = new ConcurrentHashMap<>();
public StubRegistrationServiceClient(final String registrationCaCertificate) throws IOException { public StubRegistrationServiceClient(final String registrationCaCertificate, final byte[] collationKeySalt) throws IOException {
super("example.com", 8080, null, registrationCaCertificate, null); super("example.com", 8080, null, registrationCaCertificate, collationKeySalt, null);
} }
@Override @Override
public CompletableFuture<RegistrationServiceSession> createRegistrationSession( public CompletableFuture<RegistrationServiceSession> createRegistrationSession(
final Phonenumber.PhoneNumber phoneNumber, final boolean accountExistsWithPhoneNumber, final Duration timeout) { final Phonenumber.PhoneNumber phoneNumber, final String sourceHost, final boolean accountExistsWithPhoneNumber, final Duration timeout) {
final String e164 = PhoneNumberUtil.getInstance() final String e164 = PhoneNumberUtil.getInstance()
.format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164); .format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164);

View File

@ -84,6 +84,7 @@ import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.storage.VerificationSessionManager; import org.whispersystems.textsecuregcm.storage.VerificationSessionManager;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
class VerificationControllerTest { class VerificationControllerTest {
@ -120,6 +121,7 @@ class VerificationControllerTest {
.addProvider(new NonNormalizedPhoneNumberExceptionMapper()) .addProvider(new NonNormalizedPhoneNumberExceptionMapper())
.addProvider(new ObsoletePhoneNumberFormatExceptionMapper()) .addProvider(new ObsoletePhoneNumberFormatExceptionMapper())
.addProvider(new RegistrationServiceSenderExceptionMapper()) .addProvider(new RegistrationServiceSenderExceptionMapper())
.addProvider(new TestRemoteAddressFilterProvider("127.0.0.1"))
.setMapper(SystemMapper.jsonMapper()) .setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource( .addResource(
@ -190,7 +192,7 @@ class VerificationControllerTest {
@Test @Test
void createSessionRateLimited() { void createSessionRateLimited() {
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyString(), anyBoolean(), any()))
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null))); .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null)));
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
@ -204,7 +206,7 @@ class VerificationControllerTest {
@Test @Test
void createSessionRegistrationServiceError() { void createSessionRegistrationServiceError() {
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyString(), anyBoolean(), any()))
.thenReturn(CompletableFuture.failedFuture(new RuntimeException("expected service error"))); .thenReturn(CompletableFuture.failedFuture(new RuntimeException("expected service error")));
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
@ -219,7 +221,7 @@ class VerificationControllerTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void createBeninSessionSuccess(final String requestedNumber, final String expectedNumber) { void createBeninSessionSuccess(final String requestedNumber, final String expectedNumber) {
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyString(), anyBoolean(), any()))
.thenReturn( .thenReturn(
CompletableFuture.completedFuture( CompletableFuture.completedFuture(
new RegistrationServiceSession(SESSION_ID, requestedNumber, false, null, null, null, new RegistrationServiceSession(SESSION_ID, requestedNumber, false, null, null, null,
@ -236,7 +238,7 @@ class VerificationControllerTest {
final ArgumentCaptor<Phonenumber.PhoneNumber> phoneNumberArgumentCaptor = ArgumentCaptor.forClass( final ArgumentCaptor<Phonenumber.PhoneNumber> phoneNumberArgumentCaptor = ArgumentCaptor.forClass(
Phonenumber.PhoneNumber.class); Phonenumber.PhoneNumber.class);
verify(registrationServiceClient).createRegistrationSession(phoneNumberArgumentCaptor.capture(), anyBoolean(), any()); verify(registrationServiceClient).createRegistrationSession(phoneNumberArgumentCaptor.capture(), anyString(), anyBoolean(), any());
final Phonenumber.PhoneNumber phoneNumber = phoneNumberArgumentCaptor.getValue(); final Phonenumber.PhoneNumber phoneNumber = phoneNumberArgumentCaptor.getValue();
assertEquals(expectedNumber, PhoneNumberUtil.getInstance().format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164)); assertEquals(expectedNumber, PhoneNumberUtil.getInstance().format(phoneNumber, PhoneNumberUtil.PhoneNumberFormat.E164));
@ -260,7 +262,7 @@ class VerificationControllerTest {
.format(PhoneNumberUtil.getInstance().getExampleNumber("BJ"), PhoneNumberUtil.PhoneNumberFormat.E164); .format(PhoneNumberUtil.getInstance().getExampleNumber("BJ"), PhoneNumberUtil.PhoneNumberFormat.E164);
final String oldFormatBeninE164 = newFormatBeninE164.replaceFirst("01", ""); final String oldFormatBeninE164 = newFormatBeninE164.replaceFirst("01", "");
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyString(), anyBoolean(), any()))
.thenReturn( .thenReturn(
CompletableFuture.completedFuture( CompletableFuture.completedFuture(
new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null, new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null,
@ -281,7 +283,7 @@ class VerificationControllerTest {
@MethodSource @MethodSource
void createSessionSuccess(final String pushToken, final String pushTokenType, void createSessionSuccess(final String pushToken, final String pushTokenType,
final List<VerificationSession.Information> expectedRequestedInformation) { final List<VerificationSession.Information> expectedRequestedInformation) {
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyString(), anyBoolean(), any()))
.thenReturn( .thenReturn(
CompletableFuture.completedFuture( CompletableFuture.completedFuture(
new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null, new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null,
@ -315,7 +317,7 @@ class VerificationControllerTest {
@ParameterizedTest @ParameterizedTest
@ValueSource(booleans = {true, false}) @ValueSource(booleans = {true, false})
void createSessionReregistration(final boolean isReregistration) throws NumberParseException { void createSessionReregistration(final boolean isReregistration) throws NumberParseException {
when(registrationServiceClient.createRegistrationSession(any(), anyBoolean(), any())) when(registrationServiceClient.createRegistrationSession(any(), anyString(), anyBoolean(), any()))
.thenReturn( .thenReturn(
CompletableFuture.completedFuture( CompletableFuture.completedFuture(
new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null, new RegistrationServiceSession(SESSION_ID, NUMBER, false, null, null, null,
@ -337,6 +339,7 @@ class VerificationControllerTest {
verify(registrationServiceClient).createRegistrationSession( verify(registrationServiceClient).createRegistrationSession(
eq(PhoneNumberUtil.getInstance().parse(NUMBER, null)), eq(PhoneNumberUtil.getInstance().parse(NUMBER, null)),
anyString(),
eq(isReregistration), eq(isReregistration),
any() any()
); );

View File

@ -162,6 +162,8 @@ paymentsService.coinGeckoApiKey: unset
currentReportingKey.secret: AAAAAAAAAAA= currentReportingKey.secret: AAAAAAAAAAA=
currentReportingKey.salt: AAAAAAAAAAA= currentReportingKey.salt: AAAAAAAAAAA=
registrationService.collationKeySalt: AAAAAAAAAAA=
turn.secret: AAAAAAAAAAA= turn.secret: AAAAAAAAAAA=
turn.cloudflare.apiToken: ABCDEFGHIJKLM turn.cloudflare.apiToken: ABCDEFGHIJKLM

View File

@ -393,6 +393,7 @@ oneTimeDonations:
registrationService: registrationService:
type: stub type: stub
collationKeySalt: secret://registrationService.collationKeySalt
registrationCaCertificate: | registrationCaCertificate: |
-----BEGIN CERTIFICATE----- -----BEGIN CERTIFICATE-----
MIIDazCCAlOgAwIBAgIUW5lcNWkuynRVc8Rq5pO6mHQBuZAwDQYJKoZIhvcNAQEL MIIDazCCAlOgAwIBAgIUW5lcNWkuynRVc8Rq5pO6mHQBuZAwDQYJKoZIhvcNAQEL