From 7980da9ce5a2c431841d59d93041e0e2fcd123c7 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Mon, 20 May 2024 10:48:16 -0400 Subject: [PATCH] Set client public keys in the scope of a pessimistic account lock --- .../textsecuregcm/WhisperServerService.java | 3 ++- .../controllers/DeviceController.java | 2 +- .../storage/ClientPublicKeysManager.java | 21 +++++++++++++++---- .../workers/CommandDependencies.java | 4 ++-- .../controllers/DeviceControllerTest.java | 2 +- ...ccountCreationDeletionIntegrationTest.java | 6 +++--- ...ntsManagerChangeNumberIntegrationTest.java | 5 +++-- .../AddRemoveDeviceIntegrationTest.java | 12 +++++------ 8 files changed, 35 insertions(+), 20 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index bd514ecd4..fb163af0f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -582,7 +582,8 @@ public class WhisperServerService extends Application setPublicKey(@Auth final AuthenticatedAccount auth, final SetPublicKeyRequest setPublicKeyRequest) { - return clientPublicKeysManager.setPublicKey(auth.getAccount().getIdentifier(IdentityType.ACI), + return clientPublicKeysManager.setPublicKey(auth.getAccount(), auth.getAuthenticatedDevice().getId(), setPublicKeyRequest.publicKey()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ClientPublicKeysManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ClientPublicKeysManager.java index 02c0a99a1..482f66924 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ClientPublicKeysManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ClientPublicKeysManager.java @@ -1,9 +1,12 @@ package org.whispersystems.textsecuregcm.storage; +import java.util.List; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; import org.signal.libsignal.protocol.ecc.ECPublicKey; +import org.whispersystems.textsecuregcm.identity.IdentityType; import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem; /** @@ -14,8 +17,16 @@ public class ClientPublicKeysManager { private final ClientPublicKeys clientPublicKeys; - public ClientPublicKeysManager(final ClientPublicKeys clientPublicKeys) { + private final AccountLockManager accountLockManager; + private final Executor accountLockExecutor; + + public ClientPublicKeysManager(final ClientPublicKeys clientPublicKeys, + final AccountLockManager accountLockManager, + final Executor accountLockExecutor) { + this.clientPublicKeys = clientPublicKeys; + this.accountLockManager = accountLockManager; + this.accountLockExecutor = accountLockExecutor; } /** @@ -23,14 +34,16 @@ public class ClientPublicKeysManager { * is intended for use for adding public keys to existing accounts/devices as a migration step. Callers should use * {@link #buildTransactWriteItemForInsertion(UUID, byte, ECPublicKey)} instead when creating new accounts/devices. * - * @param accountIdentifier the identifier for the target account + * @param account the target account * @param deviceId the identifier for the target device * @param publicKey the public key to store for the target account/device * @return a future that completes when the given key has been stored */ - public CompletableFuture setPublicKey(final UUID accountIdentifier, final byte deviceId, final ECPublicKey publicKey) { - return clientPublicKeys.setPublicKey(accountIdentifier, deviceId, publicKey); + public CompletableFuture setPublicKey(final Account account, final byte deviceId, final ECPublicKey publicKey) { + return accountLockManager.withLockAsync(List.of(account.getNumber()), + () -> clientPublicKeys.setPublicKey(account.getIdentifier(IdentityType.ACI), deviceId, publicKey), + accountLockExecutor); } /** diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java index 1fa08eaf4..07598850f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java @@ -150,8 +150,6 @@ record CommandDependencies( ClientPublicKeys clientPublicKeys = new ClientPublicKeys(dynamoDbAsyncClient, configuration.getDynamoDbTables().getClientPublicKeys().getTableName()); - ClientPublicKeysManager clientPublicKeysManager = new ClientPublicKeysManager(clientPublicKeys); - Accounts accounts = new Accounts( dynamoDbClient, dynamoDbAsyncClient, @@ -201,6 +199,8 @@ record CommandDependencies( reportMessageManager, messageDeletionExecutor); AccountLockManager accountLockManager = new AccountLockManager(dynamoDbClient, configuration.getDynamoDbTables().getDeletedAccountsLock().getTableName()); + ClientPublicKeysManager clientPublicKeysManager = + new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor); AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster, accountLockManager, keys, messagesManager, profilesManager, secureStorageClient, secureValueRecovery2Client, clientPresenceManager, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index 7f6d84087..29d17dd02 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -827,6 +827,6 @@ class DeviceControllerTest { assertEquals(204, response.getStatus()); } - verify(clientPublicKeysManager).setPublicKey(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE.getId(), request.publicKey()); + verify(clientPublicKeysManager).setPublicKey(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE.getId(), request.publicKey()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java index d1d7c8334..9c3e9b473 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java @@ -98,8 +98,6 @@ public class AccountCreationDeletionIntegrationTest { final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName()); - clientPublicKeysManager = new ClientPublicKeysManager(clientPublicKeys); - final Accounts accounts = new Accounts( DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), @@ -115,6 +113,8 @@ public class AccountCreationDeletionIntegrationTest { final AccountLockManager accountLockManager = new AccountLockManager(DYNAMO_DB_EXTENSION.getDynamoDbClient(), DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS_LOCK.tableName()); + clientPublicKeysManager = new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor); + final SecureStorageClient secureStorageClient = mock(SecureStorageClient.class); when(secureStorageClient.deleteStoredData(any())).thenReturn(CompletableFuture.completedFuture(null)); @@ -459,7 +459,7 @@ public class AccountCreationDeletionIntegrationTest { aciPqLastResortPreKey, pniPqLastResortPreKey)); - clientPublicKeysManager.setPublicKey(account.getIdentifier(IdentityType.ACI), Device.PRIMARY_ID, Curve.generateKeyPair().getPublicKey()).join(); + clientPublicKeysManager.setPublicKey(account, Device.PRIMARY_ID, Curve.generateKeyPair().getPublicKey()).join(); final UUID aci = account.getIdentifier(IdentityType.ACI); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java index e8627e8ae..2504cc81f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java @@ -91,8 +91,6 @@ class AccountsManagerChangeNumberIntegrationTest { final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName()); - final ClientPublicKeysManager clientPublicKeysManager = new ClientPublicKeysManager(clientPublicKeys); - final Accounts accounts = new Accounts( DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), @@ -108,6 +106,9 @@ class AccountsManagerChangeNumberIntegrationTest { final AccountLockManager accountLockManager = new AccountLockManager(DYNAMO_DB_EXTENSION.getDynamoDbClient(), Tables.DELETED_ACCOUNTS_LOCK.tableName()); + final ClientPublicKeysManager clientPublicKeysManager = + new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor); + final SecureStorageClient secureStorageClient = mock(SecureStorageClient.class); when(secureStorageClient.deleteStoredData(any())).thenReturn(CompletableFuture.completedFuture(null)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java index 51d873d9e..c33d85389 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java @@ -86,8 +86,6 @@ public class AddRemoveDeviceIntegrationTest { final ClientPublicKeys clientPublicKeys = new ClientPublicKeys(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS.tableName()); - clientPublicKeysManager = new ClientPublicKeysManager(clientPublicKeys); - final Accounts accounts = new Accounts( DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), @@ -103,6 +101,8 @@ public class AddRemoveDeviceIntegrationTest { final AccountLockManager accountLockManager = new AccountLockManager(DYNAMO_DB_EXTENSION.getDynamoDbClient(), DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS_LOCK.tableName()); + clientPublicKeysManager = new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor); + final SecureStorageClient secureStorageClient = mock(SecureStorageClient.class); when(secureStorageClient.deleteStoredData(any())).thenReturn(CompletableFuture.completedFuture(null)); @@ -229,8 +229,8 @@ public class AddRemoveDeviceIntegrationTest { final byte addedDeviceId = updatedAccountAndDevice.second().getId(); - clientPublicKeysManager.setPublicKey(account.getUuid(), Device.PRIMARY_ID, Curve.generateKeyPair().getPublicKey()).join(); - clientPublicKeysManager.setPublicKey(account.getUuid(), addedDeviceId, Curve.generateKeyPair().getPublicKey()).join(); + clientPublicKeysManager.setPublicKey(account, Device.PRIMARY_ID, Curve.generateKeyPair().getPublicKey()).join(); + clientPublicKeysManager.setPublicKey(account, addedDeviceId, Curve.generateKeyPair().getPublicKey()).join(); final Account updatedAccount = accountsManager.removeDevice(updatedAccountAndDevice.first(), addedDeviceId).join(); @@ -290,8 +290,8 @@ public class AddRemoveDeviceIntegrationTest { final Account retrievedAccount = accountsManager.getByAccountIdentifierAsync(aci).join().orElseThrow(); - clientPublicKeysManager.setPublicKey(retrievedAccount.getUuid(), Device.PRIMARY_ID, Curve.generateKeyPair().getPublicKey()).join(); - clientPublicKeysManager.setPublicKey(retrievedAccount.getUuid(), addedDeviceId, Curve.generateKeyPair().getPublicKey()).join(); + clientPublicKeysManager.setPublicKey(retrievedAccount, Device.PRIMARY_ID, Curve.generateKeyPair().getPublicKey()).join(); + clientPublicKeysManager.setPublicKey(retrievedAccount, addedDeviceId, Curve.generateKeyPair().getPublicKey()).join(); assertEquals(2, retrievedAccount.getDevices().size());