From df421e01822eeb41a02de797e92be0c1b261e6f7 Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Tue, 5 Dec 2023 14:20:16 -0500 Subject: [PATCH] Update signed pre-keys in transactions --- .../controllers/KeysController.java | 73 +++++-- .../textsecuregcm/storage/Accounts.java | 36 +++- .../storage/AccountsManager.java | 153 ++++++++++---- .../textsecuregcm/storage/KeysManager.java | 38 +--- .../controllers/KeysControllerTest.java | 58 +++-- .../storage/AccountsManagerTest.java | 33 +-- .../textsecuregcm/storage/AccountsTest.java | 77 ++++++- .../storage/KeysManagerTest.java | 199 +++++++----------- 8 files changed, 415 insertions(+), 252 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index ca312ff90..5a865bedd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -14,9 +14,11 @@ import io.swagger.v3.oas.annotations.parameters.RequestBody; import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.tags.Tag; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.UUID; import java.util.concurrent.CompletableFuture; import javax.validation.Valid; import javax.validation.constraints.NotNull; @@ -67,6 +69,8 @@ public class KeysController { private final AccountsManager accounts; private final Experiment compareSignedEcPreKeysExperiment = new Experiment("compareSignedEcPreKeys"); + private static final CompletableFuture[] EMPTY_FUTURE_ARRAY = new CompletableFuture[0]; + public KeysController(RateLimiters rateLimiters, KeysManager keys, AccountsManager accounts) { this.rateLimiters = rateLimiters; this.keys = keys; @@ -110,24 +114,51 @@ public class KeysController { description="whether this operation applies to the account (aci) or phone-number (pni) identity") @QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) { - Account account = disabledPermittedAuth.getAccount(); + final Account account = disabledPermittedAuth.getAccount(); final Device device = disabledPermittedAuth.getAuthenticatedDevice(); + final UUID identifier = account.getIdentifier(identityType); checkSignedPreKeySignatures(setKeysRequest, account.getIdentityKey(identityType)); + final CompletableFuture updateAccountFuture; + if (setKeysRequest.signedPreKey() != null && !setKeysRequest.signedPreKey().equals(device.getSignedPreKey(identityType))) { - account = accounts.update(account, a -> a.getDevice(device.getId()).ifPresent(d -> { - switch (identityType) { - case ACI -> d.setSignedPreKey(setKeysRequest.signedPreKey()); - case PNI -> d.setPhoneNumberIdentitySignedPreKey(setKeysRequest.signedPreKey()); - } - })); + updateAccountFuture = accounts.updateDeviceTransactionallyAsync(account, + device.getId(), + d -> { + switch (identityType) { + case ACI -> d.setSignedPreKey(setKeysRequest.signedPreKey()); + case PNI -> d.setPhoneNumberIdentitySignedPreKey(setKeysRequest.signedPreKey()); + } + }, + d -> keys.buildWriteItemForEcSignedPreKey(identifier, d.getId(), setKeysRequest.signedPreKey()) + .map(List::of) + .orElseGet(Collections::emptyList)) + .toCompletableFuture(); + } else { + updateAccountFuture = CompletableFuture.completedFuture(account); } - return keys.store(account.getIdentifier(identityType), device.getId(), - setKeysRequest.preKeys(), setKeysRequest.pqPreKeys(), setKeysRequest.signedPreKey(), setKeysRequest.pqLastResortPreKey()) + return updateAccountFuture.thenCompose(updatedAccount -> { + final List> storeFutures = new ArrayList<>(3); + + if (setKeysRequest.preKeys() != null) { + storeFutures.add(keys.storeEcOneTimePreKeys(identifier, device.getId(), setKeysRequest.preKeys())); + } + + if (setKeysRequest.pqPreKeys() != null) { + storeFutures.add(keys.storeKemOneTimePreKeys(identifier, device.getId(), setKeysRequest.pqPreKeys())); + } + + if (setKeysRequest.pqLastResortPreKey() != null) { + storeFutures.add( + keys.storePqLastResort(identifier, Map.of(device.getId(), setKeysRequest.pqLastResortPreKey()))); + } + + return CompletableFuture.allOf(storeFutures.toArray(EMPTY_FUTURE_ARRAY)); + }) .thenApply(Util.ASYNC_EMPTY_RESPONSE); } @@ -265,17 +296,21 @@ public class KeysController { @Valid final ECSignedPreKey signedPreKey, @QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) { - Device device = auth.getAuthenticatedDevice(); + final UUID identifier = auth.getAccount().getIdentifier(identityType); + final byte deviceId = auth.getAuthenticatedDevice().getId(); - accounts.updateDevice(auth.getAccount(), device.getId(), d -> { - switch (identityType) { - case ACI -> d.setSignedPreKey(signedPreKey); - case PNI -> d.setPhoneNumberIdentitySignedPreKey(signedPreKey); - } - }); - - return keys.storeEcSignedPreKeys(auth.getAccount().getIdentifier(identityType), - Map.of(device.getId(), signedPreKey)) + return accounts.updateDeviceTransactionallyAsync(auth.getAccount(), + deviceId, + d -> { + switch (identityType) { + case ACI -> d.setSignedPreKey(signedPreKey); + case PNI -> d.setPhoneNumberIdentitySignedPreKey(signedPreKey); + } + }, + d -> keys.buildWriteItemForEcSignedPreKey(identifier, d.getId(), signedPreKey) + .map(List::of) + .orElseGet(Collections::emptyList)) + .toCompletableFuture() .thenApply(Util.ASYNC_EMPTY_RESPONSE); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java index 36ed7707d..6d753b852 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java @@ -95,6 +95,7 @@ public class Accounts extends AbstractDynamoDbStore { private static final Timer RESERVE_USERNAME_TIMER = Metrics.timer(name(Accounts.class, "reserveUsername")); private static final Timer CLEAR_USERNAME_HASH_TIMER = Metrics.timer(name(Accounts.class, "clearUsernameHash")); private static final Timer UPDATE_TIMER = Metrics.timer(name(Accounts.class, "update")); + private static final Timer UPDATE_TRANSACTIONALLY_TIMER = Metrics.timer(name(Accounts.class, "updateTransactionally")); private static final Timer RECLAIM_TIMER = Metrics.timer(name(Accounts.class, "reclaim")); private static final Timer GET_BY_NUMBER_TIMER = Metrics.timer(name(Accounts.class, "getByNumber")); private static final Timer GET_BY_USERNAME_HASH_TIMER = Metrics.timer(name(Accounts.class, "getByUsernameHash")); @@ -277,6 +278,7 @@ public class Accounts extends AbstractDynamoDbStore { !existingAccount.getNumber().equals(accountToCreate.getNumber())) { throw new IllegalArgumentException("reclaimed accounts must match"); } + return AsyncTimerUtil.record(RECLAIM_TIMER, () -> { accountToCreate.setVersion(existingAccount.getVersion()); @@ -364,7 +366,8 @@ public class Accounts extends AbstractDynamoDbStore { public void changeNumber(final Account account, final String number, final UUID phoneNumberIdentifier, - final Optional maybeDisplacedAccountIdentifier) { + final Optional maybeDisplacedAccountIdentifier, + final Collection additionalWriteItems) { CHANGE_NUMBER_TIMER.record(() -> { final String originalNumber = account.getNumber(); @@ -413,6 +416,8 @@ public class Accounts extends AbstractDynamoDbStore { .build()) .build()); + writeItems.addAll(additionalWriteItems); + final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder() .transactItems(writeItems) .build(); @@ -863,6 +868,35 @@ public class Accounts extends AbstractDynamoDbStore { joinAndUnwrapUpdateFuture(updateAsync(account)); } + public CompletionStage updateTransactionallyAsync(final Account account, + final Collection additionalWriteItems) { + + return AsyncTimerUtil.record(UPDATE_TRANSACTIONALLY_TIMER, () -> { + final List writeItems = new ArrayList<>(additionalWriteItems.size() + 1); + writeItems.add(UpdateAccountSpec.forAccount(accountsTableName, account).transactItem()); + writeItems.addAll(additionalWriteItems); + + return asyncClient.transactWriteItems(TransactWriteItemsRequest.builder() + .transactItems(writeItems) + .build()) + .thenApply(response -> { + account.setVersion(account.getVersion() + 1); + return (Void) null; + }) + .exceptionally(throwable -> { + final Throwable unwrapped = ExceptionUtils.unwrap(throwable); + + if (unwrapped instanceof TransactionCanceledException transactionCanceledException) { + if ("ConditionalCheckFailed".equals(transactionCanceledException.cancellationReasons().get(0).code())) { + throw new ContestedOptimisticLockException(); + } + } + + throw CompletableFutureUtils.errorAsCompletionException(throwable); + }); + }); + } + public CompletableFuture usernameHashAvailable(final byte[] username) { return usernameHashAvailable(Optional.empty(), username); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 9d73cb056..1df84efef 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -27,6 +27,7 @@ import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; @@ -39,10 +40,10 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; -import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.commons.lang3.ObjectUtils; import org.apache.commons.lang3.StringUtils; @@ -71,6 +72,7 @@ import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.Util; import reactor.core.publisher.ParallelFlux; import reactor.core.scheduler.Scheduler; +import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem; public class AccountsManager { @@ -365,39 +367,25 @@ public class AccountsManager { final UUID uuid = account.getUuid(); final UUID phoneNumberIdentifier = phoneNumberIdentifiers.getPhoneNumberIdentifier(targetNumber); - final Account numberChangedAccount; - - numberChangedAccount = updateWithRetries( - account, - a -> { - setPniKeys(account, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); - return true; - }, - a -> accounts.changeNumber(a, targetNumber, phoneNumberIdentifier, maybeDisplacedUuid), - () -> accounts.getByAccountIdentifier(uuid).orElseThrow(), - AccountChangeValidator.NUMBER_CHANGE_VALIDATOR); - - updatedAccount.set(numberChangedAccount); - CompletableFuture.allOf( keysManager.delete(phoneNumberIdentifier), keysManager.delete(originalPhoneNumberIdentifier)) .join(); - keysManager.storeEcSignedPreKeys(phoneNumberIdentifier, pniSignedPreKeys); + final Collection keyWriteItems = + buildKeyWriteItems(uuid, phoneNumberIdentifier, pniSignedPreKeys, pniPqLastResortPreKeys); - if (pniPqLastResortPreKeys != null) { - keysManager.getPqEnabledDevices(uuid).thenCompose( - deviceIds -> keysManager.storePqLastResort( - phoneNumberIdentifier, - deviceIds.stream() - .filter(pniPqLastResortPreKeys::containsKey) - .collect( - Collectors.toMap( - Function.identity(), - pniPqLastResortPreKeys::get)))) - .join(); - } + final Account numberChangedAccount = updateWithRetries( + account, + a -> { + setPniKeys(account, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); + return true; + }, + a -> accounts.changeNumber(a, targetNumber, phoneNumberIdentifier, maybeDisplacedUuid, keyWriteItems), + () -> accounts.getByAccountIdentifier(uuid).orElseThrow(), + AccountChangeValidator.NUMBER_CHANGE_VALIDATOR); + + updatedAccount.set(numberChangedAccount); }); return updatedAccount.get(); @@ -410,31 +398,58 @@ public class AccountsManager { final Map pniRegistrationIds) throws MismatchedDevicesException { validateDevices(account, pniSignedPreKeys, pniPqLastResortPreKeys, pniRegistrationIds); - final UUID pni = account.getPhoneNumberIdentifier(); - final Account updatedAccount = update(account, a -> { return setPniKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); }); + final UUID aci = account.getIdentifier(IdentityType.ACI); + final UUID pni = account.getIdentifier(IdentityType.PNI); + + final Collection keyWriteItems = + buildKeyWriteItems(pni, pni, pniSignedPreKeys, pniPqLastResortPreKeys); + + return redisDeleteAsync(account) + .thenCompose(ignored -> keysManager.delete(pni)) + .thenCompose(ignored -> updateTransactionallyWithRetriesAsync(account, + a -> setPniKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds), + accounts::updateTransactionallyAsync, + () -> accounts.getByAccountIdentifierAsync(aci).thenApply(Optional::orElseThrow), + a -> keyWriteItems, + AccountChangeValidator.GENERAL_CHANGE_VALIDATOR, + MAX_UPDATE_ATTEMPTS)) + .join(); + } + + private Collection buildKeyWriteItems( + final UUID enabledDevicesIdentifier, + final UUID phoneNumberIdentifier, + @Nullable final Map pniSignedPreKeys, + @Nullable final Map pniPqLastResortPreKeys) { + + final List keyWriteItems = new ArrayList<>(); + + if (pniSignedPreKeys != null) { + pniSignedPreKeys.forEach((deviceId, signedPreKey) -> + keysManager.buildWriteItemForEcSignedPreKey(phoneNumberIdentifier, deviceId, signedPreKey) + .ifPresent(keyWriteItems::add)); + } - keysManager.delete(pni); - keysManager.storeEcSignedPreKeys(pni, pniSignedPreKeys).join(); if (pniPqLastResortPreKeys != null) { - keysManager.getPqEnabledDevices(pni) - .thenCompose( - deviceIds -> keysManager.storePqLastResort( - pni, - deviceIds.stream() - .filter(pniPqLastResortPreKeys::containsKey) - .collect(Collectors.toMap(Function.identity(), pniPqLastResortPreKeys::get)))) + keysManager.getPqEnabledDevices(enabledDevicesIdentifier) + .thenAccept(deviceIds -> deviceIds.stream() + .filter(pniPqLastResortPreKeys::containsKey) + .map(deviceId -> keysManager.buildWriteItemForLastResortKey(phoneNumberIdentifier, + deviceId, + pniPqLastResortPreKeys.get(deviceId))) + .forEach(keyWriteItems::add)) .join(); } - return updatedAccount; + return keyWriteItems; } - private boolean setPniKeys(final Account account, + private void setPniKeys(final Account account, @Nullable final IdentityKey pniIdentityKey, @Nullable final Map pniSignedPreKeys, @Nullable final Map pniRegistrationIds) { if (ObjectUtils.allNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) { - return false; + return; } else if (!ObjectUtils.allNotNull(pniIdentityKey, pniSignedPreKeys, pniRegistrationIds)) { throw new IllegalArgumentException("PNI identity key, signed pre-keys, and registration IDs must be all null or all non-null"); } @@ -455,8 +470,6 @@ public class AccountsManager { } account.setPhoneNumberIdentityKey(pniIdentityKey); - - return changed; } private void validateDevices(final Account account, @@ -777,6 +790,42 @@ public class AccountsManager { return CompletableFuture.failedFuture(new OptimisticLockRetryLimitExceededException()); } + private CompletionStage updateTransactionallyWithRetriesAsync(final Account account, + final Consumer updater, + final BiFunction, CompletionStage> persister, + final Supplier> retriever, + final Function> additionalWriteItemProvider, + final AccountChangeValidator changeValidator, + final int remainingTries) { + + final Account originalAccount = AccountUtil.cloneAccountAsNotStale(account); + + final Collection additionalWriteItems = additionalWriteItemProvider.apply(account); + updater.accept(account); + + if (remainingTries > 0) { + return persister.apply(account, additionalWriteItems) + .thenApply(ignored -> { + final Account updatedAccount = AccountUtil.cloneAccountAsNotStale(account); + account.markStale(); + + changeValidator.validateChange(originalAccount, updatedAccount); + + return updatedAccount; + }) + .exceptionallyCompose(throwable -> { + if (ExceptionUtils.unwrap(throwable) instanceof ContestedOptimisticLockException) { + return retriever.get().thenCompose(refreshedAccount -> + updateTransactionallyWithRetriesAsync(refreshedAccount, updater, persister, retriever, additionalWriteItemProvider, changeValidator, remainingTries - 1)); + } else { + throw ExceptionUtils.wrap(throwable); + } + }); + } + + return CompletableFuture.failedFuture(new OptimisticLockRetryLimitExceededException()); + } + public Account updateDevice(Account account, byte deviceId, Consumer deviceUpdater) { return update(account, a -> { a.getDevice(deviceId).ifPresent(deviceUpdater); @@ -794,6 +843,22 @@ public class AccountsManager { }); } + public CompletionStage updateDeviceTransactionallyAsync(final Account account, + final byte deviceId, + final Consumer deviceUpdater, + final Function> additionalWriteItemProvider) { + + final UUID uuid = account.getUuid(); + + return redisDeleteAsync(account).thenCompose(ignored -> updateTransactionallyWithRetriesAsync(account, + a -> a.getDevice(deviceId).ifPresent(deviceUpdater), + accounts::updateTransactionallyAsync, + () -> accounts.getByAccountIdentifierAsync(uuid).thenApply(Optional::orElseThrow), + a -> additionalWriteItemProvider.apply(a.getDevice(deviceId).orElseThrow()), + AccountChangeValidator.GENERAL_CHANGE_VALIDATOR, + MAX_UPDATE_ATTEMPTS)); + } + public Optional getByE164(final String number) { return checkRedisThenAccounts( getByNumberTimer, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java index b541499f9..c7d27084b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java @@ -6,13 +6,11 @@ package org.whispersystems.textsecuregcm.storage; import com.google.common.annotations.VisibleForTesting; -import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; -import javax.annotation.Nullable; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; @@ -43,34 +41,20 @@ public class KeysManager { this.dynamicConfigurationManager = dynamicConfigurationManager; } - public CompletableFuture store( - final UUID identifier, final byte deviceId, - @Nullable final List ecKeys, - @Nullable final List pqKeys, - @Nullable final ECSignedPreKey ecSignedPreKey, - @Nullable final KEMSignedPreKey pqLastResortKey) { + public Optional buildWriteItemForEcSignedPreKey(final UUID identifier, + final byte deviceId, + final ECSignedPreKey ecSignedPreKey) { - final List> storeFutures = new ArrayList<>(); + return dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().storeEcSignedPreKeys() + ? Optional.of(ecSignedPreKeys.buildTransactWriteItem(identifier, deviceId, ecSignedPreKey)) + : Optional.empty(); + } - if (ecKeys != null && !ecKeys.isEmpty()) { - storeFutures.add(ecPreKeys.store(identifier, deviceId, ecKeys)); - } + public TransactWriteItem buildWriteItemForLastResortKey(final UUID identifier, + final byte deviceId, + final KEMSignedPreKey lastResortSignedPreKey) { - if (pqKeys != null && !pqKeys.isEmpty()) { - storeFutures.add(pqPreKeys.store(identifier, deviceId, pqKeys)); - } - - if (ecSignedPreKey != null - && dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().storeEcSignedPreKeys()) { - - storeFutures.add(ecSignedPreKeys.store(identifier, deviceId, ecSignedPreKey)); - } - - if (pqLastResortKey != null) { - storeFutures.add(pqLastResortKeys.store(identifier, deviceId, pqLastResortKey)); - } - - return CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0])); + return pqLastResortKeys.buildTransactWriteItem(identifier, deviceId, lastResortSignedPreKey); } public List buildWriteItemsForRepeatedUseKeys(final UUID accountIdentifier, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java index 2cf0c4e58..6bf1d3ae4 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java @@ -10,7 +10,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.eq; @@ -35,6 +34,7 @@ import java.util.Optional; import java.util.OptionalInt; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; import javax.ws.rs.client.Entity; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; @@ -236,9 +236,27 @@ class KeysControllerTest { when(accounts.getByServiceIdentifier(new AciServiceIdentifier(EXISTS_UUID))).thenReturn(Optional.of(existsAccount)); when(accounts.getByServiceIdentifier(new PniServiceIdentifier(EXISTS_PNI))).thenReturn(Optional.of(existsAccount)); + when(accounts.updateDeviceTransactionallyAsync(any(), anyByte(), any(), any())).thenAnswer(invocation -> { + final Account account = invocation.getArgument(0); + final byte deviceId = invocation.getArgument(1); + final Consumer deviceUpdater = invocation.getArgument(2); + + deviceUpdater.accept(account.getDevice(deviceId).orElseThrow()); + + return CompletableFuture.completedFuture(account); + }); + when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter); - when(KEYS.store(any(), anyByte(), any(), any(), any(), any())).thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null)); + when(KEYS.storeEcOneTimePreKeys(any(), anyByte(), any())) + .thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null)); + + when(KEYS.storeKemOneTimePreKeys(any(), anyByte(), any())) + .thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null)); + + when(KEYS.storePqLastResort(any(), any())) + .thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null)); + when(KEYS.getEcSignedPreKey(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); when(KEYS.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null)); @@ -301,8 +319,7 @@ class KeysControllerTest { verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(test)); verify(AuthHelper.VALID_DEVICE, never()).setPhoneNumberIdentitySignedPreKey(any()); - verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyByte(), any()); - verify(KEYS).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(Device.PRIMARY_ID, test)); + verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.VALID_ACCOUNT), anyByte(), any(), any()); } @Test @@ -320,8 +337,7 @@ class KeysControllerTest { verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(replacementKey)); verify(AuthHelper.VALID_DEVICE, never()).setSignedPreKey(any()); - verify(accounts).updateDevice(eq(AuthHelper.VALID_ACCOUNT), anyByte(), any()); - verify(KEYS).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(Device.PRIMARY_ID, replacementKey)); + verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.VALID_ACCOUNT), anyByte(), any(), any()); } @Test @@ -748,13 +764,13 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); ArgumentCaptor> listCaptor = ArgumentCaptor.forClass(List.class); - verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), listCaptor.capture(), isNull(), - eq(signedPreKey), isNull()); + + verify(KEYS).storeEcOneTimePreKeys(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), listCaptor.capture()); assertThat(listCaptor.getValue()).containsExactly(preKey); verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey)); - verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any()); + verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.VALID_ACCOUNT), eq(SAMPLE_DEVICE_ID), any(), any()); } @Test @@ -782,14 +798,15 @@ class KeysControllerTest { ArgumentCaptor> ecCaptor = ArgumentCaptor.forClass(List.class); ArgumentCaptor> pqCaptor = ArgumentCaptor.forClass(List.class); - verify(KEYS).store(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), ecCaptor.capture(), pqCaptor.capture(), - eq(signedPreKey), eq(pqLastResortPreKey)); + verify(KEYS).storeEcOneTimePreKeys(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), ecCaptor.capture()); + verify(KEYS).storeKemOneTimePreKeys(eq(AuthHelper.VALID_UUID), eq(SAMPLE_DEVICE_ID), pqCaptor.capture()); + verify(KEYS).storePqLastResort(AuthHelper.VALID_UUID, Map.of(SAMPLE_DEVICE_ID, pqLastResortPreKey)); assertThat(ecCaptor.getValue()).containsExactly(preKey); assertThat(pqCaptor.getValue()).containsExactly(pqPreKey); verify(AuthHelper.VALID_DEVICE).setSignedPreKey(eq(signedPreKey)); - verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any()); + verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.VALID_ACCOUNT), eq(SAMPLE_DEVICE_ID), any(), any()); } @Test @@ -886,13 +903,12 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); ArgumentCaptor> listCaptor = ArgumentCaptor.forClass(List.class); - verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), listCaptor.capture(), isNull(), eq(signedPreKey), - isNull()); + verify(KEYS).storeEcOneTimePreKeys(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), listCaptor.capture()); assertThat(listCaptor.getValue()).containsExactly(preKey); verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(signedPreKey)); - verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any()); + verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.VALID_ACCOUNT), eq(SAMPLE_DEVICE_ID), any(), any()); } @Test @@ -921,14 +937,15 @@ class KeysControllerTest { ArgumentCaptor> ecCaptor = ArgumentCaptor.forClass(List.class); ArgumentCaptor> pqCaptor = ArgumentCaptor.forClass(List.class); - verify(KEYS).store(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), ecCaptor.capture(), pqCaptor.capture(), - eq(signedPreKey), eq(pqLastResortPreKey)); + verify(KEYS).storeEcOneTimePreKeys(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), ecCaptor.capture()); + verify(KEYS).storeKemOneTimePreKeys(eq(AuthHelper.VALID_PNI), eq(SAMPLE_DEVICE_ID), pqCaptor.capture()); + verify(KEYS).storePqLastResort(AuthHelper.VALID_PNI, Map.of(SAMPLE_DEVICE_ID, pqLastResortPreKey)); assertThat(ecCaptor.getValue()).containsExactly(preKey); assertThat(pqCaptor.getValue()).containsExactly(pqPreKey); verify(AuthHelper.VALID_DEVICE).setPhoneNumberIdentitySignedPreKey(eq(signedPreKey)); - verify(accounts).update(eq(AuthHelper.VALID_ACCOUNT), any()); + verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.VALID_ACCOUNT), eq(SAMPLE_DEVICE_ID), any(), any()); } @Test @@ -967,8 +984,7 @@ class KeysControllerTest { assertThat(response.getStatus()).isEqualTo(204); ArgumentCaptor> listCaptor = ArgumentCaptor.forClass(List.class); - verify(KEYS).store(eq(AuthHelper.DISABLED_UUID), eq(SAMPLE_DEVICE_ID), listCaptor.capture(), isNull(), - eq(signedPreKey), isNull()); + verify(KEYS).storeEcOneTimePreKeys(eq(AuthHelper.DISABLED_UUID), eq(SAMPLE_DEVICE_ID), listCaptor.capture()); List capturedList = listCaptor.getValue(); assertThat(capturedList.size()).isEqualTo(1); @@ -976,6 +992,6 @@ class KeysControllerTest { assertThat(capturedList.get(0).publicKey()).isEqualTo(preKey.publicKey()); verify(AuthHelper.DISABLED_DEVICE).setSignedPreKey(eq(signedPreKey)); - verify(accounts).update(eq(AuthHelper.DISABLED_ACCOUNT), any()); + verify(accounts).updateDeviceTransactionallyAsync(eq(AuthHelper.DISABLED_ACCOUNT), eq(SAMPLE_DEVICE_ID), any(), any()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 8ed161a37..011370b65 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -154,6 +154,7 @@ class AccountsManagerTest { when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); when(accounts.updateAsync(any())).thenReturn(CompletableFuture.completedFuture(null)); + when(accounts.updateTransactionallyAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(accounts.delete(any())).thenReturn(CompletableFuture.completedFuture(null)); doAnswer((Answer) invocation -> { @@ -164,7 +165,7 @@ class AccountsManagerTest { account.setNumber(number, phoneNumberIdentifier); return null; - }).when(accounts).changeNumber(any(), anyString(), any(), any()); + }).when(accounts).changeNumber(any(), anyString(), any(), any(), any()); final SecureStorageClient storageClient = mock(SecureStorageClient.class); when(storageClient.deleteStoredData(any())).thenReturn(CompletableFuture.completedFuture(null)); @@ -1163,7 +1164,6 @@ class AccountsManagerTest { verify(keysManager).delete(originalPni); verify(keysManager, atLeastOnce()).delete(targetPni); verify(keysManager).delete(newPni); - verify(keysManager).storeEcSignedPreKeys(eq(newPni), any()); verifyNoMoreInteractions(keysManager); } @@ -1209,9 +1209,9 @@ class AccountsManagerTest { verify(keysManager).delete(newPni); verify(keysManager).delete(originalPni); verify(keysManager).getPqEnabledDevices(uuid); - verify(keysManager).storeEcSignedPreKeys(newPni, newSignedKeys); - verify(keysManager).storePqLastResort(eq(newPni), - eq(Map.of(Device.PRIMARY_ID, newSignedPqKeys.get(Device.PRIMARY_ID)))); + verify(keysManager).buildWriteItemForEcSignedPreKey(eq(newPni), eq(Device.PRIMARY_ID), any()); + verify(keysManager).buildWriteItemForEcSignedPreKey(eq(newPni), eq(deviceId2), any()); + verify(keysManager).buildWriteItemForLastResortKey(eq(newPni), eq(Device.PRIMARY_ID), any()); verifyNoMoreInteractions(keysManager); } @@ -1304,9 +1304,12 @@ class AccountsManagerTest { assertEquals(newRegistrationIds, updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt()))); - verify(accounts).update(any()); + verify(accounts).updateTransactionallyAsync(any(), any()); verify(keysManager).delete(oldPni); + verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(Device.PRIMARY_ID), any()); + verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(deviceId2), any()); + verify(keysManager, never()).buildWriteItemForLastResortKey(any(), anyByte(), any()); } @Test @@ -1360,14 +1363,13 @@ class AccountsManagerTest { assertEquals(newRegistrationIds, updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt()))); - verify(accounts).update(any()); + verify(accounts).updateTransactionallyAsync(any(), any()); verify(keysManager).delete(oldPni); - verify(keysManager).storeEcSignedPreKeys(oldPni, newSignedKeys); - - // only the pq key for the already-pq-enabled device should be saved - verify(keysManager).storePqLastResort(eq(oldPni), - eq(Map.of(Device.PRIMARY_ID, newSignedPqKeys.get(Device.PRIMARY_ID)))); + verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(Device.PRIMARY_ID), any()); + verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(deviceId2), any()); + verify(keysManager).buildWriteItemForLastResortKey(eq(oldPni), eq(Device.PRIMARY_ID), any()); + verify(keysManager, never()).buildWriteItemForLastResortKey(eq(oldPni), eq(deviceId2), any()); } @Test @@ -1420,11 +1422,12 @@ class AccountsManagerTest { assertEquals(newRegistrationIds, updatedAccount.getDevices().stream().collect(Collectors.toMap(Device::getId, d -> d.getPhoneNumberIdentityRegistrationId().getAsInt()))); - verify(accounts).update(any()); + verify(accounts).updateTransactionallyAsync(any(), any()); verify(keysManager).delete(oldPni); - verify(keysManager).storeEcSignedPreKeys(oldPni, newSignedKeys); - verify(keysManager).storePqLastResort(any(), argThat(Map::isEmpty)); + verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(Device.PRIMARY_ID), any()); + verify(keysManager).buildWriteItemForEcSignedPreKey(eq(oldPni), eq(deviceId2), any()); + verify(keysManager, never()).buildWriteItemForLastResortKey(any(), anyByte(), any()); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java index 748cb3327..3c9671613 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java @@ -12,6 +12,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -99,7 +100,10 @@ class AccountsTest { Tables.NUMBERS, Tables.PNI_ASSIGNMENTS, Tables.USERNAMES, - Tables.DELETED_ACCOUNTS); + Tables.DELETED_ACCOUNTS, + + // This is an unrelated table used to test "tag-along" transactional updates + Tables.CLIENT_RELEASES); private final TestClock clock = TestClock.pinned(Instant.EPOCH); private DynamicConfigurationManager mockDynamicConfigManager; @@ -558,6 +562,71 @@ class AccountsTest { assertThatThrownBy(() -> accounts.update(account)).isInstanceOfAny(ContestedOptimisticLockException.class); } + @Test + void testUpdateTransactionally() { + final Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID()); + createAccount(account); + + final String deviceName = "device-name"; + + assertNotEquals(deviceName, + accounts.getByAccountIdentifier(account.getUuid()).orElseThrow().getPrimaryDevice().orElseThrow().getName()); + + assertFalse(DYNAMO_DB_EXTENSION.getDynamoDbClient().getItem(GetItemRequest.builder() + .tableName(Tables.CLIENT_RELEASES.tableName()) + .key(Map.of( + ClientReleases.ATTR_PLATFORM, AttributeValues.fromString("test"), + ClientReleases.ATTR_VERSION, AttributeValues.fromString("test") + )) + .build()) + .hasItem()); + + account.getPrimaryDevice().orElseThrow().setName(deviceName); + + accounts.updateTransactionallyAsync(account, List.of(TransactWriteItem.builder() + .put(Put.builder() + .tableName(Tables.CLIENT_RELEASES.tableName()) + .item(Map.of( + ClientReleases.ATTR_PLATFORM, AttributeValues.fromString("test"), + ClientReleases.ATTR_VERSION, AttributeValues.fromString("test") + )) + .build()) + .build())).toCompletableFuture().join(); + + assertEquals(deviceName, + accounts.getByAccountIdentifier(account.getUuid()).orElseThrow().getPrimaryDevice().orElseThrow().getName()); + + assertTrue(DYNAMO_DB_EXTENSION.getDynamoDbClient().getItem(GetItemRequest.builder() + .tableName(Tables.CLIENT_RELEASES.tableName()) + .key(Map.of( + ClientReleases.ATTR_PLATFORM, AttributeValues.fromString("test"), + ClientReleases.ATTR_VERSION, AttributeValues.fromString("test") + )) + .build()) + .hasItem()); + } + + @Test + void testUpdateTransactionallyContestedLock() { + final Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID()); + createAccount(account); + + account.setVersion(account.getVersion() - 1); + + final CompletionException completionException = assertThrows(CompletionException.class, + () -> accounts.updateTransactionallyAsync(account, List.of(TransactWriteItem.builder() + .put(Put.builder() + .tableName(Tables.CLIENT_RELEASES.tableName()) + .item(Map.of( + ClientReleases.ATTR_PLATFORM, AttributeValues.fromString("test"), + ClientReleases.ATTR_VERSION, AttributeValues.fromString("test") + )) + .build()) + .build())).toCompletableFuture().join()); + + assertTrue(completionException.getCause() instanceof ContestedOptimisticLockException); + } + @Test void testGetAll() { final List expectedAccounts = new ArrayList<>(); @@ -719,7 +788,7 @@ class AccountsTest { verifyStoredState(originalNumber, account.getUuid(), account.getPhoneNumberIdentifier(), null, retrieved.get(), account); } - accounts.changeNumber(account, targetNumber, targetPni, maybeDisplacedAccountIdentifier); + accounts.changeNumber(account, targetNumber, targetPni, maybeDisplacedAccountIdentifier, Collections.emptyList()); assertThat(accounts.getByE164(originalNumber)).isEmpty(); assertThat(accounts.getByAccountIdentifier(originalPni)).isEmpty(); @@ -766,7 +835,7 @@ class AccountsTest { createAccount(account); createAccount(existingAccount); - assertThrows(TransactionCanceledException.class, () -> accounts.changeNumber(account, targetNumber, targetPni, Optional.of(existingAccount.getUuid()))); + assertThrows(TransactionCanceledException.class, () -> accounts.changeNumber(account, targetNumber, targetPni, Optional.of(existingAccount.getUuid()), Collections.emptyList())); assertPhoneNumberConstraintExists(originalNumber, account.getUuid()); assertPhoneNumberIdentifierConstraintExists(originalPni, account.getUuid()); @@ -802,7 +871,7 @@ class AccountsTest { Map.of(":uuid", AttributeValues.fromUUID(existingAccountIdentifier))) .build()); - assertThrows(TransactionCanceledException.class, () -> accounts.changeNumber(account, targetNumber, existingPhoneNumberIdentifier, Optional.empty())); + assertThrows(TransactionCanceledException.class, () -> accounts.changeNumber(account, targetNumber, existingPhoneNumberIdentifier, Optional.empty(), Collections.emptyList())); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java index da6a5ef04..f5449e6cf 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java @@ -20,6 +20,8 @@ import java.util.UUID; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; @@ -65,54 +67,39 @@ class KeysManagerTest { } @Test - void testStore() { + void storeEcOneTimePreKeys() { assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(), "Initial pre-key count for an account should be zero"); - assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(), - "Initial pre-key count for an account should be zero"); - assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent(), - "Initial last-resort pre-key for an account should be missing"); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), null, null, null).join(); + keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1))).join(); assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join()); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), null, null, null).join(); + keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1))).join(); assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(), "Repeatedly storing same key should have no effect"); + } - keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestKEMSignedPreKey(1)), null, null).join(); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(), - "Uploading new PQ prekeys should have no effect on EC prekeys"); + @Test + void storeKemOneTimePreKeys() { + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(), + "Initial pre-key count for an account should be zero"); + + keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join(); assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, null, null, generateTestKEMSignedPreKey(1001)).join(); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(), - "Uploading new PQ last-resort prekey should have no effect on EC prekeys"); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(), - "Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys"); - assertEquals(1001, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId()); + keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join(); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); + } - keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null, null).join(); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(), - "Inserting a new key should overwrite all prior keys of the same type for the given account/device"); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(), - "Uploading new EC prekeys should have no effect on PQ prekeys"); + @Test + void storeEcSignedPreKeys() { + assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isEmpty()); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestKEMSignedPreKey(2)), null, null).join(); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(), - "Inserting a new key should overwrite all prior keys of the same type for the given account/device"); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(), - "Inserting a new key should overwrite all prior keys of the same type for the given account/device"); + final ECSignedPreKey signedPreKey = generateTestECSignedPreKey(1); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, - List.of(generateTestPreKey(4), generateTestPreKey(5)), - List.of(generateTestKEMSignedPreKey(6), generateTestKEMSignedPreKey(7)), null, generateTestKEMSignedPreKey(1002)).join(); - assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join(), - "Inserting multiple new keys should overwrite all prior keys for the given account/device"); - assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join(), - "Inserting multiple new keys should overwrite all prior keys for the given account/device"); - assertEquals(1002, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().get().keyId(), - "Uploading new last-resort key should overwrite prior last-resort key for the account/device"); + keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, Map.of(DEVICE_ID, signedPreKey)).join(); + + assertEquals(Optional.of(signedPreKey), keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join()); } @Test @@ -121,7 +108,8 @@ class KeysManagerTest { final ECPreKey preKey = generateTestPreKey(1); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2)), null, null, null).join(); + keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2))).join(); + final Optional takenKey = keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID).join(); assertEquals(Optional.of(preKey), takenKey); assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join()); @@ -135,7 +123,8 @@ class KeysManagerTest { final KEMSignedPreKey preKey2 = generateTestKEMSignedPreKey(2); final KEMSignedPreKey preKeyLast = generateTestKEMSignedPreKey(1001); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), null, preKeyLast).join(); + keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(preKey1, preKey2)).join(); + keysManager.storePqLastResort(ACCOUNT_UUID, Map.of(DEVICE_ID, preKeyLast)).join(); assertEquals(Optional.of(preKey1), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID).join()); assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); @@ -150,79 +139,51 @@ class KeysManagerTest { assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); } - @Test - void testGetCount() { - assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join()); - assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); - - keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestKEMSignedPreKey(1)), null, null).join(); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join()); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); - } - @Test void testDeleteByAccount() { - keysManager.store(ACCOUNT_UUID, DEVICE_ID, - List.of(generateTestPreKey(1), generateTestPreKey(2)), - List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)), - generateTestECSignedPreKey(5), - generateTestKEMSignedPreKey(6)) - .join(); + int keyId = 1; - final byte deviceId2 = DEVICE_ID + 1; - keysManager.store(ACCOUNT_UUID, deviceId2, - List.of(generateTestPreKey(7)), - List.of(generateTestKEMSignedPreKey(8)), - generateTestECSignedPreKey(9), - generateTestKEMSignedPreKey(10)) - .join(); + for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) { + keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestPreKey(keyId++))).join(); + keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestKEMSignedPreKey(keyId++))).join(); + keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, Map.of(deviceId, generateTestECSignedPreKey(keyId++))).join(); + keysManager.storePqLastResort(ACCOUNT_UUID, Map.of(deviceId, generateTestKEMSignedPreKey(keyId++))).join(); + } - assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join()); - assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); - assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); - assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join()); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join()); - assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent()); - assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent()); + for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) { + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, deviceId).join()); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, deviceId).join()); + assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId).join().isPresent()); + assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId).join().isPresent()); + } keysManager.delete(ACCOUNT_UUID).join(); - assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join()); - assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); - assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); - assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); - assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join()); - assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join()); - assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent()); - assertFalse(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent()); + for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) { + assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, deviceId).join()); + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, deviceId).join()); + assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId).join().isPresent()); + assertFalse(keysManager.getLastResort(ACCOUNT_UUID, deviceId).join().isPresent()); + } } @Test void testDeleteByAccountAndDevice() { - keysManager.store(ACCOUNT_UUID, DEVICE_ID, - List.of(generateTestPreKey(1), generateTestPreKey(2)), - List.of(generateTestKEMSignedPreKey(3), generateTestKEMSignedPreKey(4)), - generateTestECSignedPreKey(5), - generateTestKEMSignedPreKey(6)) - .join(); + int keyId = 1; - final byte deviceId2 = DEVICE_ID + 1; - keysManager.store(ACCOUNT_UUID, deviceId2, - List.of(generateTestPreKey(7)), - List.of(generateTestKEMSignedPreKey(8)), - generateTestECSignedPreKey(9), - generateTestKEMSignedPreKey(10)) - .join(); + for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) { + keysManager.storeEcOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestPreKey(keyId++))).join(); + keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, deviceId, List.of(generateTestKEMSignedPreKey(keyId++))).join(); + keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, Map.of(deviceId, generateTestECSignedPreKey(keyId++))).join(); + keysManager.storePqLastResort(ACCOUNT_UUID, Map.of(deviceId, generateTestKEMSignedPreKey(keyId++))).join(); + } - assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join()); - assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); - assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); - assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join()); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join()); - assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent()); - assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent()); + for (byte deviceId : new byte[] {DEVICE_ID, DEVICE_ID + 1}) { + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, deviceId).join()); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, deviceId).join()); + assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId).join().isPresent()); + assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId).join().isPresent()); + } keysManager.delete(ACCOUNT_UUID, DEVICE_ID).join(); @@ -230,10 +191,11 @@ class KeysManagerTest { assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, deviceId2).join()); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, deviceId2).join()); - assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, deviceId2).join().isPresent()); - assertTrue(keysManager.getLastResort(ACCOUNT_UUID, deviceId2).join().isPresent()); + + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, (byte) (DEVICE_ID + 1)).join()); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, (byte) (DEVICE_ID + 1)).join()); + assertTrue(keysManager.getEcSignedPreKey(ACCOUNT_UUID, (byte) (DEVICE_ID + 1)).join().isPresent()); + assertTrue(keysManager.getLastResort(ACCOUNT_UUID, (byte) (DEVICE_ID + 1)).join().isPresent()); } @Test @@ -272,15 +234,13 @@ class KeysManagerTest { @Test void testGetPqEnabledDevices() { - final ECKeyPair identityKeyPair = Curve.generateKeyPair(); + keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestKEMSignedPreKey(1))).join(); + + keysManager.storePqLastResort(ACCOUNT_UUID, Map.of((byte) (DEVICE_ID + 1), generateTestKEMSignedPreKey(2))).join(); + + keysManager.storeKemOneTimePreKeys(ACCOUNT_UUID, (byte) (DEVICE_ID + 2), List.of(generateTestKEMSignedPreKey(3))).join(); + keysManager.storePqLastResort(ACCOUNT_UUID, Map.of((byte) (DEVICE_ID + 2), generateTestKEMSignedPreKey(4))).join(); - keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(KeysHelper.signedKEMPreKey(1, identityKeyPair)), null, null).join(); - keysManager.store(ACCOUNT_UUID, (byte) (DEVICE_ID + 1), null, null, null, - KeysHelper.signedKEMPreKey(2, identityKeyPair)).join(); - keysManager.store(ACCOUNT_UUID, (byte) (DEVICE_ID + 2), null, - List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), null, KeysHelper.signedKEMPreKey(4, identityKeyPair)) - .join(); - keysManager.store(ACCOUNT_UUID, (byte) (DEVICE_ID + 3), null, null, null, null).join(); assertIterableEquals( Set.of((byte) (DEVICE_ID + 1), (byte) (DEVICE_ID + 2)), Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID).join())); @@ -290,21 +250,18 @@ class KeysManagerTest { void testStoreEcSignedPreKeyDisabled() { when(ecPreKeyMigrationConfiguration.storeEcSignedPreKeys()).thenReturn(false); - final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - - keysManager.store(ACCOUNT_UUID, DEVICE_ID, - List.of(generateTestPreKey(1)), - List.of(KeysHelper.signedKEMPreKey(2, identityKeyPair)), - KeysHelper.signedECPreKey(3, identityKeyPair), - KeysHelper.signedKEMPreKey(4, identityKeyPair)) - .join(); - - assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID).join()); - assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID).join()); - assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); + keysManager.storeEcSignedPreKeys(ACCOUNT_UUID, Map.of(DEVICE_ID, generateTestECSignedPreKey(1))).join(); assertFalse(keysManager.getEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID).join().isPresent()); } + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void buildWriteItemForEcSignedPreKey(final boolean enableSignedPreKeyWrite) { + when(ecPreKeyMigrationConfiguration.storeEcSignedPreKeys()).thenReturn(enableSignedPreKeyWrite); + assertEquals(enableSignedPreKeyWrite, + keysManager.buildWriteItemForEcSignedPreKey(ACCOUNT_UUID, DEVICE_ID, generateTestECSignedPreKey(1)).isPresent()); + } + private static ECPreKey generateTestPreKey(final long keyId) { return new ECPreKey(keyId, Curve.generateKeyPair().getPublicKey()); }