From 82ed783a2dba0ced4a5ae9ac85e012381dc6da8c Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Wed, 12 Jul 2023 17:27:32 -0400 Subject: [PATCH] Introduce async account updaters --- .../storage/AccountsManager.java | 75 +++++++++++++++++ .../storage/AccountsManagerTest.java | 80 +++++++++++++++++-- 2 files changed, 149 insertions(+), 6 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 1828355ee..267874a2b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -34,6 +34,7 @@ import java.util.Optional; import java.util.OptionalInt; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; @@ -59,6 +60,7 @@ import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; +import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.Util; import reactor.core.publisher.ParallelFlux; @@ -120,6 +122,8 @@ public class AccountsManager { private static final Duration USERNAME_HASH_RESERVATION_TTL_MINUTES = Duration.ofMinutes(5); + private static final int MAX_UPDATE_ATTEMPTS = 10; + @FunctionalInterface private interface AccountPersister { void persistAccount(Account account) throws UsernameHashNotAvailableException; @@ -553,6 +557,14 @@ public class AccountsManager { }); } + public CompletableFuture updateAsync(Account account, Consumer updater) { + return updateAsync(account, a -> { + updater.accept(a); + // assume that all updaters passed to the public method actually modify the account + return true; + }); + } + /** * Specialized version of {@link #updateDevice(Account, long, Consumer)} that minimizes potentially contentious and * redundant updates of {@code device.lastSeen} @@ -606,6 +618,25 @@ public class AccountsManager { return updatedAccount; } + private CompletableFuture updateAsync(final Account account, final Function updater) { + + final Timer.Context timerContext = updateTimer.time(); + + return redisDeleteAsync(account) + .thenCompose(ignored -> { + final UUID uuid = account.getUuid(); + + return updateWithRetriesAsync(account, + updater, + a -> accounts.updateAsync(a).toCompletableFuture(), + () -> accounts.getByAccountIdentifierAsync(uuid).thenApply(Optional::orElseThrow), + AccountChangeValidator.GENERAL_CHANGE_VALIDATOR, + MAX_UPDATE_ATTEMPTS); + }) + .thenCompose(updatedAccount -> redisSetAsync(updatedAccount).thenApply(ignored -> updatedAccount)) + .whenComplete((ignored, throwable) -> timerContext.close()); + } + private Account updateWithRetries(Account account, final Function updater, final Consumer persister, @@ -660,6 +691,42 @@ public class AccountsManager { throw new OptimisticLockRetryLimitExceededException(); } + private CompletionStage updateWithRetriesAsync(Account account, + final Function updater, + final Function> persister, + final Supplier> retriever, + final AccountChangeValidator changeValidator, + final int remainingTries) { + + final Account originalAccount = cloneAccount(account); + + if (!updater.apply(account)) { + return CompletableFuture.completedFuture(account); + } + + if (remainingTries > 0) { + return persister.apply(account) + .thenApply(ignored -> { + final Account updatedAccount = cloneAccount(account); + account.markStale(); + + changeValidator.validateChange(originalAccount, updatedAccount); + + return updatedAccount; + }) + .exceptionallyCompose(throwable -> { + if (ExceptionUtils.unwrap(throwable) instanceof ContestedOptimisticLockException) { + return retriever.get().thenCompose(refreshedAccount -> + updateWithRetriesAsync(refreshedAccount, updater, persister, retriever, changeValidator, remainingTries - 1)); + } else { + throw ExceptionUtils.wrap(throwable); + } + }); + } + + return CompletableFuture.failedFuture(new OptimisticLockRetryLimitExceededException()); + } + private static Account cloneAccount(final Account account) { try { final Account clone = mapper.readValue(mapper.writeValueAsBytes(account), Account.class); @@ -680,6 +747,14 @@ public class AccountsManager { }); } + public CompletableFuture updateDeviceAsync(final Account account, final long deviceId, final Consumer deviceUpdater) { + return updateAsync(account, a -> { + a.getDevice(deviceId).ifPresent(deviceUpdater); + // assume that all updaters passed to the public method actually modify the device + return true; + }); + } + public Optional getByE164(final String number) { return checkRedisThenAccounts( getByNumberTimer, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 2d84e140d..76c1fd908 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -101,6 +101,15 @@ class AccountsManagerTest { return null; }; + private static final Answer> ACCOUNT_UPDATE_ASYNC_ANSWER = invocation -> { + // it is implicit in the update() contract is that a successful call will + // result in an incremented version + final Account updatedAccount = invocation.getArgument(0, Account.class); + updatedAccount.setVersion(updatedAccount.getVersion() + 1); + + return CompletableFuture.completedFuture(null); + }; + @BeforeEach void setup() throws InterruptedException { accounts = mock(Accounts.class); @@ -115,6 +124,11 @@ class AccountsManagerTest { //noinspection unchecked asyncCommands = mock(RedisAdvancedClusterAsyncCommands.class); + when(asyncCommands.del(any())).thenReturn(MockRedisFuture.completedFuture(0L)); + when(asyncCommands.get(any())).thenReturn(MockRedisFuture.completedFuture(null)); + when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); + + when(accounts.updateAsync(any())).thenReturn(CompletableFuture.completedFuture(null)); doAnswer((Answer) invocation -> { final Account account = invocation.getArgument(0, Account.class); @@ -719,12 +733,6 @@ class AccountsManagerTest { when(commands.get(eq("Account3::" + uuid))).thenReturn(null); - when(accounts.getByAccountIdentifier(uuid)).thenReturn( - Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[16]))); - doThrow(ContestedOptimisticLockException.class) - .doAnswer(ACCOUNT_UPDATE_ANSWER) - .when(accounts).update(any()); - when(accounts.getByAccountIdentifier(uuid)).thenReturn( Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[16]))); doThrow(ContestedOptimisticLockException.class) @@ -743,6 +751,33 @@ class AccountsManagerTest { verifyNoMoreInteractions(accounts); } + @Test + void testUpdateAsync_optimisticLockingFailure() { + UUID uuid = UUID.randomUUID(); + UUID pni = UUID.randomUUID(); + Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[16]); + + when(asyncCommands.get(eq("Account3::" + uuid))).thenReturn(null); + + when(accounts.getByAccountIdentifierAsync(uuid)).thenReturn(CompletableFuture.completedFuture( + Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[16])))); + + when(accounts.updateAsync(any())) + .thenReturn(CompletableFuture.failedFuture(new ContestedOptimisticLockException())) + .thenAnswer(ACCOUNT_UPDATE_ASYNC_ANSWER); + + final IdentityKey identityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey()); + + account = accountsManager.updateAsync(account, a -> a.setIdentityKey(identityKey)).join(); + + assertEquals(1, account.getVersion()); + assertEquals(identityKey, account.getIdentityKey()); + + verify(accounts, times(1)).getByAccountIdentifierAsync(uuid); + verify(accounts, times(2)).updateAsync(any()); + verifyNoMoreInteractions(accounts); + } + @Test void testUpdate_dynamoOptimisticLockingFailureDuringCreate() { UUID uuid = UUID.randomUUID(); @@ -793,6 +828,39 @@ class AccountsManagerTest { verify(unknownDeviceUpdater, never()).accept(any(Device.class)); } + @Test + void testUpdateDeviceAsync() { + final UUID uuid = UUID.randomUUID(); + Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[16]); + + when(accounts.getByAccountIdentifierAsync(uuid)).thenReturn(CompletableFuture.completedFuture( + Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[16])))); + + assertTrue(account.getDevices().isEmpty()); + + Device enabledDevice = new Device(); + enabledDevice.setFetchesMessages(true); + enabledDevice.setSignedPreKey(KeysHelper.signedECPreKey(1, Curve.generateKeyPair())); + enabledDevice.setLastSeen(System.currentTimeMillis()); + final long deviceId = account.getNextDeviceId(); + enabledDevice.setId(deviceId); + account.addDevice(enabledDevice); + + @SuppressWarnings("unchecked") Consumer deviceUpdater = mock(Consumer.class); + @SuppressWarnings("unchecked") Consumer unknownDeviceUpdater = mock(Consumer.class); + + account = accountsManager.updateDeviceAsync(account, deviceId, deviceUpdater).join(); + account = accountsManager.updateDeviceAsync(account, deviceId, d -> d.setName("deviceName")).join(); + + assertEquals("deviceName", account.getDevice(deviceId).orElseThrow().getName()); + + verify(deviceUpdater, times(1)).accept(any(Device.class)); + + accountsManager.updateDeviceAsync(account, account.getNextDeviceId(), unknownDeviceUpdater).join(); + + verify(unknownDeviceUpdater, never()).accept(any(Device.class)); + } + @Test void testCreateFreshAccount() throws InterruptedException { when(accounts.create(any())).thenReturn(true);