Introduce async account updaters

This commit is contained in:
Jon Chambers 2023-07-12 17:27:32 -04:00 committed by Jon Chambers
parent d17c7aaba6
commit 82ed783a2d
2 changed files with 149 additions and 6 deletions

View File

@ -34,6 +34,7 @@ import java.util.Optional;
import java.util.OptionalInt;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;
@ -59,6 +60,7 @@ import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.ParallelFlux;
@ -120,6 +122,8 @@ public class AccountsManager {
private static final Duration USERNAME_HASH_RESERVATION_TTL_MINUTES = Duration.ofMinutes(5);
private static final int MAX_UPDATE_ATTEMPTS = 10;
@FunctionalInterface
private interface AccountPersister {
void persistAccount(Account account) throws UsernameHashNotAvailableException;
@ -553,6 +557,14 @@ public class AccountsManager {
});
}
public CompletableFuture<Account> updateAsync(Account account, Consumer<Account> updater) {
return updateAsync(account, a -> {
updater.accept(a);
// assume that all updaters passed to the public method actually modify the account
return true;
});
}
/**
* Specialized version of {@link #updateDevice(Account, long, Consumer)} that minimizes potentially contentious and
* redundant updates of {@code device.lastSeen}
@ -606,6 +618,25 @@ public class AccountsManager {
return updatedAccount;
}
private CompletableFuture<Account> updateAsync(final Account account, final Function<Account, Boolean> updater) {
final Timer.Context timerContext = updateTimer.time();
return redisDeleteAsync(account)
.thenCompose(ignored -> {
final UUID uuid = account.getUuid();
return updateWithRetriesAsync(account,
updater,
a -> accounts.updateAsync(a).toCompletableFuture(),
() -> accounts.getByAccountIdentifierAsync(uuid).thenApply(Optional::orElseThrow),
AccountChangeValidator.GENERAL_CHANGE_VALIDATOR,
MAX_UPDATE_ATTEMPTS);
})
.thenCompose(updatedAccount -> redisSetAsync(updatedAccount).thenApply(ignored -> updatedAccount))
.whenComplete((ignored, throwable) -> timerContext.close());
}
private Account updateWithRetries(Account account,
final Function<Account, Boolean> updater,
final Consumer<Account> persister,
@ -660,6 +691,42 @@ public class AccountsManager {
throw new OptimisticLockRetryLimitExceededException();
}
private CompletionStage<Account> updateWithRetriesAsync(Account account,
final Function<Account, Boolean> updater,
final Function<Account, CompletionStage<Void>> persister,
final Supplier<CompletionStage<Account>> retriever,
final AccountChangeValidator changeValidator,
final int remainingTries) {
final Account originalAccount = cloneAccount(account);
if (!updater.apply(account)) {
return CompletableFuture.completedFuture(account);
}
if (remainingTries > 0) {
return persister.apply(account)
.thenApply(ignored -> {
final Account updatedAccount = cloneAccount(account);
account.markStale();
changeValidator.validateChange(originalAccount, updatedAccount);
return updatedAccount;
})
.exceptionallyCompose(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof ContestedOptimisticLockException) {
return retriever.get().thenCompose(refreshedAccount ->
updateWithRetriesAsync(refreshedAccount, updater, persister, retriever, changeValidator, remainingTries - 1));
} else {
throw ExceptionUtils.wrap(throwable);
}
});
}
return CompletableFuture.failedFuture(new OptimisticLockRetryLimitExceededException());
}
private static Account cloneAccount(final Account account) {
try {
final Account clone = mapper.readValue(mapper.writeValueAsBytes(account), Account.class);
@ -680,6 +747,14 @@ public class AccountsManager {
});
}
public CompletableFuture<Account> updateDeviceAsync(final Account account, final long deviceId, final Consumer<Device> deviceUpdater) {
return updateAsync(account, a -> {
a.getDevice(deviceId).ifPresent(deviceUpdater);
// assume that all updaters passed to the public method actually modify the device
return true;
});
}
public Optional<Account> getByE164(final String number) {
return checkRedisThenAccounts(
getByNumberTimer,

View File

@ -101,6 +101,15 @@ class AccountsManagerTest {
return null;
};
private static final Answer<CompletableFuture<Void>> ACCOUNT_UPDATE_ASYNC_ANSWER = invocation -> {
// it is implicit in the update() contract is that a successful call will
// result in an incremented version
final Account updatedAccount = invocation.getArgument(0, Account.class);
updatedAccount.setVersion(updatedAccount.getVersion() + 1);
return CompletableFuture.completedFuture(null);
};
@BeforeEach
void setup() throws InterruptedException {
accounts = mock(Accounts.class);
@ -115,6 +124,11 @@ class AccountsManagerTest {
//noinspection unchecked
asyncCommands = mock(RedisAdvancedClusterAsyncCommands.class);
when(asyncCommands.del(any())).thenReturn(MockRedisFuture.completedFuture(0L));
when(asyncCommands.get(any())).thenReturn(MockRedisFuture.completedFuture(null));
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(accounts.updateAsync(any())).thenReturn(CompletableFuture.completedFuture(null));
doAnswer((Answer<Void>) invocation -> {
final Account account = invocation.getArgument(0, Account.class);
@ -719,12 +733,6 @@ class AccountsManagerTest {
when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifier(uuid)).thenReturn(
Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[16])));
doThrow(ContestedOptimisticLockException.class)
.doAnswer(ACCOUNT_UPDATE_ANSWER)
.when(accounts).update(any());
when(accounts.getByAccountIdentifier(uuid)).thenReturn(
Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[16])));
doThrow(ContestedOptimisticLockException.class)
@ -743,6 +751,33 @@ class AccountsManagerTest {
verifyNoMoreInteractions(accounts);
}
@Test
void testUpdateAsync_optimisticLockingFailure() {
UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[16]);
when(asyncCommands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifierAsync(uuid)).thenReturn(CompletableFuture.completedFuture(
Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[16]))));
when(accounts.updateAsync(any()))
.thenReturn(CompletableFuture.failedFuture(new ContestedOptimisticLockException()))
.thenAnswer(ACCOUNT_UPDATE_ASYNC_ANSWER);
final IdentityKey identityKey = new IdentityKey(Curve.generateKeyPair().getPublicKey());
account = accountsManager.updateAsync(account, a -> a.setIdentityKey(identityKey)).join();
assertEquals(1, account.getVersion());
assertEquals(identityKey, account.getIdentityKey());
verify(accounts, times(1)).getByAccountIdentifierAsync(uuid);
verify(accounts, times(2)).updateAsync(any());
verifyNoMoreInteractions(accounts);
}
@Test
void testUpdate_dynamoOptimisticLockingFailureDuringCreate() {
UUID uuid = UUID.randomUUID();
@ -793,6 +828,39 @@ class AccountsManagerTest {
verify(unknownDeviceUpdater, never()).accept(any(Device.class));
}
@Test
void testUpdateDeviceAsync() {
final UUID uuid = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[16]);
when(accounts.getByAccountIdentifierAsync(uuid)).thenReturn(CompletableFuture.completedFuture(
Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[16]))));
assertTrue(account.getDevices().isEmpty());
Device enabledDevice = new Device();
enabledDevice.setFetchesMessages(true);
enabledDevice.setSignedPreKey(KeysHelper.signedECPreKey(1, Curve.generateKeyPair()));
enabledDevice.setLastSeen(System.currentTimeMillis());
final long deviceId = account.getNextDeviceId();
enabledDevice.setId(deviceId);
account.addDevice(enabledDevice);
@SuppressWarnings("unchecked") Consumer<Device> deviceUpdater = mock(Consumer.class);
@SuppressWarnings("unchecked") Consumer<Device> unknownDeviceUpdater = mock(Consumer.class);
account = accountsManager.updateDeviceAsync(account, deviceId, deviceUpdater).join();
account = accountsManager.updateDeviceAsync(account, deviceId, d -> d.setName("deviceName")).join();
assertEquals("deviceName", account.getDevice(deviceId).orElseThrow().getName());
verify(deviceUpdater, times(1)).accept(any(Device.class));
accountsManager.updateDeviceAsync(account, account.getNextDeviceId(), unknownDeviceUpdater).join();
verify(unknownDeviceUpdater, never()).accept(any(Device.class));
}
@Test
void testCreateFreshAccount() throws InterruptedException {
when(accounts.create(any())).thenReturn(true);