From af1d21c225f321662c6ad1848c06a08f41f607d3 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Fri, 22 Nov 2024 16:55:59 -0500 Subject: [PATCH] Add methods for migrating E164-mapped registration recovery passwords to PNI-mapped records --- .../RegistrationRecoveryPasswords.java | 106 +++++++++- .../RegistrationRecoveryPasswordsManager.java | 34 ++++ .../storage/RegistrationRecoveryTest.java | 182 +++++++++++++++++- 3 files changed, 311 insertions(+), 11 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswords.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswords.java index 927545194..3d2dcb3ab 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswords.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswords.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.storage; import static java.util.Objects.requireNonNull; +import com.google.common.annotations.VisibleForTesting; import java.time.Clock; import java.time.Duration; import java.util.Map; @@ -15,15 +16,21 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.util.AttributeValues; +import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.Util; +import reactor.core.publisher.Flux; +import reactor.util.function.Tuple3; +import reactor.util.function.Tuples; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.Delete; import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; import software.amazon.awssdk.services.dynamodb.model.Put; +import software.amazon.awssdk.services.dynamodb.model.ScanRequest; import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem; import software.amazon.awssdk.services.dynamodb.model.TransactWriteItemsRequest; +import software.amazon.awssdk.services.dynamodb.model.TransactionCanceledException; public class RegistrationRecoveryPasswords extends AbstractDynamoDbStore { @@ -69,15 +76,10 @@ public class RegistrationRecoveryPasswords extends AbstractDynamoDbStore { .key(Map.of(KEY_E164, AttributeValues.fromString(number))) .consistentRead(true) .build()) - .thenApply(getItemResponse -> { - final Map item = getItemResponse.item(); - if (item == null || !item.containsKey(ATTR_SALT) || !item.containsKey(ATTR_HASH)) { - return Optional.empty(); - } - final String salt = item.get(ATTR_SALT).s(); - final String hash = item.get(ATTR_HASH).s(); - return Optional.of(new SaltedTokenHash(hash, salt)); - }); + .thenApply(getItemResponse -> Optional.ofNullable(getItemResponse.item()) + .filter(item -> item.containsKey(ATTR_SALT)) + .filter(item -> item.containsKey(ATTR_HASH)) + .map(RegistrationRecoveryPasswords::saltedTokenHashFromItem)); } public CompletableFuture> lookup(final UUID phoneNumberIdentifier) { @@ -130,7 +132,91 @@ public class RegistrationRecoveryPasswords extends AbstractDynamoDbStore { .build(); } - private long expirationSeconds() { + @VisibleForTesting + long expirationSeconds() { return clock.instant().plus(expiration).getEpochSecond(); } + + public Flux> getE164AssociatedRegistrationRecoveryPasswords() { + return Flux.from(asyncClient.scanPaginator(ScanRequest.builder() + .tableName(tableName) + .consistentRead(true) + .filterExpression("begins_with(#key, :e164Prefix)") + .expressionAttributeNames(Map.of("#key", KEY_E164)) + .expressionAttributeValues(Map.of(":e164Prefix", AttributeValue.fromS("+"))) + .build()) + .items()) + .map(item -> Tuples.of(item.get(KEY_E164).s(), saltedTokenHashFromItem(item), Long.parseLong(item.get(ATTR_EXP).n()))); + } + + public CompletableFuture insertPniRecord(final String phoneNumber, + final UUID phoneNumberIdentifier, + final SaltedTokenHash saltedTokenHash, + final long expirationSeconds) { + + // We try to write both the old and new record inside a transaction, but with different conditions. For the + // E164-based record, we insist that the record be entirely unchanged. This prevents us from writing an out-of-sync + // record if we read one thing in the `Scan` pass, but then somebody updated the record before we tried to write + // the PNI-based record. We refresh and retry if this happens. + // + // For the PNI-based record, we only want to write the record if one doesn't already exist for the given PNI. If one + // already exists, we'll just leave it alone. + return asyncClient.transactWriteItems(TransactWriteItemsRequest.builder() + .transactItems( + TransactWriteItem.builder() + .put(Put.builder() + .tableName(tableName) + .item(Map.of( + KEY_E164, AttributeValues.fromString(phoneNumber), + ATTR_EXP, AttributeValues.fromLong(expirationSeconds), + ATTR_SALT, AttributeValues.fromString(saltedTokenHash.salt()), + ATTR_HASH, AttributeValues.fromString(saltedTokenHash.hash()))) + .conditionExpression("#key = :key AND #expiration = :expiration AND #salt = :salt AND #hash = :hash") + .expressionAttributeNames(Map.of( + "#key", KEY_E164, + "#expiration", ATTR_EXP, + "#salt", ATTR_SALT, + "#hash", ATTR_HASH)) + .expressionAttributeValues(Map.of( + ":key", AttributeValues.fromString(phoneNumber), + ":expiration", AttributeValues.fromLong(expirationSeconds), + ":salt", AttributeValues.fromString(saltedTokenHash.salt()), + ":hash", AttributeValues.fromString(saltedTokenHash.hash()))) + .build()) + .build(), + + TransactWriteItem.builder() + .put(Put.builder() + .tableName(tableName) + .item(Map.of( + KEY_E164, AttributeValues.fromString(phoneNumberIdentifier.toString()), + ATTR_EXP, AttributeValues.fromLong(expirationSeconds), + ATTR_SALT, AttributeValues.fromString(saltedTokenHash.salt()), + ATTR_HASH, AttributeValues.fromString(saltedTokenHash.hash()))) + .conditionExpression("attribute_not_exists(#key)") + .expressionAttributeNames(Map.of("#key", KEY_E164)) + .build()) + .build()) + .build()) + .thenApply(ignored -> true) + .exceptionally(throwable -> { + if (ExceptionUtils.unwrap(throwable) instanceof TransactionCanceledException transactionCanceledException) { + if ("ConditionalCheckFailed".equals(transactionCanceledException.cancellationReasons().get(1).code())) { + // A PNI-associated record has already been stored; we can just treat this as success + return false; + } + + if ("ConditionalCheckFailed".equals(transactionCanceledException.cancellationReasons().get(0).code())) { + // No PNI-associated record is present, but the original record has changed + throw new ContestedOptimisticLockException(); + } + } + + throw ExceptionUtils.wrap(throwable); + }); + } + + private static SaltedTokenHash saltedTokenHashFromItem(final Map item) { + return new SaltedTokenHash(item.get(ATTR_HASH).s(), item.get(ATTR_SALT).s()); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswordsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswordsManager.java index ae58dbafb..3dbe43361 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswordsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswordsManager.java @@ -10,10 +10,12 @@ import static java.util.Objects.requireNonNull; import java.lang.invoke.MethodHandles; import java.util.HexFormat; import java.util.Optional; +import java.util.UUID; import java.util.concurrent.CompletableFuture; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; +import org.whispersystems.textsecuregcm.util.ExceptionUtils; import software.amazon.awssdk.services.dynamodb.model.ResourceNotFoundException; public class RegistrationRecoveryPasswordsManager { @@ -67,6 +69,38 @@ public class RegistrationRecoveryPasswordsManager { })); } + public CompletableFuture migrateE164Record(final String number, final SaltedTokenHash saltedTokenHash, final long expirationSeconds) { + return phoneNumberIdentifiers.getPhoneNumberIdentifier(number) + .thenCompose(phoneNumberIdentifier -> migrateE164Record(number, phoneNumberIdentifier, saltedTokenHash, expirationSeconds, 10)); + } + + public CompletableFuture migrateE164Record(final String number, + final UUID phoneNumberIdentifier, + final SaltedTokenHash saltedTokenHash, + final long expirationSeconds, + final int remainingAttempts) { + + if (remainingAttempts <= 0) { + return CompletableFuture.failedFuture(new ContestedOptimisticLockException()); + } + + return registrationRecoveryPasswords.insertPniRecord(number, phoneNumberIdentifier, saltedTokenHash, expirationSeconds) + .exceptionallyCompose(throwable -> { + if (ExceptionUtils.unwrap(throwable) instanceof ContestedOptimisticLockException) { + // Something about the original record changed; refresh and retry + return registrationRecoveryPasswords.lookup(number) + .thenCompose(maybeSaltedTokenHash -> maybeSaltedTokenHash + .map(refreshedSaltedTokenHash -> migrateE164Record(number, phoneNumberIdentifier, refreshedSaltedTokenHash, expirationSeconds, remainingAttempts - 1)) + .orElseGet(() -> { + // The original record was deleted, and we can declare victory + return CompletableFuture.completedFuture(false); + })); + } + + return CompletableFuture.failedFuture(throwable); + }); + } + private static String bytesToString(final byte[] bytes) { return HexFormat.of().formatHex(bytes); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryTest.java index ce612eaa4..497bcc7f1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryTest.java @@ -8,26 +8,41 @@ package org.whispersystems.textsecuregcm.storage; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.nio.charset.StandardCharsets; import java.time.Clock; import java.time.Duration; +import java.time.Instant; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; -import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; +import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.MockUtils; import org.whispersystems.textsecuregcm.util.MutableClock; +import reactor.util.function.Tuples; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; +import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; public class RegistrationRecoveryTest { @@ -174,6 +189,84 @@ public class RegistrationRecoveryTest { assertFalse(manager.verify(NUMBER, wrongPassword).get()); } + @Test + void getE164AssociatedRegistrationRecoveryPasswords() { + registrationRecoveryPasswords.addOrReplace(NUMBER, PNI, ORIGINAL_HASH).join(); + + assertEquals(List.of(Tuples.of(NUMBER, ORIGINAL_HASH, registrationRecoveryPasswords.expirationSeconds())), + registrationRecoveryPasswords.getE164AssociatedRegistrationRecoveryPasswords().collectList().block()); + } + + @Test + void insertPniRecord() { + final long expirationSeconds = Instant.now().plusSeconds(3600).getEpochSecond(); + + DYNAMO_DB_EXTENSION.getDynamoDbClient().putItem(PutItemRequest.builder() + .tableName(Tables.REGISTRATION_RECOVERY_PASSWORDS.tableName()) + .item(Map.of( + RegistrationRecoveryPasswords.KEY_E164, AttributeValues.fromString(NUMBER), + RegistrationRecoveryPasswords.ATTR_EXP, AttributeValues.fromLong(expirationSeconds), + RegistrationRecoveryPasswords.ATTR_SALT, AttributeValues.fromString(ORIGINAL_HASH.salt()), + RegistrationRecoveryPasswords.ATTR_HASH, AttributeValues.fromString(ORIGINAL_HASH.hash()))) + .build()); + + assertTrue(registrationRecoveryPasswords.lookup(PNI).join().isEmpty()); + + assertTrue(() -> registrationRecoveryPasswords.insertPniRecord(NUMBER, PNI, ORIGINAL_HASH, expirationSeconds).join()); + assertEquals(Optional.of(ORIGINAL_HASH), registrationRecoveryPasswords.lookup(PNI).join()); + } + + @Test + void insertPniRecordOriginalDeleted() { + final CompletionException completionException = assertThrows(CompletionException.class, () -> + registrationRecoveryPasswords.insertPniRecord(NUMBER, PNI, ORIGINAL_HASH, 0L).join()); + + assertInstanceOf(ContestedOptimisticLockException.class, completionException.getCause()); + } + + @Test + void insertPniRecordOriginalChanged() { + final long expirationSeconds = Instant.now().plusSeconds(3600).getEpochSecond(); + + DYNAMO_DB_EXTENSION.getDynamoDbClient().putItem(PutItemRequest.builder() + .tableName(Tables.REGISTRATION_RECOVERY_PASSWORDS.tableName()) + .item(Map.of( + RegistrationRecoveryPasswords.KEY_E164, AttributeValues.fromString(NUMBER), + RegistrationRecoveryPasswords.ATTR_EXP, AttributeValues.fromLong(expirationSeconds), + RegistrationRecoveryPasswords.ATTR_SALT, AttributeValues.fromString(ORIGINAL_HASH.salt()), + RegistrationRecoveryPasswords.ATTR_HASH, AttributeValues.fromString(ORIGINAL_HASH.hash()))) + .build()); + + final CompletionException completionException = assertThrows(CompletionException.class, () -> + registrationRecoveryPasswords.insertPniRecord(NUMBER, PNI, ANOTHER_HASH, expirationSeconds).join()); + + assertInstanceOf(ContestedOptimisticLockException.class, completionException.getCause()); + } + + @Test + void insertPniRecordNewRecordAlreadyExists() { + registrationRecoveryPasswords.addOrReplace(NUMBER, PNI, ORIGINAL_HASH).join(); + + assertTrue(registrationRecoveryPasswords.lookup(NUMBER).join().isPresent()); + assertTrue(registrationRecoveryPasswords.lookup(PNI).join().isPresent()); + assertEquals(registrationRecoveryPasswords.lookup(NUMBER).join(), registrationRecoveryPasswords.lookup(PNI).join()); + + assertFalse(() -> + registrationRecoveryPasswords.insertPniRecord(NUMBER, PNI, ORIGINAL_HASH, registrationRecoveryPasswords.expirationSeconds()).join()); + } + + @Test + void insertPniRecordOriginalChangedNewRecordAlreadyExists() { + registrationRecoveryPasswords.addOrReplace(NUMBER, PNI, ORIGINAL_HASH).join(); + + assertTrue(registrationRecoveryPasswords.lookup(NUMBER).join().isPresent()); + assertTrue(registrationRecoveryPasswords.lookup(PNI).join().isPresent()); + assertEquals(registrationRecoveryPasswords.lookup(NUMBER).join(), registrationRecoveryPasswords.lookup(PNI).join()); + + assertFalse(() -> + registrationRecoveryPasswords.insertPniRecord(NUMBER, PNI, ANOTHER_HASH, registrationRecoveryPasswords.expirationSeconds()).join()); + } + private static long fetchTimestamp(final String number) throws ExecutionException, InterruptedException { return DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient().getItem(GetItemRequest.builder() .tableName(Tables.REGISTRATION_RECOVERY_PASSWORDS.tableName()) @@ -189,4 +282,91 @@ public class RegistrationRecoveryTest { }) .get(); } + + @Test + void migrateE164Record() { + final RegistrationRecoveryPasswords registrationRecoveryPasswords = mock(RegistrationRecoveryPasswords.class); + when(registrationRecoveryPasswords.insertPniRecord(any(), any(), any(), anyLong())) + .thenReturn(CompletableFuture.completedFuture(true)); + + final PhoneNumberIdentifiers phoneNumberIdentifiers = mock(PhoneNumberIdentifiers.class); + when(phoneNumberIdentifiers.getPhoneNumberIdentifier(NUMBER)).thenReturn(CompletableFuture.completedFuture(PNI)); + + final RegistrationRecoveryPasswordsManager migrationManager = + new RegistrationRecoveryPasswordsManager(registrationRecoveryPasswords, phoneNumberIdentifiers); + + assertTrue(() -> migrationManager.migrateE164Record(NUMBER, ORIGINAL_HASH, 1234).join()); + } + + @Test + void migrateE164RecordRetry() { + final RegistrationRecoveryPasswords registrationRecoveryPasswords = mock(RegistrationRecoveryPasswords.class); + + when(registrationRecoveryPasswords.insertPniRecord(eq(NUMBER), eq(PNI), eq(ORIGINAL_HASH), anyLong())) + .thenReturn(CompletableFuture.failedFuture(new ContestedOptimisticLockException())); + + when(registrationRecoveryPasswords.insertPniRecord(eq(NUMBER), eq(PNI), eq(ANOTHER_HASH), anyLong())) + .thenReturn(CompletableFuture.completedFuture(true)); + + when(registrationRecoveryPasswords.lookup(NUMBER)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(ANOTHER_HASH))); + + final PhoneNumberIdentifiers phoneNumberIdentifiers = mock(PhoneNumberIdentifiers.class); + when(phoneNumberIdentifiers.getPhoneNumberIdentifier(NUMBER)).thenReturn(CompletableFuture.completedFuture(PNI)); + + final RegistrationRecoveryPasswordsManager migrationManager = + new RegistrationRecoveryPasswordsManager(registrationRecoveryPasswords, phoneNumberIdentifiers); + + assertTrue(() -> migrationManager.migrateE164Record(NUMBER, ORIGINAL_HASH, 1234).join()); + + verify(registrationRecoveryPasswords).lookup(NUMBER); + verify(registrationRecoveryPasswords, times(2)).insertPniRecord(any(), any(), any(), anyLong()); + } + + @Test + void migrateE164RecordOriginalDeleted() { + final RegistrationRecoveryPasswords registrationRecoveryPasswords = mock(RegistrationRecoveryPasswords.class); + + when(registrationRecoveryPasswords.insertPniRecord(eq(NUMBER), eq(PNI), eq(ORIGINAL_HASH), anyLong())) + .thenReturn(CompletableFuture.failedFuture(new ContestedOptimisticLockException())); + + when(registrationRecoveryPasswords.lookup(NUMBER)) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + + final PhoneNumberIdentifiers phoneNumberIdentifiers = mock(PhoneNumberIdentifiers.class); + when(phoneNumberIdentifiers.getPhoneNumberIdentifier(NUMBER)).thenReturn(CompletableFuture.completedFuture(PNI)); + + final RegistrationRecoveryPasswordsManager migrationManager = + new RegistrationRecoveryPasswordsManager(registrationRecoveryPasswords, phoneNumberIdentifiers); + + assertFalse(() -> migrationManager.migrateE164Record(NUMBER, ORIGINAL_HASH, 1234).join()); + + verify(registrationRecoveryPasswords).lookup(NUMBER); + verify(registrationRecoveryPasswords).insertPniRecord(any(), any(), any(), anyLong()); + } + + @Test + void migrateE164RecordRetryExhausted() { + final RegistrationRecoveryPasswords registrationRecoveryPasswords = mock(RegistrationRecoveryPasswords.class); + + when(registrationRecoveryPasswords.insertPniRecord(any(), any(), any(), anyLong())) + .thenReturn(CompletableFuture.failedFuture(new ContestedOptimisticLockException())); + + when(registrationRecoveryPasswords.lookup(NUMBER)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(ORIGINAL_HASH))); + + final PhoneNumberIdentifiers phoneNumberIdentifiers = mock(PhoneNumberIdentifiers.class); + when(phoneNumberIdentifiers.getPhoneNumberIdentifier(NUMBER)).thenReturn(CompletableFuture.completedFuture(PNI)); + + final RegistrationRecoveryPasswordsManager migrationManager = + new RegistrationRecoveryPasswordsManager(registrationRecoveryPasswords, phoneNumberIdentifiers); + + final CompletionException completionException = assertThrows(CompletionException.class, + () -> migrationManager.migrateE164Record(NUMBER, ORIGINAL_HASH, 1234).join()); + + assertInstanceOf(ContestedOptimisticLockException.class, completionException.getCause()); + + verify(registrationRecoveryPasswords, times(10)).lookup(NUMBER); + verify(registrationRecoveryPasswords, times(10)).insertPniRecord(any(), any(), any(), anyLong()); + } }