diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountLockManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountLockManager.java index d8a8c9868..bf75e0f2d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountLockManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountLockManager.java @@ -8,11 +8,13 @@ import com.amazonaws.services.dynamodbv2.ReleaseLockOptions; import com.google.common.annotations.VisibleForTesting; import java.util.ArrayList; import java.util.List; +import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; +import java.util.stream.Stream; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; public class AccountLockManager { @@ -43,28 +45,35 @@ public class AccountLockManager { * account lifecycle changes (like deleting an account or changing a phone number). The given task runs once locks for * all given phone numbers have been acquired, and the locks are released as soon as the task completes by any means. * - * @param e164s the phone numbers for which to acquire a distributed, pessimistic lock - * @param task the task to execute once locks have been acquired + * @param e164s the phone numbers for which to acquire a distributed, pessimistic lock + * @param phoneNumberIdentifiers the phone number identifiers for which to acquire a distributed, pessimistic lock + * @param task the task to execute once locks have been acquired * @param lockAcquisitionExecutor the executor on which to run blocking lock acquire/release tasks. this executor * should not use virtual threads. - * * @throws InterruptedException if interrupted while acquiring a lock */ - public void withLock(final List e164s, final Runnable task, final Executor lockAcquisitionExecutor) { + public void withLock(final List e164s, final List phoneNumberIdentifiers, final Runnable task, + final Executor lockAcquisitionExecutor) { if (e164s.isEmpty()) { throw new IllegalArgumentException("List of e164s to lock must not be empty"); } + if (phoneNumberIdentifiers.isEmpty()) { + throw new IllegalArgumentException("List of PNIs to lock must not be empty"); + } - final List lockItems = new ArrayList<>(e164s.size()); + final List allIdentifiers = Stream.concat(e164s.stream(), + phoneNumberIdentifiers.stream().map(UUID::toString)) + .toList(); + final List lockItems = new ArrayList<>(allIdentifiers.size()); try { // Offload the acquire/release tasks to the dedicated lock acquisition executor. The lock client performs blocking // operations while holding locks which forces thread pinning when this method runs on a virtual thread. // https://github.com/awslabs/amazon-dynamodb-lock-client/issues/97 CompletableFuture.runAsync(() -> { - for (final String e164 : e164s) { + for (final String identifier : allIdentifiers) { try { - lockItems.add(lockClient.acquireLock(AcquireLockOptions.builder(e164) + lockItems.add(lockClient.acquireLock(AcquireLockOptions.builder(identifier) .withAcquireReleasedLocksConsistently(true) .build())); } catch (final InterruptedException e) { @@ -91,27 +100,32 @@ public class AccountLockManager { * account lifecycle changes (like deleting an account or changing a phone number). The given task runs once locks for * all given phone numbers have been acquired, and the locks are released as soon as the task completes by any means. * - * @param e164s the phone numbers for which to acquire a distributed, pessimistic lock - * @param taskSupplier a supplier for the task to execute once locks have been acquired - * @param executor the executor on which to acquire and release locks - * + * @param e164s the phone numbers for which to acquire a distributed, pessimistic lock + * @param phoneNumberIdentifiers the phone number identifiers for which to acquire a distributed, pessimistic lock + * @param taskSupplier a supplier for the task to execute once locks have been acquired + * @param executor the executor on which to acquire and release locks * @return a future that completes normally when the given task has executed successfully and all locks have been * released; the returned future may fail with an {@link InterruptedException} if interrupted while acquiring a lock */ - public CompletableFuture withLockAsync(final List e164s, - final Supplier> taskSupplier, - final Executor executor) { + public CompletableFuture withLockAsync(final List e164s, final List phoneNumberIdentifiers, + final Supplier> taskSupplier, final Executor executor) { if (e164s.isEmpty()) { throw new IllegalArgumentException("List of e164s to lock must not be empty"); } + if (phoneNumberIdentifiers.isEmpty()) { + throw new IllegalArgumentException("List of PNIs to lock must not be empty"); + } - final List lockItems = new ArrayList<>(e164s.size()); + final List allIdentifiers = Stream.concat(e164s.stream(), + phoneNumberIdentifiers.stream().map(UUID::toString)) + .toList(); + final List lockItems = new ArrayList<>(allIdentifiers.size()); return CompletableFuture.runAsync(() -> { - for (final String e164 : e164s) { + for (final String identifier : allIdentifiers) { try { - lockItems.add(lockClient.acquireLock(AcquireLockOptions.builder(e164) + lockItems.add(lockClient.acquireLock(AcquireLockOptions.builder(identifier) .withAcquireReleasedLocksConsistently(true) .build())); } catch (final InterruptedException e) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 5c46b4934..de046cb58 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -271,15 +271,16 @@ public class AccountsManager extends RedisPubSubAdapter implemen @Nullable final String userAgent) throws InterruptedException { final Account account = new Account(); + final UUID phoneNumberIdentifier = phoneNumberIdentifiers.getPhoneNumberIdentifier(number); return createTimer.record(() -> { - accountLockManager.withLock(List.of(number), () -> { + accountLockManager.withLock(List.of(number), List.of(phoneNumberIdentifier), () -> { final Optional maybeRecentlyDeletedAccountIdentifier = accounts.findRecentlyDeletedAccountIdentifier(number); // Reuse the ACI from any recently-deleted account with this number to cover cases where somebody is // re-registering. - account.setNumber(number, phoneNumberIdentifiers.getPhoneNumberIdentifier(number)); + account.setNumber(number, phoneNumberIdentifier); account.setUuid(maybeRecentlyDeletedAccountIdentifier.orElseGet(UUID::randomUUID)); account.setIdentityKey(aciIdentityKey); account.setPhoneNumberIdentityKey(pniIdentityKey); @@ -363,9 +364,9 @@ public class AccountsManager extends RedisPubSubAdapter implemen // We exclude the primary device's repeated-use keys from deletion because new keys were provided as // part of the account creation process, and we don't want to delete the keys that just got added. return CompletableFuture.allOf(keysManager.deleteSingleUsePreKeys(aci), - keysManager.deleteSingleUsePreKeys(pni), - messagesManager.clear(aci), - profilesManager.deleteAll(aci)); + keysManager.deleteSingleUsePreKeys(pni), + messagesManager.clear(aci), + profilesManager.deleteAll(aci)); }) .join(); } @@ -375,7 +376,8 @@ public class AccountsManager extends RedisPubSubAdapter implemen Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), Tag.of("type", accountCreationType), Tag.of("hasPushToken", String.valueOf( - primaryDeviceSpec.apnRegistrationId().isPresent() || primaryDeviceSpec.gcmRegistrationId().isPresent())), + primaryDeviceSpec.apnRegistrationId().isPresent() || primaryDeviceSpec.gcmRegistrationId() + .isPresent())), Tag.of("pushTokenType", pushTokenType)); if (StringUtils.isNotBlank(previousPushTokenType)) { @@ -385,7 +387,8 @@ public class AccountsManager extends RedisPubSubAdapter implemen Metrics.counter(CREATE_COUNTER_NAME, tags).increment(); accountAttributes.recoveryPassword().ifPresent(registrationRecoveryPassword -> - registrationRecoveryPasswordsManager.storeForCurrentNumber(account.getNumber(), registrationRecoveryPassword)); + registrationRecoveryPasswordsManager.storeForCurrentNumber(account.getNumber(), + registrationRecoveryPassword)); }, accountLockExecutor); return account; @@ -394,6 +397,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen public CompletableFuture> addDevice(final Account account, final DeviceSpec deviceSpec, final String linkDeviceToken) { return accountLockManager.withLockAsync(List.of(account.getNumber()), + List.of(account.getPhoneNumberIdentifier()), () -> addDevice(account.getIdentifier(IdentityType.ACI), deviceSpec, linkDeviceToken, MAX_UPDATE_ATTEMPTS), accountLockExecutor); } @@ -581,7 +585,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen throw new IllegalArgumentException("Cannot remove primary device"); } - return accountLockManager.withLockAsync(List.of(account.getNumber()), + return accountLockManager.withLockAsync(List.of(account.getNumber()), List.of(account.getPhoneNumberIdentifier()), () -> removeDevice(account.getIdentifier(IdentityType.ACI), deviceId, MAX_UPDATE_ATTEMPTS), accountLockExecutor); } @@ -647,8 +651,10 @@ public class AccountsManager extends RedisPubSubAdapter implemen validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds); final AtomicReference updatedAccount = new AtomicReference<>(); + final UUID targetPhoneNumberIdentifier = phoneNumberIdentifiers.getPhoneNumberIdentifier(targetNumber); - accountLockManager.withLock(List.of(account.getNumber(), targetNumber), () -> { + accountLockManager.withLock(List.of(account.getNumber(), targetNumber), + List.of(account.getPhoneNumberIdentifier(), targetPhoneNumberIdentifier), () -> { redisDelete(account); // There are three possible states for accounts associated with the target phone number: @@ -674,15 +680,14 @@ public class AccountsManager extends RedisPubSubAdapter implemen } final UUID uuid = account.getUuid(); - final UUID phoneNumberIdentifier = phoneNumberIdentifiers.getPhoneNumberIdentifier(targetNumber); CompletableFuture.allOf( - keysManager.deleteSingleUsePreKeys(phoneNumberIdentifier), + keysManager.deleteSingleUsePreKeys(targetPhoneNumberIdentifier), keysManager.deleteSingleUsePreKeys(originalPhoneNumberIdentifier)) .join(); final Collection keyWriteItems = - buildPniKeyWriteItems(uuid, phoneNumberIdentifier, pniSignedPreKeys, pniPqLastResortPreKeys); + buildPniKeyWriteItems(uuid, targetPhoneNumberIdentifier, pniSignedPreKeys, pniPqLastResortPreKeys); final Account numberChangedAccount = updateWithRetries( account, @@ -690,7 +695,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen setPniKeys(account, pniIdentityKey, pniRegistrationIds); return true; }, - a -> accounts.changeNumber(a, targetNumber, phoneNumberIdentifier, maybeDisplacedUuid, keyWriteItems), + a -> accounts.changeNumber(a, targetNumber, targetPhoneNumberIdentifier, maybeDisplacedUuid, keyWriteItems), () -> accounts.getByAccountIdentifier(uuid).orElseThrow(), AccountChangeValidator.NUMBER_CHANGE_VALIDATOR); @@ -1220,7 +1225,9 @@ public class AccountsManager extends RedisPubSubAdapter implemen public CompletableFuture delete(final Account account, final DeletionReason deletionReason) { final Timer.Sample sample = Timer.start(); - return accountLockManager.withLockAsync(List.of(account.getNumber()), () -> delete(account), accountLockExecutor) + return accountLockManager.withLockAsync(List.of(account.getNumber()), List.of(account.getPhoneNumberIdentifier()), + () -> delete(account), + accountLockExecutor) .whenComplete((ignored, throwable) -> { sample.stop(deleteTimer); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ClientPublicKeysManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ClientPublicKeysManager.java index 482f66924..bb7881ab9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ClientPublicKeysManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ClientPublicKeysManager.java @@ -42,6 +42,7 @@ public class ClientPublicKeysManager { */ public CompletableFuture setPublicKey(final Account account, final byte deviceId, final ECPublicKey publicKey) { return accountLockManager.withLockAsync(List.of(account.getNumber()), + List.of(account.getPhoneNumberIdentifier()), () -> clientPublicKeys.setPublicKey(account.getIdentifier(IdentityType.ACI), deviceId, publicKey), accountLockExecutor); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountLockManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountLockManagerTest.java index 93655f1e2..b8866fb21 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountLockManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountLockManagerTest.java @@ -12,6 +12,7 @@ import com.amazonaws.services.dynamodbv2.ReleaseLockOptions; import com.google.i18n.phonenumbers.PhoneNumberUtil; import java.util.Collections; import java.util.List; +import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -33,6 +34,9 @@ class AccountLockManagerTest { private static final String SECOND_NUMBER = PhoneNumberUtil.getInstance().format( PhoneNumberUtil.getInstance().getExampleNumber("JP"), PhoneNumberUtil.PhoneNumberFormat.E164); + private static final UUID FIRST_PNI = UUID.randomUUID(); + private static final UUID SECOND_PNI = UUID.randomUUID(); + @BeforeEach void setUp() { lockClient = mock(AmazonDynamoDBLockClient.class); @@ -51,47 +55,53 @@ class AccountLockManagerTest { @Test void withLock() throws InterruptedException { - accountLockManager.withLock(List.of(FIRST_NUMBER, SECOND_NUMBER), () -> {}, executor); + accountLockManager.withLock(List.of(FIRST_NUMBER, SECOND_NUMBER), List.of(FIRST_PNI, SECOND_PNI), () -> { + }, executor); - verify(lockClient, times(2)).acquireLock(any()); - verify(lockClient, times(2)).releaseLock(any(ReleaseLockOptions.class)); + verify(lockClient, times(4)).acquireLock(any()); + verify(lockClient, times(4)).releaseLock(any(ReleaseLockOptions.class)); } @Test void withLockTaskThrowsException() throws InterruptedException { - assertThrows(RuntimeException.class, () -> accountLockManager.withLock(List.of(FIRST_NUMBER, SECOND_NUMBER), () -> { + assertThrows(RuntimeException.class, + () -> accountLockManager.withLock(List.of(FIRST_NUMBER, SECOND_NUMBER), List.of(FIRST_PNI, SECOND_PNI), () -> { throw new RuntimeException(); }, executor)); - verify(lockClient, times(2)).acquireLock(any()); - verify(lockClient, times(2)).releaseLock(any(ReleaseLockOptions.class)); + verify(lockClient, times(4)).acquireLock(any()); + verify(lockClient, times(4)).releaseLock(any(ReleaseLockOptions.class)); } @Test void withLockEmptyList() { final Runnable task = mock(Runnable.class); - assertThrows(IllegalArgumentException.class, () -> accountLockManager.withLock(Collections.emptyList(), () -> {}, executor)); + assertThrows(IllegalArgumentException.class, + () -> accountLockManager.withLock(Collections.emptyList(), Collections.emptyList(), () -> { + }, + executor)); verify(task, never()).run(); } @Test void withLockAsync() throws InterruptedException { accountLockManager.withLockAsync(List.of(FIRST_NUMBER, SECOND_NUMBER), - () -> CompletableFuture.completedFuture(null), executor).join(); + List.of(FIRST_PNI, SECOND_PNI), () -> CompletableFuture.completedFuture(null), executor).join(); - verify(lockClient, times(2)).acquireLock(any()); - verify(lockClient, times(2)).releaseLock(any(ReleaseLockOptions.class)); + verify(lockClient, times(4)).acquireLock(any()); + verify(lockClient, times(4)).releaseLock(any(ReleaseLockOptions.class)); } @Test void withLockAsyncTaskThrowsException() throws InterruptedException { assertThrows(RuntimeException.class, () -> accountLockManager.withLockAsync(List.of(FIRST_NUMBER, SECOND_NUMBER), - () -> CompletableFuture.failedFuture(new RuntimeException()), executor).join()); + List.of(FIRST_PNI, SECOND_PNI), () -> CompletableFuture.failedFuture(new RuntimeException()), executor) + .join()); - verify(lockClient, times(2)).acquireLock(any()); - verify(lockClient, times(2)).releaseLock(any(ReleaseLockOptions.class)); + verify(lockClient, times(4)).acquireLock(any()); + verify(lockClient, times(4)).releaseLock(any(ReleaseLockOptions.class)); } @Test @@ -100,7 +110,7 @@ class AccountLockManagerTest { assertThrows(IllegalArgumentException.class, () -> accountLockManager.withLockAsync(Collections.emptyList(), - () -> CompletableFuture.completedFuture(null), executor)); + Collections.emptyList(), () -> CompletableFuture.completedFuture(null), executor)); verify(task, never()).run(); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java index 237a8b552..cc981c16c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.atLeast; @@ -50,7 +51,6 @@ import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.identity.IdentityType; -import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; @@ -107,14 +107,14 @@ class AccountsManagerConcurrentModificationIntegrationTest { final AccountLockManager accountLockManager = mock(AccountLockManager.class); doAnswer(invocation -> { - final Runnable task = invocation.getArgument(1); + final Runnable task = invocation.getArgument(2); task.run(); return null; - }).when(accountLockManager).withLock(any(), any(), any()); + }).when(accountLockManager).withLock(any(), anyList(), any(), any()); - when(accountLockManager.withLockAsync(any(), any(), any())).thenAnswer(invocation -> { - final Supplier> taskSupplier = invocation.getArgument(1); + when(accountLockManager.withLockAsync(any(), anyList(), any(), any())).thenAnswer(invocation -> { + final Supplier> taskSupplier = invocation.getArgument(2); taskSupplier.get().join(); return CompletableFuture.completedFuture(null); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 51181d2ba..7a2a3be1c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -15,6 +15,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; @@ -208,14 +209,14 @@ class AccountsManagerTest { final AccountLockManager accountLockManager = mock(AccountLockManager.class); doAnswer(invocation -> { - final Runnable task = invocation.getArgument(1); + final Runnable task = invocation.getArgument(2); task.run(); return null; - }).when(accountLockManager).withLock(any(), any(), any()); + }).when(accountLockManager).withLock(any(), anyList(), any(), any()); - when(accountLockManager.withLockAsync(any(), any(), any())).thenAnswer(invocation -> { - final Supplier> taskSupplier = invocation.getArgument(1); + when(accountLockManager.withLockAsync(any(), anyList(), any(), any())).thenAnswer(invocation -> { + final Supplier> taskSupplier = invocation.getArgument(2); return taskSupplier.get(); }); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java index 2d4f2d158..7c6e1868c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -113,14 +114,14 @@ class AccountsManagerUsernameIntegrationTest { final AccountLockManager accountLockManager = mock(AccountLockManager.class); doAnswer(invocation -> { - final Runnable task = invocation.getArgument(1); + final Runnable task = invocation.getArgument(2); task.run(); return null; - }).when(accountLockManager).withLock(any(), any(), any()); + }).when(accountLockManager).withLock(any(), anyList(), any(), any()); - when(accountLockManager.withLockAsync(any(), any(), any())).thenAnswer(invocation -> { - final Supplier> taskSupplier = invocation.getArgument(1); + when(accountLockManager.withLockAsync(any(), anyList(), any(), any())).thenAnswer(invocation -> { + final Supplier> taskSupplier = invocation.getArgument(2); taskSupplier.get().join(); return CompletableFuture.completedFuture(null);