Refresh registration recovery password expirations before retrying an insertion

This commit is contained in:
Jon Chambers 2024-11-25 16:22:49 -05:00 committed by Jon Chambers
parent ffed19d198
commit ff4e2bdfb7
3 changed files with 30 additions and 9 deletions

View File

@ -17,6 +17,7 @@ import java.util.concurrent.CompletableFuture;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
@ -65,13 +66,26 @@ public class RegistrationRecoveryPasswords {
.tableName(tableName) .tableName(tableName)
.key(Map.of(KEY_E164, AttributeValues.fromString(number))) .key(Map.of(KEY_E164, AttributeValues.fromString(number)))
.consistentRead(true) .consistentRead(true)
.build()) .build())
.thenApply(getItemResponse -> Optional.ofNullable(getItemResponse.item()) .thenApply(getItemResponse -> Optional.ofNullable(getItemResponse.item())
.filter(item -> item.containsKey(ATTR_SALT)) .filter(item -> item.containsKey(ATTR_SALT))
.filter(item -> item.containsKey(ATTR_HASH)) .filter(item -> item.containsKey(ATTR_HASH))
.map(RegistrationRecoveryPasswords::saltedTokenHashFromItem)); .map(RegistrationRecoveryPasswords::saltedTokenHashFromItem));
} }
CompletableFuture<Optional<Pair<SaltedTokenHash, Long>>> lookupWithExpiration(final String key) {
return asyncClient.getItem(GetItemRequest.builder()
.tableName(tableName)
.key(Map.of(KEY_E164, AttributeValues.fromString(key)))
.consistentRead(true)
.build())
.thenApply(getItemResponse -> Optional.ofNullable(getItemResponse.item())
.filter(item -> item.containsKey(ATTR_SALT))
.filter(item -> item.containsKey(ATTR_HASH))
.filter(item -> item.containsKey(ATTR_EXP))
.map(item -> new Pair<>(saltedTokenHashFromItem(item), Long.parseLong(item.get(ATTR_EXP).n()))));
}
public CompletableFuture<Optional<SaltedTokenHash>> lookup(final UUID phoneNumberIdentifier) { public CompletableFuture<Optional<SaltedTokenHash>> lookup(final UUID phoneNumberIdentifier) {
return lookup(phoneNumberIdentifier.toString()); return lookup(phoneNumberIdentifier.toString());
} }

View File

@ -95,9 +95,9 @@ public class RegistrationRecoveryPasswordsManager {
.exceptionallyCompose(throwable -> { .exceptionallyCompose(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof ContestedOptimisticLockException) { if (ExceptionUtils.unwrap(throwable) instanceof ContestedOptimisticLockException) {
// Something about the original record changed; refresh and retry // Something about the original record changed; refresh and retry
return registrationRecoveryPasswords.lookup(number) return registrationRecoveryPasswords.lookupWithExpiration(number)
.thenCompose(maybeSaltedTokenHash -> maybeSaltedTokenHash .thenCompose(maybePair -> maybePair
.map(refreshedSaltedTokenHash -> migrateE164Record(number, phoneNumberIdentifier, refreshedSaltedTokenHash, expirationSeconds, remainingAttempts - 1)) .map(pair -> migrateE164Record(number, phoneNumberIdentifier, pair.first(), pair.second(), remainingAttempts - 1))
.orElseGet(() -> { .orElseGet(() -> {
// The original record was deleted, and we can declare victory // The original record was deleted, and we can declare victory
return CompletableFuture.completedFuture(false); return CompletableFuture.completedFuture(false);

View File

@ -39,6 +39,7 @@ import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.MockUtils; import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.MutableClock; import org.whispersystems.textsecuregcm.util.MutableClock;
import org.whispersystems.textsecuregcm.util.Pair;
import reactor.core.scheduler.Schedulers; import reactor.core.scheduler.Schedulers;
import reactor.util.function.Tuples; import reactor.util.function.Tuples;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
@ -308,8 +309,8 @@ public class RegistrationRecoveryTest {
when(registrationRecoveryPasswords.insertPniRecord(eq(NUMBER), eq(PNI), eq(ANOTHER_HASH), anyLong())) when(registrationRecoveryPasswords.insertPniRecord(eq(NUMBER), eq(PNI), eq(ANOTHER_HASH), anyLong()))
.thenReturn(CompletableFuture.completedFuture(true)); .thenReturn(CompletableFuture.completedFuture(true));
when(registrationRecoveryPasswords.lookup(NUMBER)) when(registrationRecoveryPasswords.lookupWithExpiration(NUMBER))
.thenReturn(CompletableFuture.completedFuture(Optional.of(ANOTHER_HASH))); .thenReturn(CompletableFuture.completedFuture(Optional.of(new Pair<>(ANOTHER_HASH, 1234L))));
final PhoneNumberIdentifiers phoneNumberIdentifiers = mock(PhoneNumberIdentifiers.class); final PhoneNumberIdentifiers phoneNumberIdentifiers = mock(PhoneNumberIdentifiers.class);
when(phoneNumberIdentifiers.getPhoneNumberIdentifier(NUMBER)).thenReturn(CompletableFuture.completedFuture(PNI)); when(phoneNumberIdentifiers.getPhoneNumberIdentifier(NUMBER)).thenReturn(CompletableFuture.completedFuture(PNI));
@ -319,7 +320,7 @@ public class RegistrationRecoveryTest {
assertTrue(() -> migrationManager.migrateE164Record(NUMBER, ORIGINAL_HASH, 1234).join()); assertTrue(() -> migrationManager.migrateE164Record(NUMBER, ORIGINAL_HASH, 1234).join());
verify(registrationRecoveryPasswords).lookup(NUMBER); verify(registrationRecoveryPasswords).lookupWithExpiration(NUMBER);
verify(registrationRecoveryPasswords, times(2)).insertPniRecord(any(), any(), any(), anyLong()); verify(registrationRecoveryPasswords, times(2)).insertPniRecord(any(), any(), any(), anyLong());
} }
@ -333,6 +334,9 @@ public class RegistrationRecoveryTest {
when(registrationRecoveryPasswords.lookup(NUMBER)) when(registrationRecoveryPasswords.lookup(NUMBER))
.thenReturn(CompletableFuture.completedFuture(Optional.empty())); .thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(registrationRecoveryPasswords.lookupWithExpiration(NUMBER))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
final PhoneNumberIdentifiers phoneNumberIdentifiers = mock(PhoneNumberIdentifiers.class); final PhoneNumberIdentifiers phoneNumberIdentifiers = mock(PhoneNumberIdentifiers.class);
when(phoneNumberIdentifiers.getPhoneNumberIdentifier(NUMBER)).thenReturn(CompletableFuture.completedFuture(PNI)); when(phoneNumberIdentifiers.getPhoneNumberIdentifier(NUMBER)).thenReturn(CompletableFuture.completedFuture(PNI));
@ -341,7 +345,7 @@ public class RegistrationRecoveryTest {
assertFalse(() -> migrationManager.migrateE164Record(NUMBER, ORIGINAL_HASH, 1234).join()); assertFalse(() -> migrationManager.migrateE164Record(NUMBER, ORIGINAL_HASH, 1234).join());
verify(registrationRecoveryPasswords).lookup(NUMBER); verify(registrationRecoveryPasswords).lookupWithExpiration(NUMBER);
verify(registrationRecoveryPasswords).insertPniRecord(any(), any(), any(), anyLong()); verify(registrationRecoveryPasswords).insertPniRecord(any(), any(), any(), anyLong());
} }
@ -355,6 +359,9 @@ public class RegistrationRecoveryTest {
when(registrationRecoveryPasswords.lookup(NUMBER)) when(registrationRecoveryPasswords.lookup(NUMBER))
.thenReturn(CompletableFuture.completedFuture(Optional.of(ORIGINAL_HASH))); .thenReturn(CompletableFuture.completedFuture(Optional.of(ORIGINAL_HASH)));
when(registrationRecoveryPasswords.lookupWithExpiration(NUMBER))
.thenReturn(CompletableFuture.completedFuture(Optional.of(new Pair<>(ORIGINAL_HASH, CLOCK.instant().getEpochSecond() + EXPIRATION.getSeconds()))));
final PhoneNumberIdentifiers phoneNumberIdentifiers = mock(PhoneNumberIdentifiers.class); final PhoneNumberIdentifiers phoneNumberIdentifiers = mock(PhoneNumberIdentifiers.class);
when(phoneNumberIdentifiers.getPhoneNumberIdentifier(NUMBER)).thenReturn(CompletableFuture.completedFuture(PNI)); when(phoneNumberIdentifiers.getPhoneNumberIdentifier(NUMBER)).thenReturn(CompletableFuture.completedFuture(PNI));
@ -366,7 +373,7 @@ public class RegistrationRecoveryTest {
assertInstanceOf(ContestedOptimisticLockException.class, completionException.getCause()); assertInstanceOf(ContestedOptimisticLockException.class, completionException.getCause());
verify(registrationRecoveryPasswords, times(10)).lookup(NUMBER); verify(registrationRecoveryPasswords, times(10)).lookupWithExpiration(NUMBER);
verify(registrationRecoveryPasswords, times(10)).insertPniRecord(any(), any(), any(), anyLong()); verify(registrationRecoveryPasswords, times(10)).insertPniRecord(any(), any(), any(), anyLong());
} }
} }