diff --git a/integration-tests/src/main/java/org/signal/integration/IntegrationTools.java b/integration-tests/src/main/java/org/signal/integration/IntegrationTools.java index d8ce4da15..c8ee2a637 100644 --- a/integration-tests/src/main/java/org/signal/integration/IntegrationTools.java +++ b/integration-tests/src/main/java/org/signal/integration/IntegrationTools.java @@ -12,6 +12,7 @@ import java.util.concurrent.CompletableFuture; import org.signal.integration.config.Config; import org.whispersystems.textsecuregcm.metrics.NoopAwsSdkMetricPublisher; import org.whispersystems.textsecuregcm.registration.VerificationSession; +import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswords; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; import org.whispersystems.textsecuregcm.storage.VerificationSessionManager; @@ -37,6 +38,9 @@ public class IntegrationTools { final DynamoDbClient dynamoDbClient = config.dynamoDbClient().buildSyncClient(credentialsProvider, new NoopAwsSdkMetricPublisher()); + final PhoneNumberIdentifiers phoneNumberIdentifiers = + new PhoneNumberIdentifiers(dynamoDbAsyncClient, config.dynamoDbTables().phoneNumberIdentifiers()); + final RegistrationRecoveryPasswords registrationRecoveryPasswords = new RegistrationRecoveryPasswords( config.dynamoDbTables().registrationRecovery(), Duration.ofDays(1), dynamoDbClient, dynamoDbAsyncClient); @@ -44,7 +48,7 @@ public class IntegrationTools { dynamoDbAsyncClient, config.dynamoDbTables().verificationSessions(), Clock.systemUTC()); return new IntegrationTools( - new RegistrationRecoveryPasswordsManager(registrationRecoveryPasswords), + new RegistrationRecoveryPasswordsManager(registrationRecoveryPasswords, phoneNumberIdentifiers), new VerificationSessionManager(verificationSessions) ); } diff --git a/integration-tests/src/main/java/org/signal/integration/config/DynamoDbTables.java b/integration-tests/src/main/java/org/signal/integration/config/DynamoDbTables.java index 2e76e3c70..35c8c24f0 100644 --- a/integration-tests/src/main/java/org/signal/integration/config/DynamoDbTables.java +++ b/integration-tests/src/main/java/org/signal/integration/config/DynamoDbTables.java @@ -8,5 +8,6 @@ package org.signal.integration.config; import jakarta.validation.constraints.NotBlank; public record DynamoDbTables(@NotBlank String registrationRecovery, - @NotBlank String verificationSessions) { + @NotBlank String verificationSessions, + @NotBlank String phoneNumberIdentifiers) { } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index bf5a3eb2d..220429d03 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -583,8 +583,8 @@ public class WhisperServerService extends Application addOrReplace(final String number, final SaltedTokenHash data) { - return asyncClient.putItem(PutItemRequest.builder() - .tableName(tableName) - .item(Map.of( - KEY_E164, AttributeValues.fromString(number), - ATTR_EXP, AttributeValues.fromLong(expirationSeconds()), - ATTR_SALT, AttributeValues.fromString(data.salt()), - ATTR_HASH, AttributeValues.fromString(data.hash()))) - .build()) + public CompletableFuture> lookup(final UUID phoneNumberIdentifier) { + return lookup(phoneNumberIdentifier.toString()); + } + + public CompletableFuture addOrReplace(final String number, final UUID phoneNumberIdentifier, final SaltedTokenHash data) { + final long expirationSeconds = expirationSeconds(); + + return asyncClient.transactWriteItems(TransactWriteItemsRequest.builder() + .transactItems( + buildPutRecoveryPasswordWriteItem(number, expirationSeconds, data.salt(), data.hash()) + // buildPutRecoveryPasswordWriteItem(phoneNumberIdentifier.toString(), expirationSeconds, data.salt(), data.hash()) + ) + .build()) .thenRun(Util.NOOP); } - public CompletableFuture removeEntry(final String number) { - return asyncClient.deleteItem(DeleteItemRequest.builder() + private TransactWriteItem buildPutRecoveryPasswordWriteItem(final String key, + final long expirationSeconds, + final String salt, + final String hash) { + + return TransactWriteItem.builder() + .put(Put.builder() .tableName(tableName) - .key(Map.of(KEY_E164, AttributeValues.fromString(number))) + .item(Map.of( + KEY_E164, AttributeValues.fromString(key), + ATTR_EXP, AttributeValues.fromLong(expirationSeconds), + ATTR_SALT, AttributeValues.fromString(salt), + ATTR_HASH, AttributeValues.fromString(hash))) .build()) + .build(); + } + + public CompletableFuture removeEntry(final String number, final UUID phoneNumberIdentifier) { + return asyncClient.transactWriteItems(TransactWriteItemsRequest.builder() + .transactItems( + buildDeleteRecoveryPasswordWriteItem(number), + buildDeleteRecoveryPasswordWriteItem(phoneNumberIdentifier.toString())) + .build()) .thenRun(Util.NOOP); } + private TransactWriteItem buildDeleteRecoveryPasswordWriteItem(final String key) { + return TransactWriteItem.builder() + .delete(Delete.builder() + .tableName(tableName) + .key(Map.of(KEY_E164, AttributeValues.fromString(key))) + .build()) + .build(); + } + private long expirationSeconds() { return clock.instant().plus(expiration).getEpochSecond(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswordsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswordsManager.java index 40c988219..ae58dbafb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswordsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswordsManager.java @@ -21,10 +21,13 @@ public class RegistrationRecoveryPasswordsManager { private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); private final RegistrationRecoveryPasswords registrationRecoveryPasswords; + private final PhoneNumberIdentifiers phoneNumberIdentifiers; + public RegistrationRecoveryPasswordsManager(final RegistrationRecoveryPasswords registrationRecoveryPasswords, + final PhoneNumberIdentifiers phoneNumberIdentifiers) { - public RegistrationRecoveryPasswordsManager(final RegistrationRecoveryPasswords registrationRecoveryPasswords) { this.registrationRecoveryPasswords = requireNonNull(registrationRecoveryPasswords); + this.phoneNumberIdentifiers = phoneNumberIdentifiers; } public CompletableFuture verify(final String number, final byte[] password) { @@ -41,26 +44,27 @@ public class RegistrationRecoveryPasswordsManager { public CompletableFuture storeForCurrentNumber(final String number, final byte[] password) { final String token = bytesToString(password); final SaltedTokenHash tokenHash = SaltedTokenHash.generateFor(token); - return registrationRecoveryPasswords.addOrReplace(number, tokenHash) - .whenComplete((result, error) -> { - if (error != null) { - logger.warn("Failed to store Registration Recovery Password", error); - } - }); + + return phoneNumberIdentifiers.getPhoneNumberIdentifier(number) + .thenCompose(phoneNumberIdentifier -> registrationRecoveryPasswords.addOrReplace(number, phoneNumberIdentifier, tokenHash) + .whenComplete((result, error) -> { + if (error != null) { + logger.warn("Failed to store Registration Recovery Password", error); + } + })); } public CompletableFuture removeForNumber(final String number) { - // remove is a "fire-and-forget" operation, - // there is no action to be taken on its completion - return registrationRecoveryPasswords.removeEntry(number) - .whenComplete((ignored, error) -> { - if (error instanceof ResourceNotFoundException) { - // These will naturally happen if a recovery password is already deleted. Since we can remove - // the recovery password through many flows, we avoid creating log messages for these exceptions - } else if (error != null) { - logger.warn("Failed to remove Registration Recovery Password", error); - } - }); + return phoneNumberIdentifiers.getPhoneNumberIdentifier(number) + .thenCompose(phoneNumberIdentifier -> registrationRecoveryPasswords.removeEntry(number, phoneNumberIdentifier) + .whenComplete((ignored, error) -> { + if (error instanceof ResourceNotFoundException) { + // These will naturally happen if a recovery password is already deleted. Since we can remove + // the recovery password through many flows, we avoid creating log messages for these exceptions + } else if (error != null) { + logger.warn("Failed to remove Registration Recovery Password", error); + } + })); } private static String bytesToString(final byte[] bytes) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java index 745451b07..e39873c51 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java @@ -170,9 +170,6 @@ record CommandDependencies( dynamoDbAsyncClient ); - RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = new RegistrationRecoveryPasswordsManager( - registrationRecoveryPasswords); - ClientPublicKeys clientPublicKeys = new ClientPublicKeys(dynamoDbAsyncClient, configuration.getDynamoDbTables().getClientPublicKeys().getTableName()); @@ -225,6 +222,8 @@ record CommandDependencies( configuration.getDynamoDbTables().getDeletedAccountsLock().getTableName()); ClientPublicKeysManager clientPublicKeysManager = new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor); + RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = + new RegistrationRecoveryPasswordsManager(registrationRecoveryPasswords, phoneNumberIdentifiers); AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster, pubsubClient, accountLockManager, keys, messagesManager, profilesManager, secureStorageClient, secureValueRecovery2Client, disconnectionRequestManager, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryTest.java index 1763fb977..8424db754 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryTest.java @@ -5,6 +5,7 @@ package org.whispersystems.textsecuregcm.storage; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -14,6 +15,7 @@ import java.time.Clock; import java.time.Duration; import java.util.Map; import java.util.Optional; +import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import org.junit.jupiter.api.BeforeEach; @@ -34,6 +36,7 @@ public class RegistrationRecoveryTest { private static final Duration EXPIRATION = Duration.ofSeconds(1000); private static final String NUMBER = "+18005555555"; + private static final UUID PNI = UUID.randomUUID(); private static final SaltedTokenHash ORIGINAL_HASH = SaltedTokenHash.generateFor("pass1"); @@ -41,72 +44,106 @@ public class RegistrationRecoveryTest { @RegisterExtension private static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension( + Tables.PNI, Tables.REGISTRATION_RECOVERY_PASSWORDS); - private RegistrationRecoveryPasswords store; + private RegistrationRecoveryPasswords registrationRecoveryPasswords; private RegistrationRecoveryPasswordsManager manager; @BeforeEach public void before() throws Exception { CLOCK.setTimeMillis(Clock.systemUTC().millis()); - store = new RegistrationRecoveryPasswords( + registrationRecoveryPasswords = new RegistrationRecoveryPasswords( Tables.REGISTRATION_RECOVERY_PASSWORDS.tableName(), EXPIRATION, DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), CLOCK ); - manager = new RegistrationRecoveryPasswordsManager(store); + + final PhoneNumberIdentifiers phoneNumberIdentifiers = + new PhoneNumberIdentifiers(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), Tables.PNI.tableName()); + + manager = new RegistrationRecoveryPasswordsManager(registrationRecoveryPasswords, phoneNumberIdentifiers); } @Test public void testLookupAfterWrite() throws Exception { - store.addOrReplace(NUMBER, ORIGINAL_HASH).get(); + registrationRecoveryPasswords.addOrReplace(NUMBER, PNI, ORIGINAL_HASH).get(); final long initialExp = fetchTimestamp(NUMBER); final long expectedExpiration = CLOCK.instant().getEpochSecond() + EXPIRATION.getSeconds(); assertEquals(expectedExpiration, initialExp); - final Optional saltedTokenHash = store.lookup(NUMBER).get(); - assertTrue(saltedTokenHash.isPresent()); - assertEquals(ORIGINAL_HASH.salt(), saltedTokenHash.get().salt()); - assertEquals(ORIGINAL_HASH.hash(), saltedTokenHash.get().hash()); + { + final Optional saltedTokenHashByNumber = registrationRecoveryPasswords.lookup(NUMBER).get(); + assertTrue(saltedTokenHashByNumber.isPresent()); + assertEquals(ORIGINAL_HASH.salt(), saltedTokenHashByNumber.get().salt()); + assertEquals(ORIGINAL_HASH.hash(), saltedTokenHashByNumber.get().hash()); + } + + /* { + final Optional saltedTokenHashByPni = registrationRecoveryPasswords.lookup(PNI).get(); + assertTrue(saltedTokenHashByPni.isPresent()); + assertEquals(ORIGINAL_HASH.salt(), saltedTokenHashByPni.get().salt()); + assertEquals(ORIGINAL_HASH.hash(), saltedTokenHashByPni.get().hash()); + } */ } @Test public void testLookupAfterRefresh() throws Exception { - store.addOrReplace(NUMBER, ORIGINAL_HASH).get(); + registrationRecoveryPasswords.addOrReplace(NUMBER, PNI, ORIGINAL_HASH).get(); CLOCK.increment(50, TimeUnit.SECONDS); - store.addOrReplace(NUMBER, ORIGINAL_HASH).get(); + registrationRecoveryPasswords.addOrReplace(NUMBER, PNI, ORIGINAL_HASH).get(); final long updatedExp = fetchTimestamp(NUMBER); final long expectedExp = CLOCK.instant().getEpochSecond() + EXPIRATION.getSeconds(); assertEquals(expectedExp, updatedExp); - final Optional saltedTokenHash = store.lookup(NUMBER).get(); - assertTrue(saltedTokenHash.isPresent()); - assertEquals(ORIGINAL_HASH.salt(), saltedTokenHash.get().salt()); - assertEquals(ORIGINAL_HASH.hash(), saltedTokenHash.get().hash()); + { + final Optional saltedTokenHashByNumber = registrationRecoveryPasswords.lookup(NUMBER).get(); + assertTrue(saltedTokenHashByNumber.isPresent()); + assertEquals(ORIGINAL_HASH.salt(), saltedTokenHashByNumber.get().salt()); + assertEquals(ORIGINAL_HASH.hash(), saltedTokenHashByNumber.get().hash()); + } + + /* { + final Optional saltedTokenHashByPni = registrationRecoveryPasswords.lookup(PNI).get(); + assertTrue(saltedTokenHashByPni.isPresent()); + assertEquals(ORIGINAL_HASH.salt(), saltedTokenHashByPni.get().salt()); + assertEquals(ORIGINAL_HASH.hash(), saltedTokenHashByPni.get().hash()); + } */ } @Test public void testReplace() throws Exception { - store.addOrReplace(NUMBER, ORIGINAL_HASH).get(); - store.addOrReplace(NUMBER, ANOTHER_HASH).get(); + registrationRecoveryPasswords.addOrReplace(NUMBER, PNI, ORIGINAL_HASH).get(); + registrationRecoveryPasswords.addOrReplace(NUMBER, PNI, ANOTHER_HASH).get(); - final Optional saltedTokenHash = store.lookup(NUMBER).get(); - assertTrue(saltedTokenHash.isPresent()); - assertEquals(ANOTHER_HASH.salt(), saltedTokenHash.get().salt()); - assertEquals(ANOTHER_HASH.hash(), saltedTokenHash.get().hash()); + { + final Optional saltedTokenHashByNumber = registrationRecoveryPasswords.lookup(NUMBER).get(); + assertTrue(saltedTokenHashByNumber.isPresent()); + assertEquals(ANOTHER_HASH.salt(), saltedTokenHashByNumber.get().salt()); + assertEquals(ANOTHER_HASH.hash(), saltedTokenHashByNumber.get().hash()); + } + + /* { + final Optional saltedTokenHashByPni = registrationRecoveryPasswords.lookup(PNI).get(); + assertTrue(saltedTokenHashByPni.isPresent()); + assertEquals(ANOTHER_HASH.salt(), saltedTokenHashByPni.get().salt()); + assertEquals(ANOTHER_HASH.hash(), saltedTokenHashByPni.get().hash()); + } */ } @Test public void testRemove() throws Exception { - store.addOrReplace(NUMBER, ORIGINAL_HASH).get(); - assertTrue(store.lookup(NUMBER).get().isPresent()); + assertDoesNotThrow(() -> registrationRecoveryPasswords.removeEntry(NUMBER, PNI).join()); - store.removeEntry(NUMBER).get(); - assertTrue(store.lookup(NUMBER).get().isEmpty()); + registrationRecoveryPasswords.addOrReplace(NUMBER, PNI, ORIGINAL_HASH).get(); + assertTrue(registrationRecoveryPasswords.lookup(NUMBER).get().isPresent()); + + registrationRecoveryPasswords.removeEntry(NUMBER, PNI).get(); + assertTrue(registrationRecoveryPasswords.lookup(NUMBER).get().isEmpty()); } @Test