Use a `Callable` for tasks performed within the scope of a pessimistic lock

This commit is contained in:
Jon Chambers 2025-05-12 12:36:51 -04:00 committed by Jon Chambers
parent b95d08aaea
commit a4b98f38a6
6 changed files with 221 additions and 183 deletions

View File

@ -9,6 +9,7 @@ import com.google.common.annotations.VisibleForTesting;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.Executor;
@ -49,10 +50,15 @@ public class AccountLockManager {
* @param task the task to execute once locks have been acquired
* @param lockAcquisitionExecutor the executor on which to run blocking lock acquire/release tasks. this executor
* should not use virtual threads.
* @throws InterruptedException if interrupted while acquiring a lock
*
* @return the value returned by the given {@code task}
*
* @throws Exception if an exception is thrown by the given {@code task}
*/
public void withLock(final List<UUID> phoneNumberIdentifiers, final Runnable task,
final Executor lockAcquisitionExecutor) {
public <V> V withLock(final List<UUID> phoneNumberIdentifiers,
final Callable<V> task,
final Executor lockAcquisitionExecutor) throws Exception {
if (phoneNumberIdentifiers.isEmpty()) {
throw new IllegalArgumentException("List of PNIs to lock must not be empty");
}
@ -75,7 +81,7 @@ public class AccountLockManager {
}
}, lockAcquisitionExecutor).join();
task.run();
return task.call();
} finally {
CompletableFuture.runAsync(() -> {
for (final LockItem lockItem : lockItems) {

View File

@ -273,128 +273,148 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
final DeviceSpec primaryDeviceSpec,
@Nullable final String userAgent) throws InterruptedException {
final Account account = new Account();
final UUID pni = phoneNumberIdentifiers.getPhoneNumberIdentifier(number).join();
return createTimer.record(() -> {
accountLockManager.withLock(List.of(pni), () -> {
final Optional<UUID> maybeRecentlyDeletedAccountIdentifier =
accounts.findRecentlyDeletedAccountIdentifier(pni);
// Reuse the ACI from any recently-deleted account with this number to cover cases where somebody is
// re-registering.
account.setUuid(maybeRecentlyDeletedAccountIdentifier.orElseGet(UUID::randomUUID));
account.setNumber(number, pni);
account.setIdentityKey(aciIdentityKey);
account.setPhoneNumberIdentityKey(pniIdentityKey);
account.addDevice(primaryDeviceSpec.toDevice(Device.PRIMARY_ID, clock));
account.setRegistrationLockFromAttributes(accountAttributes);
account.setUnidentifiedAccessKey(accountAttributes.getUnidentifiedAccessKey());
account.setUnrestrictedUnidentifiedAccess(accountAttributes.isUnrestrictedUnidentifiedAccess());
account.setDiscoverableByPhoneNumber(accountAttributes.isDiscoverableByPhoneNumber());
account.setBadges(clock, accountBadges);
String accountCreationType = maybeRecentlyDeletedAccountIdentifier.isPresent() ? "recently-deleted" : "new";
final String pushTokenType;
if (primaryDeviceSpec.apnRegistrationId().isPresent()) {
pushTokenType = "apns";
} else if (primaryDeviceSpec.gcmRegistrationId().isPresent()) {
pushTokenType = "fcm";
} else {
pushTokenType = "none";
try {
return accountLockManager.withLock(List.of(pni),
() -> create(number, pni, accountAttributes, accountBadges, aciIdentityKey, pniIdentityKey, primaryDeviceSpec, userAgent), accountLockExecutor);
} catch (final Exception e) {
if (e instanceof RuntimeException runtimeException) {
throw runtimeException;
}
String previousPushTokenType = null;
try {
accounts.create(account, keysManager.buildWriteItemsForNewDevice(account.getIdentifier(IdentityType.ACI),
account.getIdentifier(IdentityType.PNI),
Device.PRIMARY_ID,
primaryDeviceSpec.aciSignedPreKey(),
primaryDeviceSpec.pniSignedPreKey(),
primaryDeviceSpec.aciPqLastResortPreKey(),
primaryDeviceSpec.pniPqLastResortPreKey()));
} catch (final AccountAlreadyExistsException e) {
accountCreationType = "re-registration";
if (StringUtils.isNotBlank(e.getExistingAccount().getPrimaryDevice().getApnId())) {
previousPushTokenType = "apns";
} else if (StringUtils.isNotBlank(e.getExistingAccount().getPrimaryDevice().getGcmId())) {
previousPushTokenType = "fcm";
} else {
previousPushTokenType = "none";
}
final UUID aci = e.getExistingAccount().getIdentifier(IdentityType.ACI);
account.setUuid(aci);
final List<TransactWriteItem> additionalWriteItems = Stream.concat(
keysManager.buildWriteItemsForNewDevice(account.getIdentifier(IdentityType.ACI),
account.getIdentifier(IdentityType.PNI),
Device.PRIMARY_ID,
primaryDeviceSpec.aciSignedPreKey(),
primaryDeviceSpec.pniSignedPreKey(),
primaryDeviceSpec.aciPqLastResortPreKey(),
primaryDeviceSpec.pniPqLastResortPreKey()).stream(),
e.getExistingAccount().getDevices()
.stream()
.map(Device::getId)
// No need to clear the keys for the primary device since we'll just overwrite them in the same
// transaction anyhow
.filter(existingDeviceId -> existingDeviceId != Device.PRIMARY_ID)
.flatMap(existingDeviceId ->
keysManager.buildWriteItemsForRemovedDevice(aci, pni, existingDeviceId).stream()))
.toList();
CompletableFuture.allOf(
keysManager.deleteSingleUsePreKeys(aci),
keysManager.deleteSingleUsePreKeys(pni),
messagesManager.clear(aci),
profilesManager.deleteAll(aci))
.thenCompose(ignored -> disconnectionRequestManager.requestDisconnection(aci))
.thenCompose(ignored -> accounts.reclaimAccount(e.getExistingAccount(), account, additionalWriteItems))
.thenCompose(ignored -> {
// We should have cleared all messages before overwriting the old account, but more may have arrived
// while we were working. Similarly, the old account holder could have added keys or profiles. We'll
// largely repeat the cleanup process after creating the account to make sure we really REALLY got
// everything.
//
// We exclude the primary device's repeated-use keys from deletion because new keys were provided as
// part of the account creation process, and we don't want to delete the keys that just got added.
return CompletableFuture.allOf(keysManager.deleteSingleUsePreKeys(aci),
keysManager.deleteSingleUsePreKeys(pni),
messagesManager.clear(aci),
profilesManager.deleteAll(aci));
})
.join();
}
redisSet(account);
Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of("type", accountCreationType),
Tag.of("hasPushToken", String.valueOf(
primaryDeviceSpec.apnRegistrationId().isPresent() || primaryDeviceSpec.gcmRegistrationId()
.isPresent())),
Tag.of("pushTokenType", pushTokenType));
if (StringUtils.isNotBlank(previousPushTokenType)) {
tags = tags.and(Tag.of("previousPushTokenType", previousPushTokenType));
}
Metrics.counter(CREATE_COUNTER_NAME, tags).increment();
accountAttributes.recoveryPassword().ifPresent(registrationRecoveryPassword ->
registrationRecoveryPasswordsManager.store(account.getIdentifier(IdentityType.PNI),
registrationRecoveryPassword));
}, accountLockExecutor);
return account;
logger.error("Unexpected exception while creating account", e);
throw new RuntimeException(e);
}
});
}
private Account create(final String number,
final UUID pni,
final AccountAttributes accountAttributes,
final List<AccountBadge> accountBadges,
final IdentityKey aciIdentityKey,
final IdentityKey pniIdentityKey,
final DeviceSpec primaryDeviceSpec,
@Nullable final String userAgent) {
final Account account = new Account();
final Optional<UUID> maybeRecentlyDeletedAccountIdentifier =
accounts.findRecentlyDeletedAccountIdentifier(pni);
// Reuse the ACI from any recently-deleted account with this number to cover cases where somebody is
// re-registering.
account.setUuid(maybeRecentlyDeletedAccountIdentifier.orElseGet(UUID::randomUUID));
account.setNumber(number, pni);
account.setIdentityKey(aciIdentityKey);
account.setPhoneNumberIdentityKey(pniIdentityKey);
account.addDevice(primaryDeviceSpec.toDevice(Device.PRIMARY_ID, clock));
account.setRegistrationLockFromAttributes(accountAttributes);
account.setUnidentifiedAccessKey(accountAttributes.getUnidentifiedAccessKey());
account.setUnrestrictedUnidentifiedAccess(accountAttributes.isUnrestrictedUnidentifiedAccess());
account.setDiscoverableByPhoneNumber(accountAttributes.isDiscoverableByPhoneNumber());
account.setBadges(clock, accountBadges);
String accountCreationType = maybeRecentlyDeletedAccountIdentifier.isPresent() ? "recently-deleted" : "new";
final String pushTokenType;
if (primaryDeviceSpec.apnRegistrationId().isPresent()) {
pushTokenType = "apns";
} else if (primaryDeviceSpec.gcmRegistrationId().isPresent()) {
pushTokenType = "fcm";
} else {
pushTokenType = "none";
}
String previousPushTokenType = null;
try {
accounts.create(account, keysManager.buildWriteItemsForNewDevice(account.getIdentifier(IdentityType.ACI),
account.getIdentifier(IdentityType.PNI),
Device.PRIMARY_ID,
primaryDeviceSpec.aciSignedPreKey(),
primaryDeviceSpec.pniSignedPreKey(),
primaryDeviceSpec.aciPqLastResortPreKey(),
primaryDeviceSpec.pniPqLastResortPreKey()));
} catch (final AccountAlreadyExistsException e) {
accountCreationType = "re-registration";
if (StringUtils.isNotBlank(e.getExistingAccount().getPrimaryDevice().getApnId())) {
previousPushTokenType = "apns";
} else if (StringUtils.isNotBlank(e.getExistingAccount().getPrimaryDevice().getGcmId())) {
previousPushTokenType = "fcm";
} else {
previousPushTokenType = "none";
}
final UUID aci = e.getExistingAccount().getIdentifier(IdentityType.ACI);
account.setUuid(aci);
final List<TransactWriteItem> additionalWriteItems = Stream.concat(
keysManager.buildWriteItemsForNewDevice(account.getIdentifier(IdentityType.ACI),
account.getIdentifier(IdentityType.PNI),
Device.PRIMARY_ID,
primaryDeviceSpec.aciSignedPreKey(),
primaryDeviceSpec.pniSignedPreKey(),
primaryDeviceSpec.aciPqLastResortPreKey(),
primaryDeviceSpec.pniPqLastResortPreKey()).stream(),
e.getExistingAccount().getDevices()
.stream()
.map(Device::getId)
// No need to clear the keys for the primary device since we'll just overwrite them in the same
// transaction anyhow
.filter(existingDeviceId -> existingDeviceId != Device.PRIMARY_ID)
.flatMap(existingDeviceId ->
keysManager.buildWriteItemsForRemovedDevice(aci, pni, existingDeviceId).stream()))
.toList();
CompletableFuture.allOf(
keysManager.deleteSingleUsePreKeys(aci),
keysManager.deleteSingleUsePreKeys(pni),
messagesManager.clear(aci),
profilesManager.deleteAll(aci))
.thenCompose(ignored -> disconnectionRequestManager.requestDisconnection(aci))
.thenCompose(ignored -> accounts.reclaimAccount(e.getExistingAccount(), account, additionalWriteItems))
.thenCompose(ignored -> {
// We should have cleared all messages before overwriting the old account, but more may have arrived
// while we were working. Similarly, the old account holder could have added keys or profiles. We'll
// largely repeat the cleanup process after creating the account to make sure we really REALLY got
// everything.
//
// We exclude the primary device's repeated-use keys from deletion because new keys were provided as
// part of the account creation process, and we don't want to delete the keys that just got added.
return CompletableFuture.allOf(keysManager.deleteSingleUsePreKeys(aci),
keysManager.deleteSingleUsePreKeys(pni),
messagesManager.clear(aci),
profilesManager.deleteAll(aci));
})
.join();
}
redisSet(account);
Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of("type", accountCreationType),
Tag.of("hasPushToken", String.valueOf(
primaryDeviceSpec.apnRegistrationId().isPresent() || primaryDeviceSpec.gcmRegistrationId()
.isPresent())),
Tag.of("pushTokenType", pushTokenType));
if (StringUtils.isNotBlank(previousPushTokenType)) {
tags = tags.and(Tag.of("previousPushTokenType", previousPushTokenType));
}
Metrics.counter(CREATE_COUNTER_NAME, tags).increment();
accountAttributes.recoveryPassword().ifPresent(registrationRecoveryPassword ->
registrationRecoveryPasswordsManager.store(account.getIdentifier(IdentityType.PNI),
registrationRecoveryPassword));
return account;
}
public CompletableFuture<Pair<Account, Device>> addDevice(final Account account, final DeviceSpec deviceSpec, final String linkDeviceToken) {
return accountLockManager.withLockAsync(List.of(account.getPhoneNumberIdentifier()),
() -> addDevice(account.getIdentifier(IdentityType.ACI), deviceSpec, linkDeviceToken, MAX_UPDATE_ATTEMPTS),
@ -655,57 +675,74 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds);
final AtomicReference<Account> updatedAccount = new AtomicReference<>();
accountLockManager.withLock(List.of(account.getPhoneNumberIdentifier(), targetPhoneNumberIdentifier), () -> {
redisDelete(account);
// There are three possible states for accounts associated with the target phone number:
//
// 1. An account exists with the target PNI; the caller has proved ownership of the number, so delete the
// account with the target PNI. This will leave a "deleted account" record for the deleted account mapping
// the UUID of the deleted account to the target PNI. We'll then overwrite that so it points to the
// original PNI to facilitate switching back and forth between numbers.
// 2. No account with the target PNI exists, but one has recently been deleted. In that case, add a "deleted
// account" record that maps the ACI of the recently-deleted account to the now-abandoned original PNI
// of the account changing its number (which facilitates ACI consistency in cases that a party is switching
// back and forth between numbers).
// 3. No account with the target PNI exists at all, in which case no additional action is needed.
final Optional<UUID> recentlyDeletedAci = accounts.findRecentlyDeletedAccountIdentifier(targetPhoneNumberIdentifier);
final Optional<Account> maybeExistingAccount = getByE164(targetNumber);
final Optional<UUID> maybeDisplacedUuid;
if (maybeExistingAccount.isPresent()) {
delete(maybeExistingAccount.get()).join();
maybeDisplacedUuid = maybeExistingAccount.map(Account::getUuid);
} else {
maybeDisplacedUuid = recentlyDeletedAci;
try {
return accountLockManager.withLock(List.of(account.getPhoneNumberIdentifier(), targetPhoneNumberIdentifier),
() -> changeNumber(account, targetNumber, targetPhoneNumberIdentifier, pniIdentityKey, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds), accountLockExecutor);
} catch (final Exception e) {
if (e instanceof MismatchedDevicesException mismatchedDevicesException) {
throw mismatchedDevicesException;
} if (e instanceof RuntimeException runtimeException) {
throw runtimeException;
}
final UUID uuid = account.getUuid();
logger.error("Unexpected exception when changing phone number", e);
throw new RuntimeException(e);
}
}
CompletableFuture.allOf(
keysManager.deleteSingleUsePreKeys(targetPhoneNumberIdentifier),
keysManager.deleteSingleUsePreKeys(originalPhoneNumberIdentifier))
.join();
private Account changeNumber(final Account account,
final String targetNumber,
final UUID targetPhoneNumberIdentifier,
final IdentityKey pniIdentityKey,
final Map<Byte, ECSignedPreKey> pniSignedPreKeys,
final Map<Byte, KEMSignedPreKey> pniPqLastResortPreKeys,
final Map<Byte, Integer> pniRegistrationIds) {
final UUID originalPhoneNumberIdentifier = account.getPhoneNumberIdentifier();
redisDelete(account);
// There are three possible states for accounts associated with the target phone number:
//
// 1. An account exists with the target PNI; the caller has proved ownership of the number, so delete the
// account with the target PNI. This will leave a "deleted account" record for the deleted account mapping
// the UUID of the deleted account to the target PNI. We'll then overwrite that so it points to the
// original PNI to facilitate switching back and forth between numbers.
// 2. No account with the target PNI exists, but one has recently been deleted. In that case, add a "deleted
// account" record that maps the ACI of the recently-deleted account to the now-abandoned original PNI
// of the account changing its number (which facilitates ACI consistency in cases that a party is switching
// back and forth between numbers).
// 3. No account with the target PNI exists at all, in which case no additional action is needed.
final Optional<UUID> recentlyDeletedAci = accounts.findRecentlyDeletedAccountIdentifier(targetPhoneNumberIdentifier);
final Optional<Account> maybeExistingAccount = getByE164(targetNumber);
final Optional<UUID> maybeDisplacedUuid;
if (maybeExistingAccount.isPresent()) {
delete(maybeExistingAccount.get()).join();
maybeDisplacedUuid = maybeExistingAccount.map(Account::getUuid);
} else {
maybeDisplacedUuid = recentlyDeletedAci;
}
final UUID uuid = account.getUuid();
CompletableFuture.allOf(
keysManager.deleteSingleUsePreKeys(targetPhoneNumberIdentifier),
keysManager.deleteSingleUsePreKeys(originalPhoneNumberIdentifier))
.join();
final Collection<TransactWriteItem> keyWriteItems =
buildPniKeyWriteItems(targetPhoneNumberIdentifier, pniSignedPreKeys, pniPqLastResortPreKeys);
final Account numberChangedAccount = updateWithRetries(
account,
a -> {
setPniKeys(account, pniIdentityKey, pniRegistrationIds);
return true;
},
a -> accounts.changeNumber(a, targetNumber, targetPhoneNumberIdentifier, maybeDisplacedUuid, keyWriteItems),
() -> accounts.getByAccountIdentifier(uuid).orElseThrow(),
AccountChangeValidator.NUMBER_CHANGE_VALIDATOR);
updatedAccount.set(numberChangedAccount);
}, accountLockExecutor);
return updatedAccount.get();
return updateWithRetries(
account,
a -> {
setPniKeys(account, pniIdentityKey, pniRegistrationIds);
return true;
},
a -> accounts.changeNumber(a, targetNumber, targetPhoneNumberIdentifier, maybeDisplacedUuid, keyWriteItems),
() -> accounts.getByAccountIdentifier(uuid).orElseThrow(),
AccountChangeValidator.NUMBER_CHANGE_VALIDATOR);
}
public Account updatePniKeys(final Account account,

View File

@ -47,9 +47,8 @@ class AccountLockManagerTest {
}
@Test
void withLock() throws InterruptedException {
accountLockManager.withLock(List.of(FIRST_PNI, SECOND_PNI), () -> {
}, executor);
void withLock() throws Exception {
accountLockManager.withLock(List.of(FIRST_PNI, SECOND_PNI), () -> null, executor);
verify(lockClient, times(2)).acquireLock(any());
verify(lockClient, times(2)).releaseLock(any(ReleaseLockOptions.class));
@ -69,8 +68,7 @@ class AccountLockManagerTest {
void withLockEmptyList() {
final Runnable task = mock(Runnable.class);
assertThrows(IllegalArgumentException.class, () -> accountLockManager.withLock(Collections.emptyList(), () -> {
},
assertThrows(IllegalArgumentException.class, () -> accountLockManager.withLock(Collections.emptyList(), () -> null,
executor));
verify(task, never()).run();
}

View File

@ -28,6 +28,7 @@ import java.util.ArrayList;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingDeque;
@ -84,7 +85,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
private Executor mutationExecutor = new ThreadPoolExecutor(20, 20, 5, TimeUnit.SECONDS, new LinkedBlockingDeque<>(20));
@BeforeEach
void setup() throws InterruptedException {
void setup() throws Exception {
@SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager =
mock(DynamicConfigurationManager.class);
@ -108,10 +109,8 @@ class AccountsManagerConcurrentModificationIntegrationTest {
final AccountLockManager accountLockManager = mock(AccountLockManager.class);
doAnswer(invocation -> {
final Runnable task = invocation.getArgument(1);
task.run();
return null;
final Callable<?> task = invocation.getArgument(1);
return task.call();
}).when(accountLockManager).withLock(anyList(), any(), any());
when(accountLockManager.withLockAsync(anyList(), any(), any())).thenAnswer(invocation -> {

View File

@ -52,6 +52,7 @@ import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.Executor;
@ -153,7 +154,7 @@ class AccountsManagerTest {
};
@BeforeEach
void setup() throws InterruptedException {
void setup() throws Exception {
accounts = mock(Accounts.class);
keysManager = mock(KeysManager.class);
messagesManager = mock(MessagesManager.class);
@ -213,10 +214,8 @@ class AccountsManagerTest {
final AccountLockManager accountLockManager = mock(AccountLockManager.class);
doAnswer(invocation -> {
final Runnable task = invocation.getArgument(1);
task.run();
return null;
final Callable<?> task = invocation.getArgument(1);
return task.call();
}).when(accountLockManager).withLock(anyList(), any(), any());
when(accountLockManager.withLockAsync(anyList(), any(), any())).thenAnswer(invocation -> {

View File

@ -28,6 +28,7 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors;
import java.util.function.Supplier;
@ -81,12 +82,12 @@ class AccountsManagerUsernameIntegrationTest {
private Accounts accounts;
@BeforeEach
void setup() throws InterruptedException {
void setup() throws Exception {
buildAccountsManager(1, 2, 10);
}
private void buildAccountsManager(final int initialWidth, int discriminatorMaxWidth, int attemptsPerWidth)
throws InterruptedException {
throws Exception {
@SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager =
mock(DynamicConfigurationManager.class);
@ -115,10 +116,8 @@ class AccountsManagerUsernameIntegrationTest {
final AccountLockManager accountLockManager = mock(AccountLockManager.class);
doAnswer(invocation -> {
final Runnable task = invocation.getArgument(1);
task.run();
return null;
final Callable<?> task = invocation.getArgument(1);
return task.call();
}).when(accountLockManager).withLock(anyList(), any(), any());
when(accountLockManager.withLockAsync(anyList(), any(), any())).thenAnswer(invocation -> {