From ff4e2bdfb72b8d2a4ed2cffac0232986b60b0822 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Mon, 25 Nov 2024 16:22:49 -0500 Subject: [PATCH] Refresh registration recovery password expirations before retrying an insertion --- .../storage/RegistrationRecoveryPasswords.java | 16 +++++++++++++++- .../RegistrationRecoveryPasswordsManager.java | 6 +++--- .../storage/RegistrationRecoveryTest.java | 17 ++++++++++++----- 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswords.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswords.java index fd1dbb3a9..08a35e2c6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswords.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswords.java @@ -17,6 +17,7 @@ import java.util.concurrent.CompletableFuture; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.ExceptionUtils; +import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Util; import reactor.core.publisher.Flux; import reactor.core.scheduler.Scheduler; @@ -65,13 +66,26 @@ public class RegistrationRecoveryPasswords { .tableName(tableName) .key(Map.of(KEY_E164, AttributeValues.fromString(number))) .consistentRead(true) - .build()) + .build()) .thenApply(getItemResponse -> Optional.ofNullable(getItemResponse.item()) .filter(item -> item.containsKey(ATTR_SALT)) .filter(item -> item.containsKey(ATTR_HASH)) .map(RegistrationRecoveryPasswords::saltedTokenHashFromItem)); } + CompletableFuture>> lookupWithExpiration(final String key) { + return asyncClient.getItem(GetItemRequest.builder() + .tableName(tableName) + .key(Map.of(KEY_E164, AttributeValues.fromString(key))) + .consistentRead(true) + .build()) + .thenApply(getItemResponse -> Optional.ofNullable(getItemResponse.item()) + .filter(item -> item.containsKey(ATTR_SALT)) + .filter(item -> item.containsKey(ATTR_HASH)) + .filter(item -> item.containsKey(ATTR_EXP)) + .map(item -> new Pair<>(saltedTokenHashFromItem(item), Long.parseLong(item.get(ATTR_EXP).n())))); + } + public CompletableFuture> lookup(final UUID phoneNumberIdentifier) { return lookup(phoneNumberIdentifier.toString()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswordsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswordsManager.java index 89d860c4e..b396c6f0f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswordsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryPasswordsManager.java @@ -95,9 +95,9 @@ public class RegistrationRecoveryPasswordsManager { .exceptionallyCompose(throwable -> { if (ExceptionUtils.unwrap(throwable) instanceof ContestedOptimisticLockException) { // Something about the original record changed; refresh and retry - return registrationRecoveryPasswords.lookup(number) - .thenCompose(maybeSaltedTokenHash -> maybeSaltedTokenHash - .map(refreshedSaltedTokenHash -> migrateE164Record(number, phoneNumberIdentifier, refreshedSaltedTokenHash, expirationSeconds, remainingAttempts - 1)) + return registrationRecoveryPasswords.lookupWithExpiration(number) + .thenCompose(maybePair -> maybePair + .map(pair -> migrateE164Record(number, phoneNumberIdentifier, pair.first(), pair.second(), remainingAttempts - 1)) .orElseGet(() -> { // The original record was deleted, and we can declare victory return CompletableFuture.completedFuture(false); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryTest.java index 4c1348747..d5fa3e359 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RegistrationRecoveryTest.java @@ -39,6 +39,7 @@ import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.MockUtils; import org.whispersystems.textsecuregcm.util.MutableClock; +import org.whispersystems.textsecuregcm.util.Pair; import reactor.core.scheduler.Schedulers; import reactor.util.function.Tuples; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; @@ -308,8 +309,8 @@ public class RegistrationRecoveryTest { when(registrationRecoveryPasswords.insertPniRecord(eq(NUMBER), eq(PNI), eq(ANOTHER_HASH), anyLong())) .thenReturn(CompletableFuture.completedFuture(true)); - when(registrationRecoveryPasswords.lookup(NUMBER)) - .thenReturn(CompletableFuture.completedFuture(Optional.of(ANOTHER_HASH))); + when(registrationRecoveryPasswords.lookupWithExpiration(NUMBER)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(new Pair<>(ANOTHER_HASH, 1234L)))); final PhoneNumberIdentifiers phoneNumberIdentifiers = mock(PhoneNumberIdentifiers.class); when(phoneNumberIdentifiers.getPhoneNumberIdentifier(NUMBER)).thenReturn(CompletableFuture.completedFuture(PNI)); @@ -319,7 +320,7 @@ public class RegistrationRecoveryTest { assertTrue(() -> migrationManager.migrateE164Record(NUMBER, ORIGINAL_HASH, 1234).join()); - verify(registrationRecoveryPasswords).lookup(NUMBER); + verify(registrationRecoveryPasswords).lookupWithExpiration(NUMBER); verify(registrationRecoveryPasswords, times(2)).insertPniRecord(any(), any(), any(), anyLong()); } @@ -333,6 +334,9 @@ public class RegistrationRecoveryTest { when(registrationRecoveryPasswords.lookup(NUMBER)) .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + when(registrationRecoveryPasswords.lookupWithExpiration(NUMBER)) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + final PhoneNumberIdentifiers phoneNumberIdentifiers = mock(PhoneNumberIdentifiers.class); when(phoneNumberIdentifiers.getPhoneNumberIdentifier(NUMBER)).thenReturn(CompletableFuture.completedFuture(PNI)); @@ -341,7 +345,7 @@ public class RegistrationRecoveryTest { assertFalse(() -> migrationManager.migrateE164Record(NUMBER, ORIGINAL_HASH, 1234).join()); - verify(registrationRecoveryPasswords).lookup(NUMBER); + verify(registrationRecoveryPasswords).lookupWithExpiration(NUMBER); verify(registrationRecoveryPasswords).insertPniRecord(any(), any(), any(), anyLong()); } @@ -355,6 +359,9 @@ public class RegistrationRecoveryTest { when(registrationRecoveryPasswords.lookup(NUMBER)) .thenReturn(CompletableFuture.completedFuture(Optional.of(ORIGINAL_HASH))); + when(registrationRecoveryPasswords.lookupWithExpiration(NUMBER)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(new Pair<>(ORIGINAL_HASH, CLOCK.instant().getEpochSecond() + EXPIRATION.getSeconds())))); + final PhoneNumberIdentifiers phoneNumberIdentifiers = mock(PhoneNumberIdentifiers.class); when(phoneNumberIdentifiers.getPhoneNumberIdentifier(NUMBER)).thenReturn(CompletableFuture.completedFuture(PNI)); @@ -366,7 +373,7 @@ public class RegistrationRecoveryTest { assertInstanceOf(ContestedOptimisticLockException.class, completionException.getCause()); - verify(registrationRecoveryPasswords, times(10)).lookup(NUMBER); + verify(registrationRecoveryPasswords, times(10)).lookupWithExpiration(NUMBER); verify(registrationRecoveryPasswords, times(10)).insertPniRecord(any(), any(), any(), anyLong()); } }