From 5b7f91827af0947ceeacc3d3f09afb9cb5216bec Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Tue, 19 Dec 2023 14:11:05 -0500 Subject: [PATCH] Remove signed pre-keys transactionally when removing devices --- .../storage/AccountLockManager.java | 9 +- .../textsecuregcm/storage/Accounts.java | 4 +- .../storage/AccountsManager.java | 113 +++++++++++----- .../textsecuregcm/storage/KeysManager.java | 65 +++++---- .../storage/MessagePersister.java | 2 +- .../storage/RepeatedUseSignedPreKeyStore.java | 62 ++------- .../grpc/DevicesGrpcServiceTest.java | 2 +- ...countCreationDeletionIntegrationTest.java} | 109 +++++++++++++-- .../storage/AccountsManagerTest.java | 67 +++++---- .../textsecuregcm/storage/AccountsTest.java | 4 +- ...va => AddRemoveDeviceIntegrationTest.java} | 128 +++++++++++++++++- .../storage/KeysManagerTest.java | 16 +-- .../storage/MessagePersisterTest.java | 6 +- .../RepeatedUseECSignedPreKeyStoreTest.java | 6 + .../RepeatedUseKEMSignedPreKeyStoreTest.java | 6 + .../RepeatedUseSignedPreKeyStoreTest.java | 64 ++++----- 16 files changed, 447 insertions(+), 216 deletions(-) rename service/src/test/java/org/whispersystems/textsecuregcm/storage/{AccountCreationIntegrationTest.java => AccountCreationDeletionIntegrationTest.java} (81%) rename service/src/test/java/org/whispersystems/textsecuregcm/storage/{LinkDeviceIntegrationTest.java => AddRemoveDeviceIntegrationTest.java} (53%) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountLockManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountLockManager.java index fd2c6c16f..c2f99bf16 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountLockManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountLockManager.java @@ -13,7 +13,6 @@ import java.util.concurrent.CompletionException; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; -import org.whispersystems.textsecuregcm.util.Util; import software.amazon.awssdk.services.dynamodb.DynamoDbClient; public class AccountLockManager { @@ -83,8 +82,9 @@ public class AccountLockManager { * * @return a future that completes normally when the given task has executed successfully and all locks have been * released; the returned future may fail with an {@link InterruptedException} if interrupted while acquiring a lock - */ public CompletableFuture withLockAsync(final List e164s, - final Supplier> taskSupplier, + */ + public CompletableFuture withLockAsync(final List e164s, + final Supplier> taskSupplier, final Executor executor) { if (e164s.isEmpty()) { @@ -107,7 +107,6 @@ public class AccountLockManager { .thenCompose(ignored -> taskSupplier.get()) .whenCompleteAsync((ignored, throwable) -> lockItems.forEach(lockItem -> lockClient.releaseLock(ReleaseLockOptions.builder(lockItem) .withBestEffort(true) - .build())), executor) - .thenRun(Util.NOOP); + .build())), executor); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java index 8e15356d1..21155d323 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java @@ -1035,7 +1035,7 @@ public class Accounts extends AbstractDynamoDbStore { return Optional.ofNullable(response.items().get(0).get(DELETED_ACCOUNTS_KEY_ACCOUNT_E164).s()); } - public CompletableFuture delete(final UUID uuid) { + public CompletableFuture delete(final UUID uuid, final List additionalWriteItems) { final Timer.Sample sample = Timer.start(); return getByAccountIdentifierAsync(uuid) @@ -1050,6 +1050,8 @@ public class Accounts extends AbstractDynamoDbStore { account.getUsernameHash().ifPresent(usernameHash -> transactWriteItems.add( buildDelete(usernamesConstraintTableName, ATTR_USERNAME_HASH, usernameHash))); + transactWriteItems.addAll(additionalWriteItems); + return asyncClient.transactWriteItems(TransactWriteItemsRequest.builder() .transactItems(transactWriteItems) .build()) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index cb79bcbbf..76baf98d7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -43,6 +43,7 @@ import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; +import java.util.stream.Stream; import javax.annotation.Nullable; import org.apache.commons.lang3.ObjectUtils; import org.apache.commons.lang3.StringUtils; @@ -204,7 +205,7 @@ public class AccountsManager { String accountCreationType = maybeRecentlyDeletedAccountIdentifier.isPresent() ? "recently-deleted" : "new"; try { - accounts.create(account, keysManager.buildWriteItemsForRepeatedUseKeys(account.getIdentifier(IdentityType.ACI), + accounts.create(account, keysManager.buildWriteItemsForNewDevice(account.getIdentifier(IdentityType.ACI), account.getIdentifier(IdentityType.PNI), Device.PRIMARY_ID, primaryDeviceSpec.aciSignedPreKey(), @@ -220,21 +221,31 @@ public class AccountsManager { account.setUuid(aci); account.setNumber(e.getExistingAccount().getNumber(), pni); - CompletableFuture.allOf( - keysManager.delete(aci), - keysManager.delete(pni), - messagesManager.clear(aci), - profilesManager.deleteAll(aci)) - .thenRunAsync(() -> clientPresenceManager.disconnectAllPresencesForUuid(aci), clientPresenceExecutor) - .thenCompose(ignored -> accounts.reclaimAccount(e.getExistingAccount(), - account, - keysManager.buildWriteItemsForRepeatedUseKeys(account.getIdentifier(IdentityType.ACI), + final List additionalWriteItems = Stream.concat( + keysManager.buildWriteItemsForNewDevice(account.getIdentifier(IdentityType.ACI), account.getIdentifier(IdentityType.PNI), Device.PRIMARY_ID, primaryDeviceSpec.aciSignedPreKey(), primaryDeviceSpec.pniSignedPreKey(), primaryDeviceSpec.aciPqLastResortPreKey(), - primaryDeviceSpec.pniPqLastResortPreKey()))) + primaryDeviceSpec.pniPqLastResortPreKey()).stream(), + e.getExistingAccount().getDevices() + .stream() + .map(Device::getId) + // No need to clear the keys for the primary device since we'll just overwrite them in the same + // transaction anyhow + .filter(existingDeviceId -> existingDeviceId != Device.PRIMARY_ID) + .flatMap(existingDeviceId -> + keysManager.buildWriteItemsForRemovedDevice(aci, pni, existingDeviceId).stream())) + .toList(); + + CompletableFuture.allOf( + keysManager.deleteSingleUsePreKeys(aci), + keysManager.deleteSingleUsePreKeys(pni), + messagesManager.clear(aci), + profilesManager.deleteAll(aci)) + .thenRunAsync(() -> clientPresenceManager.disconnectAllPresencesForUuid(aci), clientPresenceExecutor) + .thenCompose(ignored -> accounts.reclaimAccount(e.getExistingAccount(), account, additionalWriteItems)) .thenCompose(ignored -> { // We should have cleared all messages before overwriting the old account, but more may have arrived // while we were working. Similarly, the old account holder could have added keys or profiles. We'll @@ -243,8 +254,8 @@ public class AccountsManager { // // We exclude the primary device's repeated-use keys from deletion because new keys were provided as // part of the account creation process, and we don't want to delete the keys that just got added. - return CompletableFuture.allOf(keysManager.delete(aci, true), - keysManager.delete(pni, true), + return CompletableFuture.allOf(keysManager.deleteSingleUsePreKeys(aci), + keysManager.deleteSingleUsePreKeys(pni), messagesManager.clear(aci), profilesManager.deleteAll(aci)); }) @@ -264,7 +275,9 @@ public class AccountsManager { } public CompletableFuture> addDevice(final Account account, final DeviceSpec deviceSpec) { - return addDevice(account.getIdentifier(IdentityType.ACI), deviceSpec, MAX_UPDATE_ATTEMPTS); + return accountLockManager.withLockAsync(List.of(account.getNumber()), + () -> addDevice(account.getIdentifier(IdentityType.ACI), deviceSpec, MAX_UPDATE_ATTEMPTS), + accountLockExecutor); } private CompletableFuture> addDevice(final UUID accountIdentifier, final DeviceSpec deviceSpec, final int retries) { @@ -274,7 +287,7 @@ public class AccountsManager { final byte nextDeviceId = account.getNextDeviceId(); account.addDevice(deviceSpec.toDevice(nextDeviceId, clock)); - final List additionalWriteItems = keysManager.buildWriteItemsForRepeatedUseKeys( + final List additionalWriteItems = keysManager.buildWriteItemsForNewDevice( account.getIdentifier(IdentityType.ACI), account.getIdentifier(IdentityType.PNI), nextDeviceId, @@ -284,8 +297,8 @@ public class AccountsManager { deviceSpec.pniPqLastResortPreKey()); return CompletableFuture.allOf( - keysManager.delete(account.getUuid(), nextDeviceId), - keysManager.delete(account.getPhoneNumberIdentifier(), nextDeviceId), + keysManager.deleteSingleUsePreKeys(account.getUuid(), nextDeviceId), + keysManager.deleteSingleUsePreKeys(account.getPhoneNumberIdentifier(), nextDeviceId), messagesManager.clear(account.getUuid(), nextDeviceId)) .thenCompose(ignored -> accounts.updateTransactionallyAsync(account, additionalWriteItems)) .thenApply(ignored -> new Pair<>(account, account.getDevice(nextDeviceId).orElseThrow())); @@ -306,16 +319,43 @@ public class AccountsManager { throw new IllegalArgumentException("Cannot remove primary device"); } - return CompletableFuture.allOf( - keysManager.delete(account.getUuid(), deviceId), + return accountLockManager.withLockAsync(List.of(account.getNumber()), + () -> removeDevice(account.getIdentifier(IdentityType.ACI), deviceId, MAX_UPDATE_ATTEMPTS), + accountLockExecutor); + } + + private CompletableFuture removeDevice(final UUID accountIdentifier, final byte deviceId, final int retries) { + return accounts.getByAccountIdentifierAsync(accountIdentifier) + .thenApply(maybeAccount -> maybeAccount.orElseThrow(ContestedOptimisticLockException::new)) + .thenCompose(account -> CompletableFuture.allOf( + keysManager.deleteSingleUsePreKeys(account.getUuid(), deviceId), messagesManager.clear(account.getUuid(), deviceId)) - .thenCompose(ignored -> updateAsync(account, (Consumer) a -> a.removeDevice(deviceId))) - // ensure any messages that came in after the first clear() are also removed - .thenCompose(updatedAccount -> messagesManager.clear(account.getUuid(), deviceId) - .thenApply(ignored -> updatedAccount)) + .thenApply(ignored -> account)) + .thenCompose(account -> { + account.removeDevice(deviceId); + + return accounts.updateTransactionallyAsync(account, keysManager.buildWriteItemsForRemovedDevice( + account.getIdentifier(IdentityType.ACI), + account.getIdentifier(IdentityType.PNI), + deviceId)) + .thenApply(ignored -> account); + }) + .thenCompose(updatedAccount -> redisDeleteAsync(updatedAccount).thenApply(ignored -> updatedAccount)) + // Ensure any messages/single-use pre-keys that came in while we were working are also removed + .thenCompose(account -> CompletableFuture.allOf( + keysManager.deleteSingleUsePreKeys(account.getUuid(), deviceId), + messagesManager.clear(account.getUuid(), deviceId)) + .thenApply(ignored -> account)) + .exceptionallyCompose(throwable -> { + if (ExceptionUtils.unwrap(throwable) instanceof ContestedOptimisticLockException && retries > 0) { + return removeDevice(accountIdentifier, deviceId, retries - 1); + } + + return CompletableFuture.failedFuture(throwable); + }) .whenComplete((ignored, throwable) -> { if (throwable == null) { - clientPresenceManager.disconnectPresence(account.getUuid(), deviceId); + clientPresenceManager.disconnectPresence(accountIdentifier, deviceId); } }); } @@ -370,12 +410,12 @@ public class AccountsManager { final UUID phoneNumberIdentifier = phoneNumberIdentifiers.getPhoneNumberIdentifier(targetNumber); CompletableFuture.allOf( - keysManager.delete(phoneNumberIdentifier), - keysManager.delete(originalPhoneNumberIdentifier)) + keysManager.deleteSingleUsePreKeys(phoneNumberIdentifier), + keysManager.deleteSingleUsePreKeys(originalPhoneNumberIdentifier)) .join(); final Collection keyWriteItems = - buildKeyWriteItems(uuid, phoneNumberIdentifier, pniSignedPreKeys, pniPqLastResortPreKeys); + buildPniKeyWriteItems(uuid, phoneNumberIdentifier, pniSignedPreKeys, pniPqLastResortPreKeys); final Account numberChangedAccount = updateWithRetries( account, @@ -404,10 +444,10 @@ public class AccountsManager { final UUID pni = account.getIdentifier(IdentityType.PNI); final Collection keyWriteItems = - buildKeyWriteItems(pni, pni, pniSignedPreKeys, pniPqLastResortPreKeys); + buildPniKeyWriteItems(pni, pni, pniSignedPreKeys, pniPqLastResortPreKeys); return redisDeleteAsync(account) - .thenCompose(ignored -> keysManager.delete(pni)) + .thenCompose(ignored -> keysManager.deleteSingleUsePreKeys(pni)) .thenCompose(ignored -> updateTransactionallyWithRetriesAsync(account, a -> setPniKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds), accounts::updateTransactionallyAsync, @@ -418,7 +458,7 @@ public class AccountsManager { .join(); } - private Collection buildKeyWriteItems( + private Collection buildPniKeyWriteItems( final UUID enabledDevicesIdentifier, final UUID phoneNumberIdentifier, @Nullable final Map pniSignedPreKeys, @@ -961,16 +1001,23 @@ public class AccountsManager { } private CompletableFuture delete(final Account account) { + final List additionalWriteItems = + account.getDevices().stream().flatMap(device -> keysManager.buildWriteItemsForRemovedDevice( + account.getIdentifier(IdentityType.ACI), + account.getIdentifier(IdentityType.PNI), + device.getId()).stream()) + .toList(); + return CompletableFuture.allOf( secureStorageClient.deleteStoredData(account.getUuid()), secureValueRecovery2Client.deleteBackups(account.getUuid()), - keysManager.delete(account.getUuid()), - keysManager.delete(account.getPhoneNumberIdentifier()), + keysManager.deleteSingleUsePreKeys(account.getUuid()), + keysManager.deleteSingleUsePreKeys(account.getPhoneNumberIdentifier()), messagesManager.clear(account.getUuid()), messagesManager.clear(account.getPhoneNumberIdentifier()), profilesManager.deleteAll(account.getUuid()), registrationRecoveryPasswordsManager.removeForNumber(account.getNumber())) - .thenCompose(ignored -> CompletableFuture.allOf(accounts.delete(account.getUuid()), redisDeleteAsync(account))) + .thenCompose(ignored -> CompletableFuture.allOf(accounts.delete(account.getUuid(), additionalWriteItems), redisDeleteAsync(account))) .thenRun(() -> RedisOperation.unchecked(() -> account.getDevices().forEach(device -> clientPresenceManager.disconnectPresence(account.getUuid(), device.getId())))); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java index c7d27084b..dee58ecc9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java @@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.storage; import com.google.common.annotations.VisibleForTesting; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; @@ -46,7 +47,7 @@ public class KeysManager { final ECSignedPreKey ecSignedPreKey) { return dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().storeEcSignedPreKeys() - ? Optional.of(ecSignedPreKeys.buildTransactWriteItem(identifier, deviceId, ecSignedPreKey)) + ? Optional.of(ecSignedPreKeys.buildTransactWriteItemForInsertion(identifier, deviceId, ecSignedPreKey)) : Optional.empty(); } @@ -54,10 +55,10 @@ public class KeysManager { final byte deviceId, final KEMSignedPreKey lastResortSignedPreKey) { - return pqLastResortKeys.buildTransactWriteItem(identifier, deviceId, lastResortSignedPreKey); + return pqLastResortKeys.buildTransactWriteItemForInsertion(identifier, deviceId, lastResortSignedPreKey); } - public List buildWriteItemsForRepeatedUseKeys(final UUID accountIdentifier, + public List buildWriteItemsForNewDevice(final UUID accountIdentifier, final UUID phoneNumberIdentifier, final byte deviceId, final ECSignedPreKey aciSignedPreKey, @@ -65,10 +66,38 @@ public class KeysManager { final KEMSignedPreKey aciPqLastResortPreKey, final KEMSignedPreKey pniLastResortPreKey) { - return List.of(ecSignedPreKeys.buildTransactWriteItem(accountIdentifier, deviceId, aciSignedPreKey), - ecSignedPreKeys.buildTransactWriteItem(phoneNumberIdentifier, deviceId, pniSignedPreKey), - pqLastResortKeys.buildTransactWriteItem(accountIdentifier, deviceId, aciPqLastResortPreKey), - pqLastResortKeys.buildTransactWriteItem(phoneNumberIdentifier, deviceId, pniLastResortPreKey)); + final List writeItems = new ArrayList<>(List.of( + pqLastResortKeys.buildTransactWriteItemForInsertion(accountIdentifier, deviceId, aciPqLastResortPreKey), + pqLastResortKeys.buildTransactWriteItemForInsertion(phoneNumberIdentifier, deviceId, pniLastResortPreKey) + )); + + if (dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().storeEcSignedPreKeys()) { + writeItems.addAll(List.of( + ecSignedPreKeys.buildTransactWriteItemForInsertion(accountIdentifier, deviceId, aciSignedPreKey), + ecSignedPreKeys.buildTransactWriteItemForInsertion(phoneNumberIdentifier, deviceId, pniSignedPreKey) + )); + } + + return writeItems; + } + + public List buildWriteItemsForRemovedDevice(final UUID accountIdentifier, + final UUID phoneNumberIdentifier, + final byte deviceId) { + + final List writeItems = new ArrayList<>(List.of( + pqLastResortKeys.buildTransactWriteItemForDeletion(accountIdentifier, deviceId), + pqLastResortKeys.buildTransactWriteItemForDeletion(phoneNumberIdentifier, deviceId) + )); + + if (dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().deleteEcSignedPreKeys()) { + writeItems.addAll(List.of( + ecSignedPreKeys.buildTransactWriteItemForDeletion(accountIdentifier, deviceId), + ecSignedPreKeys.buildTransactWriteItemForDeletion(phoneNumberIdentifier, deviceId) + )); + } + + return writeItems; } public CompletableFuture storeEcSignedPreKeys(final UUID identifier, final Map keys) { @@ -130,27 +159,17 @@ public class KeysManager { return pqPreKeys.getCount(identifier, deviceId); } - public CompletableFuture delete(final UUID identifier) { - return delete(identifier, false); - } - - public CompletableFuture delete(final UUID identifier, final boolean excludePrimaryDevice) { + public CompletableFuture deleteSingleUsePreKeys(final UUID identifier) { return CompletableFuture.allOf( ecPreKeys.delete(identifier), - pqPreKeys.delete(identifier), - dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().deleteEcSignedPreKeys() - ? ecSignedPreKeys.delete(identifier, excludePrimaryDevice) - : CompletableFuture.completedFuture(null), - pqLastResortKeys.delete(identifier, excludePrimaryDevice)); + pqPreKeys.delete(identifier) + ); } - public CompletableFuture delete(final UUID accountUuid, final byte deviceId) { + public CompletableFuture deleteSingleUsePreKeys(final UUID accountUuid, final byte deviceId) { return CompletableFuture.allOf( ecPreKeys.delete(accountUuid, deviceId), - pqPreKeys.delete(accountUuid, deviceId), - dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().deleteEcSignedPreKeys() - ? ecSignedPreKeys.delete(accountUuid, deviceId) - : CompletableFuture.completedFuture(null), - pqLastResortKeys.delete(accountUuid, deviceId)); + pqPreKeys.delete(accountUuid, deviceId) + ); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java index c584f6528..67db03944 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagePersister.java @@ -266,7 +266,7 @@ public class MessagePersister implements Managed { clientPresenceManager.disconnectPresence(account.getUuid(), deviceToDelete.getId()); CompletableFuture .allOf( - keysManager.delete(account.getUuid(), deviceToDelete.getId()), + keysManager.deleteSingleUsePreKeys(account.getUuid(), deviceToDelete.getId()), messagesManager.clear(account.getUuid(), deviceToDelete.getId())) .orTimeout((UNLINK_TIMEOUT.toSeconds() * 3) / 4, TimeUnit.SECONDS) .join(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java index 9fe0966c8..07ac03cc6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java @@ -15,10 +15,9 @@ import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.util.AttributeValues; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; -import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; +import software.amazon.awssdk.services.dynamodb.model.Delete; import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; import software.amazon.awssdk.services.dynamodb.model.Put; import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; @@ -47,8 +46,6 @@ public abstract class RepeatedUseSignedPreKeyStore> { private final Timer storeSingleKeyTimer = Metrics.timer(MetricsUtil.name(getClass(), "storeSingleKey")); private final Timer storeKeyBatchTimer = Metrics.timer(MetricsUtil.name(getClass(), "storeKeyBatch")); - private final Timer deleteForDeviceTimer = Metrics.timer(MetricsUtil.name(getClass(), "deleteForDevice")); - private final Timer deleteForAccountTimer = Metrics.timer(MetricsUtil.name(getClass(), "deleteForAccount")); private final String findKeyTimerName = MetricsUtil.name(getClass(), "findKey"); @@ -112,7 +109,7 @@ public abstract class RepeatedUseSignedPreKeyStore> { .thenRun(() -> sample.stop(storeKeyBatchTimer)); } - TransactWriteItem buildTransactWriteItem(final UUID identifier, final byte deviceId, final K preKey) { + TransactWriteItem buildTransactWriteItemForInsertion(final UUID identifier, final byte deviceId, final K preKey) { return TransactWriteItem.builder() .put(Put.builder() .tableName(tableName) @@ -121,6 +118,15 @@ public abstract class RepeatedUseSignedPreKeyStore> { .build(); } + public TransactWriteItem buildTransactWriteItemForDeletion(final UUID identifier, final byte deviceId) { + return TransactWriteItem.builder() + .delete(Delete.builder() + .tableName(tableName) + .key(getPrimaryKey(identifier, deviceId)) + .build()) + .build(); + } + /** * Finds a repeated-use pre-key for a specific device. * @@ -147,52 +153,6 @@ public abstract class RepeatedUseSignedPreKeyStore> { return findFuture; } - /** - * Clears all repeated-use pre-keys associated with the given account/identity. - * - * @param identifier the identifier for the account/identity for which to clear repeated-use pre-keys - * @param excludePrimaryDevice whether to exclude the primary device from repeated-use key deletion; this is intended - * for cases when a user "re-registers" and displaces an existing account record and has - * provided new repeated-use keys for the primary device in the process of creating the - * new account - * - * @return a future that completes once repeated-use pre-keys have been cleared from all devices associated with the - * target account/identity - */ - public CompletableFuture delete(final UUID identifier, final boolean excludePrimaryDevice) { - final Timer.Sample sample = Timer.start(); - - return getDeviceIdsWithKeys(identifier) - .filter(deviceId -> deviceId != Device.PRIMARY_ID || !excludePrimaryDevice) - .map(deviceId -> DeleteItemRequest.builder() - .tableName(tableName) - .key(getPrimaryKey(identifier, deviceId)) - .build()) - .flatMap(deleteItemRequest -> Mono.fromFuture(() -> dynamoDbAsyncClient.deleteItem(deleteItemRequest))) - // Idiom: wait for everything to finish, but discard the results - .reduce(0, (a, b) -> 0) - .toFuture() - .thenRun(() -> sample.stop(deleteForAccountTimer)); - } - - /** - * Removes the repeated-use pre-key associated with a specific device. - * - * @param identifier the identifier for the account/identity with which the target device is associated - * @param deviceId the identifier for the device within the given account/identity - * - * @return a future that completes once the repeated-use pre-key has been removed from the target device - */ - public CompletableFuture delete(final UUID identifier, final byte deviceId) { - final Timer.Sample sample = Timer.start(); - - return dynamoDbAsyncClient.deleteItem(DeleteItemRequest.builder() - .tableName(tableName) - .key(getPrimaryKey(identifier, deviceId)) - .build()) - .thenRun(() -> sample.stop(deleteForDeviceTimer)); - } - public Flux getDeviceIdsWithKeys(final UUID identifier) { return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder() .tableName(tableName) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcServiceTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcServiceTest.java index 1d38fb9d7..ca1fada62 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcServiceTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/grpc/DevicesGrpcServiceTest.java @@ -97,7 +97,7 @@ class DevicesGrpcServiceTest extends SimpleBaseGrpcTest dynamicConfigurationManager = mock(DynamicConfigurationManager.class); - DynamicConfiguration dynamicConfiguration = new DynamicConfiguration(); + final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); + when(dynamicConfiguration.getEcPreKeyMigrationConfiguration()) + .thenReturn(new DynamicECPreKeyMigrationConfiguration(true, true)); keysManager = new KeysManager( DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), @@ -248,6 +253,25 @@ public class AccountCreationIntegrationTest { pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey); + + assertEquals(Optional.of(aciSignedPreKey), keysManager.getEcSignedPreKey(account.getUuid(), Device.PRIMARY_ID).join()); + assertEquals(Optional.of(pniSignedPreKey), keysManager.getEcSignedPreKey(account.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join()); + assertEquals(Optional.of(aciPqLastResortPreKey), keysManager.getLastResort(account.getUuid(), Device.PRIMARY_ID).join()); + assertEquals(Optional.of(pniPqLastResortPreKey), keysManager.getLastResort(account.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join()); + } + + @SuppressWarnings("unused") + static ArgumentSets createAccount() { + return ArgumentSets + // deliveryChannels + .argumentsForFirstParameter( + new DeliveryChannels(true, null, null, null), + new DeliveryChannels(false, "apns-token", null, null), + new DeliveryChannels(false, "apns-token", "apns-voip-token", null), + new DeliveryChannels(false, null, null, "fcm-token")) + + // discoverableByPhoneNumber + .argumentsForNextParameter(true, false); } @CartesianTest @@ -375,18 +399,77 @@ public class AccountCreationIntegrationTest { assertEquals(existingAccountUuid, reregisteredAccount.getUuid()); } - @SuppressWarnings("unused") - static ArgumentSets createAccount() { - return ArgumentSets - // deliveryChannels - .argumentsForFirstParameter( - new DeliveryChannels(true, null, null, null), - new DeliveryChannels(false, "apns-token", null, null), - new DeliveryChannels(false, "apns-token", "apns-voip-token", null), - new DeliveryChannels(false, null, null, "fcm-token")) + @Test + void deleteAccount() throws InterruptedException { + final String number = PhoneNumberUtil.getInstance().format( + PhoneNumberUtil.getInstance().getExampleNumber("US"), + PhoneNumberUtil.PhoneNumberFormat.E164); - // discoverableByPhoneNumber - .argumentsForNextParameter(true, false); + final String password = RandomStringUtils.randomAlphanumeric(16); + final String signalAgent = RandomStringUtils.randomAlphabetic(3); + final int registrationId = ThreadLocalRandom.current().nextInt(Device.MAX_REGISTRATION_ID); + final int pniRegistrationId = ThreadLocalRandom.current().nextInt(Device.MAX_REGISTRATION_ID); + final byte[] deviceName = RandomStringUtils.randomAlphabetic(16).getBytes(StandardCharsets.UTF_8); + final String registrationLockSecret = RandomStringUtils.randomAlphanumeric(16); + + final Device.DeviceCapabilities deviceCapabilities = new Device.DeviceCapabilities( + ThreadLocalRandom.current().nextBoolean(), + ThreadLocalRandom.current().nextBoolean(), + ThreadLocalRandom.current().nextBoolean(), + ThreadLocalRandom.current().nextBoolean()); + + final AccountAttributes accountAttributes = new AccountAttributes(true, + registrationId, + pniRegistrationId, + deviceName, + registrationLockSecret, + true, + deviceCapabilities); + + final List badges = new ArrayList<>(List.of(new AccountBadge( + RandomStringUtils.randomAlphabetic(8), + CLOCK.instant().plus(Duration.ofDays(7)), + true))); + + final ECKeyPair aciKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniKeyPair = Curve.generateKeyPair(); + + final ECSignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciKeyPair); + final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniKeyPair); + final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciKeyPair); + final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniKeyPair); + + final Account account = accountsManager.create(number, + accountAttributes, + badges, + new IdentityKey(aciKeyPair.getPublicKey()), + new IdentityKey(pniKeyPair.getPublicKey()), + new DeviceSpec( + deviceName, + password, + signalAgent, + deviceCapabilities, + registrationId, + pniRegistrationId, + true, + Optional.empty(), + Optional.empty(), + aciSignedPreKey, + pniSignedPreKey, + aciPqLastResortPreKey, + pniPqLastResortPreKey)); + + final UUID aci = account.getIdentifier(IdentityType.ACI); + + assertTrue(accountsManager.getByAccountIdentifier(aci).isPresent()); + + accountsManager.delete(account, AccountsManager.DeletionReason.ADMIN_DELETED).join(); + + assertFalse(accountsManager.getByAccountIdentifier(aci).isPresent()); + assertFalse(keysManager.getEcSignedPreKey(account.getUuid(), Device.PRIMARY_ID).join().isPresent()); + assertFalse(keysManager.getEcSignedPreKey(account.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); + assertFalse(keysManager.getLastResort(account.getUuid(), Device.PRIMARY_ID).join().isPresent()); + assertFalse(keysManager.getLastResort(account.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); } @SuppressWarnings("OptionalUsedAsFieldOrParameterType") diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index bb563d864..74eb76f3b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -14,7 +14,6 @@ import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.argThat; @@ -159,7 +158,7 @@ class AccountsManagerTest { when(accounts.updateAsync(any())).thenReturn(CompletableFuture.completedFuture(null)); when(accounts.updateTransactionallyAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); - when(accounts.delete(any())).thenReturn(CompletableFuture.completedFuture(null)); + when(accounts.delete(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); doAnswer((Answer) invocation -> { final Account account = invocation.getArgument(0, Account.class); @@ -207,9 +206,7 @@ class AccountsManagerTest { when(accountLockManager.withLockAsync(any(), any(), any())).thenAnswer(invocation -> { final Supplier> taskSupplier = invocation.getArgument(1); - taskSupplier.get().join(); - - return CompletableFuture.completedFuture(null); + return taskSupplier.get(); }); final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = @@ -217,8 +214,7 @@ class AccountsManagerTest { when(registrationRecoveryPasswordsManager.removeForNumber(anyString())).thenReturn(CompletableFuture.completedFuture(null)); - when(keysManager.delete(any())).thenReturn(CompletableFuture.completedFuture(null)); - when(keysManager.delete(any(), anyBoolean())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.deleteSingleUsePreKeys(any())).thenReturn(CompletableFuture.completedFuture(null)); when(messagesManager.clear(any())).thenReturn(CompletableFuture.completedFuture(null)); when(profilesManager.deleteAll(any())).thenReturn(CompletableFuture.completedFuture(null)); @@ -959,7 +955,10 @@ class AccountsManagerTest { Account account = AccountsHelper.generateTestAccount("+14152222222", List.of(primaryDevice, linkedDevice)); - when(keysManager.delete(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); + when(accounts.getByAccountIdentifierAsync(account.getUuid())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); + + when(keysManager.deleteSingleUsePreKeys(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); when(messagesManager.clear(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); assertTrue(account.getDevice(linkedDevice.getId()).isPresent()); @@ -968,7 +967,7 @@ class AccountsManagerTest { assertFalse(account.getDevice(linkedDevice.getId()).isPresent()); verify(messagesManager, times(2)).clear(account.getUuid(), linkedDevice.getId()); - verify(keysManager).delete(account.getUuid(), linkedDevice.getId()); + verify(keysManager, times(2)).deleteSingleUsePreKeys(account.getUuid(), linkedDevice.getId()); verify(clientPresenceManager).disconnectPresence(account.getUuid(), linkedDevice.getId()); } @@ -979,14 +978,14 @@ class AccountsManagerTest { final Account account = AccountsHelper.generateTestAccount("+14152222222", List.of(primaryDevice)); - when(keysManager.delete(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.deleteSingleUsePreKeys(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); when(messagesManager.clear(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); assertThrows(IllegalArgumentException.class, () -> accountsManager.removeDevice(account, Device.PRIMARY_ID)); - assertDoesNotThrow(() -> account.getPrimaryDevice()); + assertDoesNotThrow(account::getPrimaryDevice); verify(messagesManager, never()).clear(any(), anyByte()); - verify(keysManager, never()).delete(any(), anyByte()); + verify(keysManager, never()).deleteSingleUsePreKeys(any(), anyByte()); verify(clientPresenceManager, never()).disconnectPresence(any(), anyByte()); } @@ -1035,10 +1034,8 @@ class AccountsManagerTest { verify(accounts) .create(argThat(account -> e164.equals(account.getNumber()) && existingUuid.equals(account.getUuid())), any()); - verify(keysManager).delete(existingUuid); - verify(keysManager).delete(phoneNumberIdentifiersByE164.get(e164)); - verify(keysManager).delete(existingUuid, true); - verify(keysManager).delete(phoneNumberIdentifiersByE164.get(e164), true); + verify(keysManager, times(2)).deleteSingleUsePreKeys(existingUuid); + verify(keysManager, times(2)).deleteSingleUsePreKeys(phoneNumberIdentifiersByE164.get(e164)); verify(messagesManager, times(2)).clear(existingUuid); verify(profilesManager, times(2)).deleteAll(existingUuid); verify(clientPresenceManager).disconnectAllPresencesForUuid(existingUuid); @@ -1060,7 +1057,7 @@ class AccountsManagerTest { argThat(a -> e164.equals(a.getNumber()) && recentlyDeletedUuid.equals(a.getUuid())), any()); - verify(keysManager).buildWriteItemsForRepeatedUseKeys(eq(account.getIdentifier(IdentityType.ACI)), + verify(keysManager).buildWriteItemsForNewDevice(eq(account.getIdentifier(IdentityType.ACI)), eq(account.getIdentifier(IdentityType.PNI)), eq(Device.PRIMARY_ID), any(), @@ -1119,7 +1116,7 @@ class AccountsManagerTest { final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciKeyPair); final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniKeyPair); - when(keysManager.delete(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.deleteSingleUsePreKeys(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); when(messagesManager.clear(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); when(accounts.getByAccountIdentifierAsync(aci)).thenReturn(CompletableFuture.completedFuture(Optional.of(account))); when(accounts.updateTransactionallyAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); @@ -1142,11 +1139,11 @@ class AccountsManagerTest { pniPqLastResortPreKey)) .join(); - verify(keysManager).delete(aci, nextDeviceId); - verify(keysManager).delete(pni, nextDeviceId); + verify(keysManager).deleteSingleUsePreKeys(aci, nextDeviceId); + verify(keysManager).deleteSingleUsePreKeys(pni, nextDeviceId); verify(messagesManager).clear(aci, nextDeviceId); - verify(keysManager).buildWriteItemsForRepeatedUseKeys( + verify(keysManager).buildWriteItemsForNewDevice( aci, pni, nextDeviceId, @@ -1207,8 +1204,8 @@ class AccountsManagerTest { assertTrue(phoneNumberIdentifiersByE164.containsKey(targetNumber)); - verify(keysManager).delete(originalPni); - verify(keysManager).delete(phoneNumberIdentifiersByE164.get(targetNumber)); + verify(keysManager).deleteSingleUsePreKeys(originalPni); + verify(keysManager).deleteSingleUsePreKeys(phoneNumberIdentifiersByE164.get(targetNumber)); } @Test @@ -1219,7 +1216,7 @@ class AccountsManagerTest { account = accountsManager.changeNumber(account, number, null, null, null, null); assertEquals(number, account.getNumber()); - verify(keysManager, never()).delete(any()); + verify(keysManager, never()).deleteSingleUsePreKeys(any()); } @Test @@ -1258,10 +1255,10 @@ class AccountsManagerTest { assertTrue(phoneNumberIdentifiersByE164.containsKey(targetNumber)); final UUID newPni = phoneNumberIdentifiersByE164.get(targetNumber); - verify(keysManager).delete(existingAccountUuid); - verify(keysManager).delete(originalPni); - verify(keysManager, atLeastOnce()).delete(targetPni); - verify(keysManager).delete(newPni); + verify(keysManager).deleteSingleUsePreKeys(existingAccountUuid); + verify(keysManager).deleteSingleUsePreKeys(originalPni); + verify(keysManager, atLeastOnce()).deleteSingleUsePreKeys(targetPni); + verify(keysManager).deleteSingleUsePreKeys(newPni); verifyNoMoreInteractions(keysManager); } @@ -1302,10 +1299,10 @@ class AccountsManagerTest { assertTrue(phoneNumberIdentifiersByE164.containsKey(targetNumber)); final UUID newPni = phoneNumberIdentifiersByE164.get(targetNumber); - verify(keysManager).delete(existingAccountUuid); - verify(keysManager, atLeastOnce()).delete(targetPni); - verify(keysManager).delete(newPni); - verify(keysManager).delete(originalPni); + verify(keysManager).deleteSingleUsePreKeys(existingAccountUuid); + verify(keysManager, atLeastOnce()).deleteSingleUsePreKeys(targetPni); + verify(keysManager).deleteSingleUsePreKeys(newPni); + verify(keysManager).deleteSingleUsePreKeys(originalPni); verify(keysManager).getPqEnabledDevices(uuid); verify(keysManager).buildWriteItemForEcSignedPreKey(eq(newPni), eq(Device.PRIMARY_ID), any()); verify(keysManager).buildWriteItemForEcSignedPreKey(eq(newPni), eq(deviceId2), any()); @@ -1408,7 +1405,7 @@ class AccountsManagerTest { verify(accounts).updateTransactionallyAsync(any(), any()); - verify(keysManager).delete(oldPni); + verify(keysManager).deleteSingleUsePreKeys(oldPni); verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(Device.PRIMARY_ID), any()); verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(deviceId2), any()); verify(keysManager, never()).buildWriteItemForLastResortKey(any(), anyByte(), any()); @@ -1471,7 +1468,7 @@ class AccountsManagerTest { verify(accounts).updateTransactionallyAsync(any(), any()); - verify(keysManager).delete(oldPni); + verify(keysManager).deleteSingleUsePreKeys(oldPni); verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(Device.PRIMARY_ID), any()); verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(deviceId2), any()); verify(keysManager).buildWriteItemForLastResortKey(eq(oldPni), eq(Device.PRIMARY_ID), any()); @@ -1534,7 +1531,7 @@ class AccountsManagerTest { verify(accounts).updateTransactionallyAsync(any(), any()); - verify(keysManager).delete(oldPni); + verify(keysManager).deleteSingleUsePreKeys(oldPni); verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(Device.PRIMARY_ID), any()); verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(deviceId2), any()); verify(keysManager, never()).buildWriteItemForLastResortKey(any(), anyByte(), any()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java index 007559324..4adb16617 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java @@ -202,7 +202,7 @@ class AccountsTest { assertPhoneNumberConstraintExists("+14151112222", account.getUuid()); assertPhoneNumberIdentifierConstraintExists(account.getPhoneNumberIdentifier(), account.getUuid()); - accounts.delete(originalUuid).join(); + accounts.delete(originalUuid, Collections.emptyList()).join(); assertThat(accounts.findRecentlyDeletedAccountIdentifier(account.getNumber())).hasValue(originalUuid); freshUser = createAccount(account); @@ -679,7 +679,7 @@ class AccountsTest { assertThat(accounts.getByAccountIdentifier(deletedAccount.getUuid())).isPresent(); assertThat(accounts.getByAccountIdentifier(retainedAccount.getUuid())).isPresent(); - accounts.delete(deletedAccount.getUuid()).join(); + accounts.delete(deletedAccount.getUuid(), Collections.emptyList()).join(); assertThat(accounts.getByAccountIdentifier(deletedAccount.getUuid())).isNotPresent(); assertThat(accounts.findRecentlyDeletedAccountIdentifier(deletedAccount.getNumber())).hasValue(deletedAccount.getUuid()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/LinkDeviceIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java similarity index 53% rename from service/src/test/java/org/whispersystems/textsecuregcm/storage/LinkDeviceIntegrationTest.java rename to service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java index 349db1aaf..fe7a7cfae 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/LinkDeviceIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java @@ -1,6 +1,9 @@ package org.whispersystems.textsecuregcm.storage; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.Mockito.mock; @@ -12,7 +15,9 @@ import java.time.Clock; import java.time.Instant; import java.time.ZoneId; import java.util.Optional; +import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; @@ -23,7 +28,9 @@ import org.junit.jupiter.api.extension.RegisterExtension; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicECPreKeyMigrationConfiguration; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; @@ -32,7 +39,7 @@ import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.util.Pair; -public class LinkDeviceIntegrationTest { +public class AddRemoveDeviceIntegrationTest { @RegisterExtension static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension( @@ -56,16 +63,19 @@ public class LinkDeviceIntegrationTest { private ExecutorService accountLockExecutor; private ExecutorService clientPresenceExecutor; - private AccountsManager accountsManager; private KeysManager keysManager; + private MessagesManager messagesManager; + private AccountsManager accountsManager; @BeforeEach void setUp() { @SuppressWarnings("unchecked") final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class); - DynamicConfiguration dynamicConfiguration = new DynamicConfiguration(); + final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); + when(dynamicConfiguration.getEcPreKeyMigrationConfiguration()) + .thenReturn(new DynamicECPreKeyMigrationConfiguration(true, true)); keysManager = new KeysManager( DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), @@ -100,7 +110,7 @@ public class LinkDeviceIntegrationTest { new PhoneNumberIdentifiers(DYNAMO_DB_EXTENSION.getDynamoDbClient(), DynamoDbExtensionSchema.Tables.PNI.tableName()); - final MessagesManager messagesManager = mock(MessagesManager.class); + messagesManager = mock(MessagesManager.class); when(messagesManager.clear(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); final ProfilesManager profilesManager = mock(ProfilesManager.class); @@ -143,7 +153,7 @@ public class LinkDeviceIntegrationTest { } @Test - void linkDevice() throws InterruptedException { + void addDevice() throws InterruptedException { final String number = PhoneNumberUtil.getInstance().format( PhoneNumberUtil.getInstance().getExampleNumber("US"), PhoneNumberUtil.PhoneNumberFormat.E164); @@ -176,5 +186,113 @@ public class LinkDeviceIntegrationTest { assertEquals(2, accountsManager.getByAccountIdentifier(updatedAccountAndDevice.first().getUuid()).orElseThrow().getDevices() .size()); + + final byte addedDeviceId = updatedAccountAndDevice.second().getId(); + + assertTrue(keysManager.getEcSignedPreKey(updatedAccountAndDevice.first().getUuid(), addedDeviceId).join().isPresent()); + assertTrue(keysManager.getEcSignedPreKey(updatedAccountAndDevice.first().getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); + assertTrue(keysManager.getLastResort(updatedAccountAndDevice.first().getUuid(), addedDeviceId).join().isPresent()); + assertTrue(keysManager.getLastResort(updatedAccountAndDevice.first().getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); + } + + @Test + void removeDevice() throws InterruptedException { + final String number = PhoneNumberUtil.getInstance().format( + PhoneNumberUtil.getInstance().getExampleNumber("US"), + PhoneNumberUtil.PhoneNumberFormat.E164); + + final ECKeyPair aciKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniKeyPair = Curve.generateKeyPair(); + + final Account account = AccountsHelper.createAccount(accountsManager, number); + assertEquals(1, accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getDevices().size()); + + final Pair updatedAccountAndDevice = + accountsManager.addDevice(account, new DeviceSpec( + "device-name".getBytes(StandardCharsets.UTF_8), + "password", + "OWT", + new Device.DeviceCapabilities(true, true, true, true), + 1, + 2, + true, + Optional.empty(), + Optional.empty(), + KeysHelper.signedECPreKey(1, aciKeyPair), + KeysHelper.signedECPreKey(2, pniKeyPair), + KeysHelper.signedKEMPreKey(3, aciKeyPair), + KeysHelper.signedKEMPreKey(4, pniKeyPair))) + .join(); + + final byte addedDeviceId = updatedAccountAndDevice.second().getId(); + + final Account updatedAccount = accountsManager.removeDevice(updatedAccountAndDevice.first(), addedDeviceId).join(); + + assertEquals(1, updatedAccount.getDevices().size()); + + assertFalse(keysManager.getEcSignedPreKey(updatedAccount.getUuid(), addedDeviceId).join().isPresent()); + assertFalse(keysManager.getEcSignedPreKey(updatedAccount.getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); + assertFalse(keysManager.getLastResort(updatedAccount.getUuid(), addedDeviceId).join().isPresent()); + assertFalse(keysManager.getLastResort(updatedAccount.getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); + + assertTrue(keysManager.getEcSignedPreKey(updatedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent()); + assertTrue(keysManager.getEcSignedPreKey(updatedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); + assertTrue(keysManager.getLastResort(updatedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent()); + assertTrue(keysManager.getLastResort(updatedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); + } + + @Test + void removeDevicePartialFailure() throws InterruptedException { + final String number = PhoneNumberUtil.getInstance().format( + PhoneNumberUtil.getInstance().getExampleNumber("US"), + PhoneNumberUtil.PhoneNumberFormat.E164); + + final ECKeyPair aciKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniKeyPair = Curve.generateKeyPair(); + + final Account account = AccountsHelper.createAccount(accountsManager, number); + assertEquals(1, accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getDevices().size()); + + final UUID aci = account.getIdentifier(IdentityType.ACI); + final UUID pni = account.getIdentifier(IdentityType.PNI); + + final Pair updatedAccountAndDevice = + accountsManager.addDevice(account, new DeviceSpec( + "device-name".getBytes(StandardCharsets.UTF_8), + "password", + "OWT", + new Device.DeviceCapabilities(true, true, true, true), + 1, + 2, + true, + Optional.empty(), + Optional.empty(), + KeysHelper.signedECPreKey(1, aciKeyPair), + KeysHelper.signedECPreKey(2, pniKeyPair), + KeysHelper.signedKEMPreKey(3, aciKeyPair), + KeysHelper.signedKEMPreKey(4, pniKeyPair))) + .join(); + + final byte addedDeviceId = updatedAccountAndDevice.second().getId(); + + when(messagesManager.clear(any(), anyByte())) + .thenReturn(CompletableFuture.failedFuture(new RuntimeException("OH NO"))); + + assertThrows(CompletionException.class, + () -> accountsManager.removeDevice(updatedAccountAndDevice.first(), addedDeviceId).join()); + + final Account retrievedAccount = accountsManager.getByAccountIdentifierAsync(aci).join().orElseThrow(); + + assertEquals(2, retrievedAccount.getDevices().size()); + + assertTrue(keysManager.getEcSignedPreKey(retrievedAccount.getUuid(), addedDeviceId).join().isPresent()); + assertTrue(keysManager.getEcSignedPreKey(retrievedAccount.getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); + assertTrue(keysManager.getLastResort(retrievedAccount.getUuid(), addedDeviceId).join().isPresent()); + assertTrue(keysManager.getLastResort(retrievedAccount.getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); + + assertTrue(keysManager.getEcSignedPreKey(retrievedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent()); + assertTrue(keysManager.getEcSignedPreKey(retrievedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); + assertTrue(keysManager.getLastResort(retrievedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent()); + assertTrue(keysManager.getLastResort(retrievedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java index f5449e6cf..9802222fa 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java @@ -140,7 +140,7 @@ class KeysManagerTest { } @Test - void testDeleteByAccount() { + void testDeleteSingleUsePreKeysByAccount() { int keyId = 1; for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) { @@ -157,18 +157,18 @@ class KeysManagerTest { assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId).join().isPresent()); } - keysManager.delete(ACCOUNT_UUID).join(); + keysManager.deleteSingleUsePreKeys(ACCOUNT_UUID).join(); for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) { assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, deviceId).join()); assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, deviceId).join()); - assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId).join().isPresent()); - assertFalse(keysManager.getLastResort(ACCOUNT_UUID, deviceId).join().isPresent()); + assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId).join().isPresent()); + assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId).join().isPresent()); } } @Test - void testDeleteByAccountAndDevice() { + void testDeleteSingleUsePreKeysByAccountAndDevice() { int keyId = 1; for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) { @@ -185,12 +185,12 @@ class KeysManagerTest { assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId).join().isPresent()); } - keysManager.delete(ACCOUNT_UUID, DEVICE_ID).join(); + keysManager.deleteSingleUsePreKeys(ACCOUNT_UUID, DEVICE_ID).join(); assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join()); assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); - assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); - assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); + assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); + assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, (byte) (DEVICE_ID + 1)).join()); assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, (byte) (DEVICE_ID + 1)).join()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java index 949fe7499..9669459df 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagePersisterTest.java @@ -276,7 +276,7 @@ class MessagePersisterTest { when(messagesManager.persistMessages(any(UUID.class), anyByte(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); - when(keysManager.delete(any(), eq(inactiveId))).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.deleteSingleUsePreKeys(any(), eq(inactiveId))).thenReturn(CompletableFuture.completedFuture(null)); assertTimeoutPreemptively(Duration.ofSeconds(1), () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE_ID)); @@ -326,7 +326,7 @@ class MessagePersisterTest { when(messagesManager.persistMessages(any(UUID.class), anyByte(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); - when(keysManager.delete(any(), eq(deviceIdB))).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.deleteSingleUsePreKeys(any(), eq(deviceIdB))).thenReturn(CompletableFuture.completedFuture(null)); assertTimeoutPreemptively(Duration.ofSeconds(1), () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE_ID)); @@ -376,7 +376,7 @@ class MessagePersisterTest { when(messagesManager.persistMessages(any(UUID.class), anyByte(), anyList())).thenThrow(ItemCollectionSizeLimitExceededException.builder().build()); when(messagesManager.clear(any(UUID.class), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); - when(keysManager.delete(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.deleteSingleUsePreKeys(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null)); assertTimeoutPreemptively(Duration.ofSeconds(1), () -> messagePersister.persistQueue(destinationAccount, DESTINATION_DEVICE_ID)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java index c2d0d7fb4..24650531c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java @@ -18,6 +18,7 @@ import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; class RepeatedUseECSignedPreKeyStoreTest extends RepeatedUseSignedPreKeyStoreTest { @@ -47,6 +48,11 @@ class RepeatedUseECSignedPreKeyStoreTest extends RepeatedUseSignedPreKeyStoreTes return KeysHelper.signedECPreKey(currentKeyId++, IDENTITY_KEY_PAIR); } + @Override + protected DynamoDbClient getDynamoDbClient() { + return DYNAMO_DB_EXTENSION.getDynamoDbClient(); + } + @Test void storeIfAbsent() { final UUID identifier = UUID.randomUUID(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseKEMSignedPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseKEMSignedPreKeyStoreTest.java index 57a0c4751..0725fd663 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseKEMSignedPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseKEMSignedPreKeyStoreTest.java @@ -11,6 +11,7 @@ import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; class RepeatedUseKEMSignedPreKeyStoreTest extends RepeatedUseSignedPreKeyStoreTest { @@ -35,6 +36,11 @@ class RepeatedUseKEMSignedPreKeyStoreTest extends RepeatedUseSignedPreKeyStoreTe return keyStore; } + @Override + protected DynamoDbClient getDynamoDbClient() { + return DYNAMO_DB_EXTENSION.getDynamoDbClient(); + } + @Override protected KEMSignedPreKey generateSignedPreKey() { return KeysHelper.signedKEMPreKey(currentKeyId++, IDENTITY_KEY_PAIR); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java index 17fbd3152..9cd7c1986 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java @@ -5,16 +5,16 @@ package org.whispersystems.textsecuregcm.storage; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; import java.util.Map; import java.util.Optional; import java.util.UUID; - -import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.Test; +import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import software.amazon.awssdk.services.dynamodb.DynamoDbClient; +import software.amazon.awssdk.services.dynamodb.model.TransactWriteItemsRequest; abstract class RepeatedUseSignedPreKeyStoreTest> { @@ -22,6 +22,8 @@ abstract class RepeatedUseSignedPreKeyStoreTest> { protected abstract K generateSignedPreKey(); + protected abstract DynamoDbClient getDynamoDbClient(); + @Test void storeFind() { final RepeatedUseSignedPreKeyStore keys = getKeyStore(); @@ -52,7 +54,23 @@ abstract class RepeatedUseSignedPreKeyStoreTest> { } @Test - void deleteForDevice() { + void buildTransactWriteItemForInsertion() { + final RepeatedUseSignedPreKeyStore keys = getKeyStore(); + + assertEquals(Optional.empty(), keys.find(UUID.randomUUID(), Device.PRIMARY_ID).join()); + + final UUID identifier = UUID.randomUUID(); + final K signedPreKey = generateSignedPreKey(); + + getDynamoDbClient().transactWriteItems(TransactWriteItemsRequest.builder() + .transactItems(keys.buildTransactWriteItemForInsertion(identifier, Device.PRIMARY_ID, signedPreKey)) + .build()); + + assertEquals(Optional.of(signedPreKey), keys.find(identifier, Device.PRIMARY_ID).join()); + } + + @Test + void buildTransactWriteItemForDeletion() { final RepeatedUseSignedPreKeyStore keys = getKeyStore(); final UUID identifier = UUID.randomUUID(); @@ -63,36 +81,12 @@ abstract class RepeatedUseSignedPreKeyStoreTest> { ); keys.store(identifier, signedPreKeys).join(); - keys.delete(identifier, Device.PRIMARY_ID).join(); + + getDynamoDbClient().transactWriteItems(TransactWriteItemsRequest.builder() + .transactItems(keys.buildTransactWriteItemForDeletion(identifier, Device.PRIMARY_ID)) + .build()); assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join()); assertEquals(Optional.of(signedPreKeys.get(deviceId2)), keys.find(identifier, deviceId2).join()); } - - @ParameterizedTest - @ValueSource(booleans = {true, false}) - void deleteForAllDevices(final boolean excludePrimaryDevice) { - final RepeatedUseSignedPreKeyStore keys = getKeyStore(); - - assertDoesNotThrow(() -> keys.delete(UUID.randomUUID(), excludePrimaryDevice).join()); - - final byte deviceId2 = Device.PRIMARY_ID + 1; - - final UUID identifier = UUID.randomUUID(); - final Map signedPreKeys = Map.of( - Device.PRIMARY_ID, generateSignedPreKey(), - deviceId2, generateSignedPreKey() - ); - - keys.store(identifier, signedPreKeys).join(); - keys.delete(identifier, excludePrimaryDevice).join(); - - if (excludePrimaryDevice) { - assertTrue(keys.find(identifier, Device.PRIMARY_ID).join().isPresent()); - } else { - assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join()); - } - - assertEquals(Optional.empty(), keys.find(identifier, deviceId2).join()); - } }