From c8033f875db4256370ce2ea46b07d4d45b229db7 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Mon, 13 Nov 2023 13:05:29 -0500 Subject: [PATCH] Create accounts transactionally --- .../textsecuregcm/WhisperServerService.java | 16 +- .../controllers/RegistrationController.java | 52 +-- .../textsecuregcm/storage/Accounts.java | 20 +- .../storage/AccountsManager.java | 42 +- .../textsecuregcm/storage/KeysManager.java | 33 +- .../storage/RepeatedUseSignedPreKeyStore.java | 16 +- .../controllers/AccountControllerTest.java | 10 - .../RegistrationControllerTest.java | 97 ++-- .../AccountCreationIntegrationTest.java | 424 ++++++++++++++++++ ...ntsManagerChangeNumberIntegrationTest.java | 53 ++- ...ConcurrentModificationIntegrationTest.java | 36 +- .../storage/AccountsManagerTest.java | 56 ++- ...ccountsManagerUsernameIntegrationTest.java | 53 ++- .../textsecuregcm/storage/AccountsTest.java | 110 ++--- .../RepeatedUseSignedPreKeyStoreTest.java | 69 +-- .../tests/util/AccountsHelper.java | 32 ++ 16 files changed, 854 insertions(+), 265 deletions(-) create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationIntegrationTest.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 2234c95ef..aae8350c3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -342,8 +342,8 @@ public class WhisperServerService extends Application { - a.setIdentityKey(registrationRequest.aciIdentityKey()); - a.setPhoneNumberIdentityKey(registrationRequest.pniIdentityKey()); - - final Device device = a.getPrimaryDevice().orElseThrow(); - - device.setSignedPreKey(registrationRequest.deviceActivationRequest().aciSignedPreKey()); - device.setPhoneNumberIdentitySignedPreKey(registrationRequest.deviceActivationRequest().pniSignedPreKey()); - - registrationRequest.deviceActivationRequest().apnToken().ifPresent(apnRegistrationId -> { - device.setApnId(apnRegistrationId.apnRegistrationId()); - device.setVoipApnId(apnRegistrationId.voipRegistrationId()); - }); - - registrationRequest.deviceActivationRequest().gcmToken().ifPresent(gcmRegistrationId -> - device.setGcmId(gcmRegistrationId.gcmRegistrationId())); - - CompletableFuture.allOf( - keysManager.storeEcSignedPreKeys(a.getUuid(), - Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().aciSignedPreKey())), - keysManager.storePqLastResort(a.getUuid(), - Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().aciPqLastResortPreKey())), - keysManager.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(), - Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().pniSignedPreKey())), - keysManager.storePqLastResort(a.getPhoneNumberIdentifier(), - Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().pniPqLastResortPreKey()))) - .join(); - }); + final Account account = accounts.create(number, + password, + signalAgent, + registrationRequest.accountAttributes(), + existingAccount.map(Account::getBadges).orElseGet(ArrayList::new), + registrationRequest.aciIdentityKey(), + registrationRequest.pniIdentityKey(), + registrationRequest.deviceActivationRequest().aciSignedPreKey(), + registrationRequest.deviceActivationRequest().pniSignedPreKey(), + registrationRequest.deviceActivationRequest().aciPqLastResortPreKey(), + registrationRequest.deviceActivationRequest().pniPqLastResortPreKey(), + registrationRequest.deviceActivationRequest().apnToken(), + registrationRequest.deviceActivationRequest().gcmToken()); Metrics.counter(ACCOUNT_CREATED_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java index 2378e3c16..2f07926ae 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java @@ -19,6 +19,7 @@ import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -28,6 +29,7 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; +import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; import javax.annotation.Nonnull; @@ -157,6 +159,7 @@ public class Accounts extends AbstractDynamoDbStore { final String phoneNumberIdentifierConstraintTableName, final String usernamesConstraintTableName, final String deletedAccountsTableName) { + super(client); this.clock = clock; this.asyncClient = asyncClient; @@ -175,12 +178,14 @@ public class Accounts extends AbstractDynamoDbStore { final String phoneNumberIdentifierConstraintTableName, final String usernamesConstraintTableName, final String deletedAccountsTableName) { + this(Clock.systemUTC(), client, asyncClient, accountsTableName, phoneNumberConstraintTableName, phoneNumberIdentifierConstraintTableName, usernamesConstraintTableName, deletedAccountsTableName); } - public boolean create(final Account account) { + public boolean create(final Account account, final Function> additionalWriteItemsFunction) { + return CREATE_TIMER.record(() -> { try { final AttributeValue uuidAttr = AttributeValues.fromUUID(account.getUuid()); @@ -199,8 +204,13 @@ public class Accounts extends AbstractDynamoDbStore { // the newly-created account. final TransactWriteItem deletedAccountDelete = buildRemoveDeletedAccount(account.getNumber()); + final Collection writeItems = new ArrayList<>( + List.of(phoneNumberConstraintPut, phoneNumberIdentifierConstraintPut, accountPut, deletedAccountDelete)); + + writeItems.addAll(additionalWriteItemsFunction.apply(account)); + final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder() - .transactItems(phoneNumberConstraintPut, phoneNumberIdentifierConstraintPut, accountPut, deletedAccountDelete) + .transactItems(writeItems) .build(); try { @@ -229,7 +239,8 @@ public class Accounts extends AbstractDynamoDbStore { account.setUuid(UUIDUtil.fromByteBuffer(actualAccountUuid)); final Account existingAccount = getByAccountIdentifier(account.getUuid()).orElseThrow(); account.setNumber(existingAccount.getNumber(), existingAccount.getPhoneNumberIdentifier()); - joinAndUnwrapUpdateFuture(reclaimAccount(existingAccount, account)); + joinAndUnwrapUpdateFuture(reclaimAccount(existingAccount, account, additionalWriteItemsFunction.apply(account))); + return false; } @@ -254,7 +265,7 @@ public class Accounts extends AbstractDynamoDbStore { * @param existingAccount the existing account in the accounts table * @param accountToCreate a new account, with the same number and identifier as existingAccount */ - private CompletionStage reclaimAccount(final Account existingAccount, final Account accountToCreate) { + private CompletionStage reclaimAccount(final Account existingAccount, final Account accountToCreate, final Collection additionalWriteItems) { if (!existingAccount.getUuid().equals(accountToCreate.getUuid()) || !existingAccount.getNumber().equals(accountToCreate.getNumber())) { throw new IllegalArgumentException("reclaimed accounts must match"); @@ -310,6 +321,7 @@ public class Accounts extends AbstractDynamoDbStore { .build()); } writeItems.add(UpdateAccountSpec.forAccount(accountsTableName, accountToCreate).transactItem()); + writeItems.addAll(additionalWriteItems); return asyncClient.transactWriteItems(TransactWriteItemsRequest.builder().transactItems(writeItems).build()) .thenApply(response -> { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index d4890d40d..624e394fa 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -53,7 +53,9 @@ import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.entities.AccountAttributes; +import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; +import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.identity.IdentityType; @@ -175,17 +177,26 @@ public class AccountsManager { this.clock = requireNonNull(clock); } + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") public Account create(final String number, final String password, final String signalAgent, final AccountAttributes accountAttributes, - final List accountBadges) throws InterruptedException { + final List accountBadges, + final IdentityKey aciIdentityKey, + final IdentityKey pniIdentityKey, + final ECSignedPreKey aciSignedPreKey, + final ECSignedPreKey pniSignedPreKey, + final KEMSignedPreKey aciPqLastResortPreKey, + final KEMSignedPreKey pniPqLastResortPreKey, + final Optional maybeApnRegistrationId, + final Optional maybeGcmRegistrationId) throws InterruptedException { try (Timer.Context ignored = createTimer.time()) { final Account account = new Account(); accountLockManager.withLock(List.of(number), () -> { - Device device = new Device(); + final Device device = new Device(); device.setId(Device.PRIMARY_ID); device.setAuthTokenHash(SaltedTokenHash.generateFor(password)); device.setFetchesMessages(accountAttributes.getFetchesMessages()); @@ -196,6 +207,16 @@ public class AccountsManager { device.setCreated(System.currentTimeMillis()); device.setLastSeen(Util.todayInMillis()); device.setUserAgent(signalAgent); + device.setSignedPreKey(aciSignedPreKey); + device.setPhoneNumberIdentitySignedPreKey(pniSignedPreKey); + + maybeApnRegistrationId.ifPresent(apnRegistrationId -> { + device.setApnId(apnRegistrationId.apnRegistrationId()); + device.setVoipApnId(apnRegistrationId.voipRegistrationId()); + }); + + maybeGcmRegistrationId.ifPresent(gcmRegistrationId -> + device.setGcmId(gcmRegistrationId.gcmRegistrationId())); account.setNumber(number, phoneNumberIdentifiers.getPhoneNumberIdentifier(number)); @@ -205,6 +226,8 @@ public class AccountsManager { // Reuse the ACI from any recently-deleted account with this number to cover cases where somebody is // re-registering. account.setUuid(maybeRecentlyDeletedAccountIdentifier.orElseGet(UUID::randomUUID)); + account.setIdentityKey(aciIdentityKey); + account.setPhoneNumberIdentityKey(pniIdentityKey); account.addDevice(device); account.setRegistrationLockFromAttributes(accountAttributes); account.setUnidentifiedAccessKey(accountAttributes.getUnidentifiedAccessKey()); @@ -214,7 +237,14 @@ public class AccountsManager { final UUID originalUuid = account.getUuid(); - boolean freshUser = accounts.create(account); + final boolean freshUser = accounts.create(account, + a -> keysManager.buildWriteItemsForRepeatedUseKeys(a.getIdentifier(IdentityType.ACI), + a.getIdentifier(IdentityType.PNI), + Device.PRIMARY_ID, + aciSignedPreKey, + pniSignedPreKey, + aciPqLastResortPreKey, + pniPqLastResortPreKey)); // create() sometimes updates the UUID, if there was a number conflict. // for metrics, we want secondary to run with the same original UUID @@ -235,9 +265,11 @@ public class AccountsManager { // confident that everything has already been deleted. In the second case, though, we're taking over an existing // account and need to clear out messages and keys that may have been stored for the old account. if (!originalUuid.equals(actualUuid)) { + // We exclude the primary device's repeated-use keys from deletion because new keys were provided as part of + // the account creation process, and we don't want to delete the keys that just got added. final CompletableFuture deleteKeysFuture = CompletableFuture.allOf( - keysManager.delete(actualUuid), - keysManager.delete(account.getPhoneNumberIdentifier())); + keysManager.delete(actualUuid, true), + keysManager.delete(account.getPhoneNumberIdentifier(), true)); messagesManager.clear(actualUuid).join(); profilesManager.deleteAll(actualUuid).join(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java index 41cef17ee..c44c993a6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java @@ -18,6 +18,7 @@ import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem; public class KeysManager { @@ -75,6 +76,20 @@ public class KeysManager { return CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0])); } + public List buildWriteItemsForRepeatedUseKeys(final UUID accountIdentifier, + final UUID phoneNumberIdentifier, + final byte deviceId, + final ECSignedPreKey aciSignedPreKey, + final ECSignedPreKey pniSignedPreKey, + final KEMSignedPreKey aciPqLastResortPreKey, + final KEMSignedPreKey pniLastResortPreKey) { + + return List.of(ecSignedPreKeys.buildTransactWriteItem(accountIdentifier, deviceId, aciSignedPreKey), + ecSignedPreKeys.buildTransactWriteItem(phoneNumberIdentifier, deviceId, pniSignedPreKey), + pqLastResortKeys.buildTransactWriteItem(accountIdentifier, deviceId, aciPqLastResortPreKey), + pqLastResortKeys.buildTransactWriteItem(phoneNumberIdentifier, deviceId, pniLastResortPreKey)); + } + public CompletableFuture storeEcSignedPreKeys(final UUID identifier, final Map keys) { if (dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().storeEcSignedPreKeys()) { return ecSignedPreKeys.store(identifier, keys); @@ -134,14 +149,18 @@ public class KeysManager { return pqPreKeys.getCount(identifier, deviceId); } - public CompletableFuture delete(final UUID accountUuid) { + public CompletableFuture delete(final UUID identifier) { + return delete(identifier, false); + } + + public CompletableFuture delete(final UUID identifier, final boolean excludePrimaryDevice) { return CompletableFuture.allOf( - ecPreKeys.delete(accountUuid), - pqPreKeys.delete(accountUuid), - dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().deleteEcSignedPreKeys() - ? ecSignedPreKeys.delete(accountUuid) - : CompletableFuture.completedFuture(null), - pqLastResortKeys.delete(accountUuid)); + ecPreKeys.delete(identifier), + pqPreKeys.delete(identifier), + dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().deleteEcSignedPreKeys() + ? ecSignedPreKeys.delete(identifier, excludePrimaryDevice) + : CompletableFuture.completedFuture(null), + pqLastResortKeys.delete(identifier, excludePrimaryDevice)); } public CompletableFuture delete(final UUID accountUuid, final byte deviceId) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java index 2e61e2fc7..9fe0966c8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java @@ -112,6 +112,15 @@ public abstract class RepeatedUseSignedPreKeyStore> { .thenRun(() -> sample.stop(storeKeyBatchTimer)); } + TransactWriteItem buildTransactWriteItem(final UUID identifier, final byte deviceId, final K preKey) { + return TransactWriteItem.builder() + .put(Put.builder() + .tableName(tableName) + .item(getItemFromPreKey(identifier, deviceId, preKey)) + .build()) + .build(); + } + /** * Finds a repeated-use pre-key for a specific device. * @@ -142,14 +151,19 @@ public abstract class RepeatedUseSignedPreKeyStore> { * Clears all repeated-use pre-keys associated with the given account/identity. * * @param identifier the identifier for the account/identity for which to clear repeated-use pre-keys + * @param excludePrimaryDevice whether to exclude the primary device from repeated-use key deletion; this is intended + * for cases when a user "re-registers" and displaces an existing account record and has + * provided new repeated-use keys for the primary device in the process of creating the + * new account * * @return a future that completes once repeated-use pre-keys have been cleared from all devices associated with the * target account/identity */ - public CompletableFuture delete(final UUID identifier) { + public CompletableFuture delete(final UUID identifier, final boolean excludePrimaryDevice) { final Timer.Sample sample = Timer.start(); return getDeviceIdsWithKeys(identifier) + .filter(deviceId -> deviceId != Device.PRIMARY_ID || !excludePrimaryDevice) .map(deviceId -> DeleteItemRequest.builder() .tableName(tableName) .key(getPrimaryKey(identifier, deviceId)) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java index 4d1194b20..af09851ae 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java @@ -55,7 +55,6 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; -import org.mockito.stubbing.Answer; import org.signal.libsignal.usernames.BaseUsernameException; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; @@ -215,15 +214,6 @@ class AccountControllerTest { when(accountsManager.getByE164(eq(SENDER_PREAUTH))).thenReturn(Optional.empty()); when(accountsManager.getByE164(eq(SENDER_HAS_STORAGE))).thenReturn(Optional.of(senderHasStorage)); when(accountsManager.getByE164(eq(SENDER_TRANSFER))).thenReturn(Optional.of(senderTransfer)); - - when(accountsManager.create(any(), any(), any(), any(), any())).thenAnswer((Answer) invocation -> { - final Account account = mock(Account.class); - when(account.getUuid()).thenReturn(UUID.randomUUID()); - when(account.getNumber()).thenReturn(invocation.getArgument(0, String.class)); - when(account.getBadges()).thenReturn(invocation.getArgument(4, List.class)); - - return account; - }); } @AfterEach diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java index adf25c345..ef85cd730 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java @@ -7,10 +7,11 @@ package org.whispersystems.textsecuregcm.controllers; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -21,10 +22,12 @@ import io.dropwizard.testing.junit5.ResourceExtension; import java.io.UncheckedIOException; import java.nio.charset.StandardCharsets; import java.time.Duration; +import java.util.Arrays; import java.util.Base64; +import java.util.Collections; import java.util.EnumSet; import java.util.HashSet; -import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.UUID; @@ -72,7 +75,6 @@ import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; @@ -95,7 +97,6 @@ class RegistrationControllerTest { RegistrationLockVerificationManager.class); private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock( RegistrationRecoveryPasswordsManager.class); - private final KeysManager keysManager = mock(KeysManager.class); private final RateLimiters rateLimiters = mock(RateLimiters.class); private final RateLimiter registrationLimiter = mock(RateLimiter.class); @@ -110,7 +111,7 @@ class RegistrationControllerTest { .addResource( new RegistrationController(accountsManager, new PhoneVerificationTokenManager(registrationServiceClient, registrationRecoveryPasswordsManager), - registrationLockVerificationManager, keysManager, rateLimiters)) + registrationLockVerificationManager, rateLimiters)) .build(); @BeforeEach @@ -125,11 +126,6 @@ class RegistrationControllerTest { return invocation.getArgument(0); }); - - when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); - when(keysManager.storeEcOneTimePreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); - when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); - when(keysManager.storeKemOneTimePreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null)); } @Test @@ -171,7 +167,7 @@ class RegistrationControllerTest { final Account account = mock(Account.class); when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class))); - when(accountsManager.create(any(), any(), any(), any(), any())) + when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) .thenReturn(account); final String json = requestJson("sessionId", new byte[0], true, registrationId.orElse(0), pniRegistrationId); @@ -294,7 +290,7 @@ class RegistrationControllerTest { final Account account = mock(Account.class); when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class))); - when(accountsManager.create(any(), any(), any(), any(), any())) + when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) .thenReturn(account); final Invocation.Builder request = resources.getJerseyTest() @@ -352,7 +348,7 @@ class RegistrationControllerTest { final Account createdAccount = mock(Account.class); when(createdAccount.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class))); - when(accountsManager.create(any(), any(), any(), any(), any())) + when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) .thenReturn(createdAccount); expectedStatus = 200; @@ -406,7 +402,8 @@ class RegistrationControllerTest { final Account account = mock(Account.class); when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class))); - when(accountsManager.create(any(), any(), any(), any(), any())).thenReturn(account); + when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) + .thenReturn(account); final Invocation.Builder request = resources.getJerseyTest() .target("/v1/registration") @@ -429,7 +426,7 @@ class RegistrationControllerTest { final Account account = mock(Account.class); when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class))); - when(accountsManager.create(any(), any(), any(), any(), any())) + when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) .thenReturn(account); final Invocation.Builder request = resources.getJerseyTest() @@ -669,9 +666,8 @@ class RegistrationControllerTest { final ECSignedPreKey expectedPniSignedPreKey, final KEMSignedPreKey expectedAciPqLastResortPreKey, final KEMSignedPreKey expectedPniPqLastResortPreKey, - final Optional expectedApnsToken, - final Optional expectedApnsVoipToken, - final Optional expectedGcmToken) throws InterruptedException { + final Optional expectedApnRegistrationId, + final Optional expectedGcmRegistrationId) throws InterruptedException { when(registrationServiceClient.getSession(any(), any())) .thenReturn( @@ -679,9 +675,6 @@ class RegistrationControllerTest { Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null, SESSION_EXPIRATION_SECONDS)))); - when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); - when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); - final UUID accountIdentifier = UUID.randomUUID(); final UUID phoneNumberIdentifier = UUID.randomUUID(); final Device device = mock(Device.class); @@ -692,9 +685,8 @@ class RegistrationControllerTest { when(a.getPrimaryDevice()).thenReturn(Optional.of(device)); }); - when(accountsManager.create(any(), any(), any(), any(), any())).thenReturn(account); - - when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) + .thenReturn(account); final Invocation.Builder request = resources.getJerseyTest() .target("/v1/registration") @@ -705,27 +697,33 @@ class RegistrationControllerTest { assertEquals(200, response.getStatus()); } - verify(accountsManager).create(any(), any(), any(), any(), any()); + verify(accountsManager).create( + eq(NUMBER), + eq(PASSWORD), + isNull(), + argThat(attributes -> accountAttributesEqual(attributes, registrationRequest.accountAttributes())), + eq(Collections.emptyList()), + eq(expectedAciIdentityKey), + eq(expectedPniIdentityKey), + eq(expectedAciSignedPreKey), + eq(expectedPniSignedPreKey), + eq(expectedAciPqLastResortPreKey), + eq(expectedPniPqLastResortPreKey), + eq(expectedApnRegistrationId), + eq(expectedGcmRegistrationId)); + } - verify(account).setIdentityKey(expectedAciIdentityKey); - verify(account).setPhoneNumberIdentityKey(expectedPniIdentityKey); - - verify(device).setSignedPreKey(expectedAciSignedPreKey); - verify(device).setPhoneNumberIdentitySignedPreKey(expectedPniSignedPreKey); - - verify(keysManager).storeEcSignedPreKeys(accountIdentifier, Map.of(Device.PRIMARY_ID, expectedAciSignedPreKey)); - verify(keysManager).storeEcSignedPreKeys(phoneNumberIdentifier, Map.of(Device.PRIMARY_ID, expectedPniSignedPreKey)); - verify(keysManager).storePqLastResort(accountIdentifier, Map.of(Device.PRIMARY_ID, expectedAciPqLastResortPreKey)); - verify(keysManager).storePqLastResort(phoneNumberIdentifier, Map.of(Device.PRIMARY_ID, expectedPniPqLastResortPreKey)); - - expectedApnsToken.ifPresentOrElse(expectedToken -> verify(device).setApnId(expectedToken), - () -> verify(device, never()).setApnId(any())); - - expectedApnsVoipToken.ifPresentOrElse(expectedToken -> verify(device).setVoipApnId(expectedToken), - () -> verify(device, never()).setVoipApnId(any())); - - expectedGcmToken.ifPresentOrElse(expectedToken -> verify(device).setGcmId(expectedToken), - () -> verify(device, never()).setGcmId(any())); + private static boolean accountAttributesEqual(final AccountAttributes a, final AccountAttributes b) { + return a.getFetchesMessages() == b.getFetchesMessages() + && a.getRegistrationId() == b.getRegistrationId() + && a.isUnrestrictedUnidentifiedAccess() == b.isUnrestrictedUnidentifiedAccess() + && a.isDiscoverableByPhoneNumber() == b.isDiscoverableByPhoneNumber() + && Objects.equals(a.getPhoneNumberIdentityRegistrationId(), b.getPhoneNumberIdentityRegistrationId()) + && Objects.equals(a.getName(), b.getName()) + && Objects.equals(a.getRegistrationLock(), b.getRegistrationLock()) + && Arrays.equals(a.getUnidentifiedAccessKey(), b.getUnidentifiedAccessKey()) + && Objects.equals(a.getCapabilities(), b.getCapabilities()) + && Objects.equals(a.recoveryPassword(), b.recoveryPassword()); } private static Stream atomicAccountCreationSuccess() { @@ -800,8 +798,7 @@ class RegistrationControllerTest { pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, - Optional.of(apnsToken), - Optional.of(apnsVoipToken), + Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty()), // requires the request to be atomic @@ -823,8 +820,7 @@ class RegistrationControllerTest { pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, - Optional.of(apnsToken), - Optional.of(apnsVoipToken), + Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), Optional.empty()), // Fetches messages; no push tokens @@ -847,8 +843,7 @@ class RegistrationControllerTest { aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), - Optional.empty(), - Optional.of(gcmToken))); + Optional.of(new GcmRegistrationId(gcmToken)))); } /** diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationIntegrationTest.java new file mode 100644 index 000000000..ed24f3196 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationIntegrationTest.java @@ -0,0 +1,424 @@ +package org.whispersystems.textsecuregcm.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.i18n.phonenumbers.PhoneNumberUtil; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneId; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junitpioneer.jupiter.cartesian.ArgumentSets; +import org.junitpioneer.jupiter.cartesian.CartesianTest; +import org.signal.libsignal.protocol.IdentityKey; +import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.entities.AccountAttributes; +import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; +import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; +import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; +import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; +import org.whispersystems.textsecuregcm.tests.util.KeysHelper; + +public class AccountCreationIntegrationTest { + + @RegisterExtension + static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension( + DynamoDbExtensionSchema.Tables.ACCOUNTS, + DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS, + DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS_LOCK, + DynamoDbExtensionSchema.Tables.NUMBERS, + DynamoDbExtensionSchema.Tables.PNI, + DynamoDbExtensionSchema.Tables.PNI_ASSIGNMENTS, + DynamoDbExtensionSchema.Tables.USERNAMES, + DynamoDbExtensionSchema.Tables.EC_KEYS, + DynamoDbExtensionSchema.Tables.PQ_KEYS, + DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, + DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS); + + @RegisterExtension + static final RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); + + private static final Clock CLOCK = Clock.fixed(Instant.now(), ZoneId.systemDefault()); + + private ExecutorService accountLockExecutor; + + private AccountsManager accountsManager; + private KeysManager keysManager; + + record DeliveryChannels(boolean fetchesMessages, String apnsToken, String apnsVoipToken, String fcmToken) {} + + @BeforeEach + void setUp() { + @SuppressWarnings("unchecked") final DynamicConfigurationManager dynamicConfigurationManager = + mock(DynamicConfigurationManager.class); + + DynamicConfiguration dynamicConfiguration = new DynamicConfiguration(); + when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); + + keysManager = new KeysManager( + DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), + DynamoDbExtensionSchema.Tables.EC_KEYS.tableName(), + DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName(), + DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(), + DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName(), + dynamicConfigurationManager); + + final Accounts accounts = new Accounts( + DYNAMO_DB_EXTENSION.getDynamoDbClient(), + DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), + DynamoDbExtensionSchema.Tables.ACCOUNTS.tableName(), + DynamoDbExtensionSchema.Tables.NUMBERS.tableName(), + DynamoDbExtensionSchema.Tables.PNI_ASSIGNMENTS.tableName(), + DynamoDbExtensionSchema.Tables.USERNAMES.tableName(), + DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS.tableName()); + + accountLockExecutor = Executors.newSingleThreadExecutor(); + + final AccountLockManager accountLockManager = new AccountLockManager(DYNAMO_DB_EXTENSION.getDynamoDbClient(), + DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS_LOCK.tableName()); + + final SecureStorageClient secureStorageClient = mock(SecureStorageClient.class); + when(secureStorageClient.deleteStoredData(any())).thenReturn(CompletableFuture.completedFuture(null)); + + final SecureValueRecovery2Client svr2Client = mock(SecureValueRecovery2Client.class); + when(svr2Client.deleteBackups(any())).thenReturn(CompletableFuture.completedFuture(null)); + + final PhoneNumberIdentifiers phoneNumberIdentifiers = + new PhoneNumberIdentifiers(DYNAMO_DB_EXTENSION.getDynamoDbClient(), + DynamoDbExtensionSchema.Tables.PNI.tableName()); + + final MessagesManager messagesManager = mock(MessagesManager.class); + when(messagesManager.clear(any())).thenReturn(CompletableFuture.completedFuture(null)); + + final ProfilesManager profilesManager = mock(ProfilesManager.class); + when(profilesManager.deleteAll(any())).thenReturn(CompletableFuture.completedFuture(null)); + + final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = + mock(RegistrationRecoveryPasswordsManager.class); + + when(registrationRecoveryPasswordsManager.removeForNumber(any())) + .thenReturn(CompletableFuture.completedFuture(null)); + + accountsManager = new AccountsManager( + accounts, + phoneNumberIdentifiers, + CACHE_CLUSTER_EXTENSION.getRedisCluster(), + accountLockManager, + keysManager, + messagesManager, + profilesManager, + secureStorageClient, + svr2Client, + mock(ClientPresenceManager.class), + mock(ExperimentEnrollmentManager.class), + registrationRecoveryPasswordsManager, + accountLockExecutor, + CLOCK); + } + + @AfterEach + void tearDown() throws InterruptedException { + accountLockExecutor.shutdown(); + + //noinspection ResultOfMethodCallIgnored + accountLockExecutor.awaitTermination(1, TimeUnit.SECONDS); + } + + @CartesianTest + @CartesianTest.MethodFactory("createAccount") + void createAccount(final DeliveryChannels deliveryChannels, + final boolean discoverableByPhoneNumber) throws InterruptedException { + + final String number = PhoneNumberUtil.getInstance().format( + PhoneNumberUtil.getInstance().getExampleNumber("US"), + PhoneNumberUtil.PhoneNumberFormat.E164); + + final String password = RandomStringUtils.randomAlphanumeric(16); + final String signalAgent = RandomStringUtils.randomAlphabetic(3); + final int registrationId = ThreadLocalRandom.current().nextInt(Device.MAX_REGISTRATION_ID); + final int pniRegistrationId = ThreadLocalRandom.current().nextInt(Device.MAX_REGISTRATION_ID); + final String deviceName = RandomStringUtils.randomAlphabetic(16); + final String registrationLockSecret = RandomStringUtils.randomAlphanumeric(16); + + final Device.DeviceCapabilities deviceCapabilities = new Device.DeviceCapabilities( + ThreadLocalRandom.current().nextBoolean(), + ThreadLocalRandom.current().nextBoolean(), + ThreadLocalRandom.current().nextBoolean(), + ThreadLocalRandom.current().nextBoolean()); + + final AccountAttributes accountAttributes = new AccountAttributes(deliveryChannels.fetchesMessages(), + registrationId, + deviceName, + registrationLockSecret, + discoverableByPhoneNumber, + deviceCapabilities); + + accountAttributes.setPhoneNumberIdentityRegistrationId(pniRegistrationId); + + final List badges = new ArrayList<>(List.of(new AccountBadge( + RandomStringUtils.randomAlphabetic(8), + CLOCK.instant().plus(Duration.ofDays(7)), + true))); + + final ECKeyPair aciKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniKeyPair = Curve.generateKeyPair(); + + final ECSignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciKeyPair); + final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniKeyPair); + final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciKeyPair); + final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniKeyPair); + + final Optional maybeApnRegistrationId = + deliveryChannels.apnsToken() != null || deliveryChannels.apnsVoipToken() != null + ? Optional.of(new ApnRegistrationId(deliveryChannels.apnsToken(), deliveryChannels.apnsVoipToken())) + : Optional.empty(); + + final Optional maybeGcmRegistrationId = deliveryChannels.fcmToken() != null + ? Optional.of(new GcmRegistrationId(deliveryChannels.fcmToken())) + : Optional.empty(); + + final Account account = accountsManager.create(number, + password, + signalAgent, + accountAttributes, + badges, + new IdentityKey(aciKeyPair.getPublicKey()), + new IdentityKey(pniKeyPair.getPublicKey()), + aciSignedPreKey, + pniSignedPreKey, + aciPqLastResortPreKey, + pniPqLastResortPreKey, + maybeApnRegistrationId, + maybeGcmRegistrationId); + + assertExpectedStoredAccount(account, + number, + password, + signalAgent, + deliveryChannels, + registrationId, + pniRegistrationId, + deviceName, + discoverableByPhoneNumber, + deviceCapabilities, + badges, + maybeApnRegistrationId, + maybeGcmRegistrationId, + registrationLockSecret, + aciSignedPreKey, + pniSignedPreKey, + aciPqLastResortPreKey, + pniPqLastResortPreKey); + } + + @CartesianTest + @CartesianTest.MethodFactory("createAccount") + void reregisterAccount(final DeliveryChannels deliveryChannels, + final boolean discoverableByPhoneNumber) throws InterruptedException { + + final String number = PhoneNumberUtil.getInstance().format( + PhoneNumberUtil.getInstance().getExampleNumber("US"), + PhoneNumberUtil.PhoneNumberFormat.E164); + + final UUID existingAccountUuid; + { + final ECKeyPair aciKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniKeyPair = Curve.generateKeyPair(); + + final ECSignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciKeyPair); + final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniKeyPair); + final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciKeyPair); + final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniKeyPair); + + final Account originalAccount = accountsManager.create(number, + RandomStringUtils.randomAlphanumeric(16), + "OWI", + new AccountAttributes(true, 1, "name", "registration-lock", false, new Device.DeviceCapabilities(false, false, false, false)), + Collections.emptyList(), + new IdentityKey(aciKeyPair.getPublicKey()), + new IdentityKey(pniKeyPair.getPublicKey()), + aciSignedPreKey, + pniSignedPreKey, + aciPqLastResortPreKey, + pniPqLastResortPreKey, + Optional.empty(), + Optional.empty()); + + existingAccountUuid = originalAccount.getUuid(); + } + + final String password = RandomStringUtils.randomAlphanumeric(16); + final String signalAgent = RandomStringUtils.randomAlphabetic(3); + final int registrationId = ThreadLocalRandom.current().nextInt(Device.MAX_REGISTRATION_ID); + final int pniRegistrationId = ThreadLocalRandom.current().nextInt(Device.MAX_REGISTRATION_ID); + final String deviceName = RandomStringUtils.randomAlphabetic(16); + final String registrationLockSecret = RandomStringUtils.randomAlphanumeric(16); + + final Device.DeviceCapabilities deviceCapabilities = new Device.DeviceCapabilities( + ThreadLocalRandom.current().nextBoolean(), + ThreadLocalRandom.current().nextBoolean(), + ThreadLocalRandom.current().nextBoolean(), + ThreadLocalRandom.current().nextBoolean()); + + final AccountAttributes accountAttributes = new AccountAttributes(deliveryChannels.fetchesMessages(), + registrationId, + deviceName, + registrationLockSecret, + discoverableByPhoneNumber, + deviceCapabilities); + + accountAttributes.setPhoneNumberIdentityRegistrationId(pniRegistrationId); + + final List badges = new ArrayList<>(List.of(new AccountBadge( + RandomStringUtils.randomAlphabetic(8), + CLOCK.instant().plus(Duration.ofDays(7)), + true))); + + final ECKeyPair aciKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniKeyPair = Curve.generateKeyPair(); + + final ECSignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciKeyPair); + final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniKeyPair); + final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciKeyPair); + final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniKeyPair); + + final Optional maybeApnRegistrationId = + deliveryChannels.apnsToken() != null || deliveryChannels.apnsVoipToken() != null + ? Optional.of(new ApnRegistrationId(deliveryChannels.apnsToken(), deliveryChannels.apnsVoipToken())) + : Optional.empty(); + + final Optional maybeGcmRegistrationId = deliveryChannels.fcmToken() != null + ? Optional.of(new GcmRegistrationId(deliveryChannels.fcmToken())) + : Optional.empty(); + + final Account reregisteredAccount = accountsManager.create(number, + password, + signalAgent, + accountAttributes, + badges, + new IdentityKey(aciKeyPair.getPublicKey()), + new IdentityKey(pniKeyPair.getPublicKey()), + aciSignedPreKey, + pniSignedPreKey, + aciPqLastResortPreKey, + pniPqLastResortPreKey, + maybeApnRegistrationId, + maybeGcmRegistrationId); + + assertExpectedStoredAccount(reregisteredAccount, + number, + password, + signalAgent, + deliveryChannels, + registrationId, + pniRegistrationId, + deviceName, + discoverableByPhoneNumber, + deviceCapabilities, + badges, + maybeApnRegistrationId, + maybeGcmRegistrationId, + registrationLockSecret, + aciSignedPreKey, + pniSignedPreKey, + aciPqLastResortPreKey, + pniPqLastResortPreKey); + + assertEquals(existingAccountUuid, reregisteredAccount.getUuid()); + } + + @SuppressWarnings("unused") + static ArgumentSets createAccount() { + return ArgumentSets + // deliveryChannels + .argumentsForFirstParameter( + new DeliveryChannels(true, null, null, null), + new DeliveryChannels(false, "apns-token", null, null), + new DeliveryChannels(false, "apns-token", "apns-voip-token", null), + new DeliveryChannels(false, null, null, "fcm-token")) + + // discoverableByPhoneNumber + .argumentsForNextParameter(true, false); + } + + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + private void assertExpectedStoredAccount(final Account account, + final String number, + final String password, + final String signalAgent, + final DeliveryChannels deliveryChannels, + final int registrationId, + final int pniRegistrationId, + final String deviceName, + final boolean discoverableByPhoneNumber, + final Device.DeviceCapabilities deviceCapabilities, + final List badges, + final Optional maybeApnRegistrationId, + final Optional maybeGcmRegistrationId, + final String registrationLockSecret, + final ECSignedPreKey aciSignedPreKey, + final ECSignedPreKey pniSignedPreKey, + final KEMSignedPreKey aciPqLastResortPreKey, + final KEMSignedPreKey pniPqLastResortPreKey) { + + final Device primaryDevice = account.getPrimaryDevice().orElseThrow(); + + assertEquals(number, account.getNumber()); + assertEquals(signalAgent, primaryDevice.getUserAgent()); + assertEquals(deliveryChannels.fetchesMessages(), primaryDevice.getFetchesMessages()); + assertEquals(registrationId, primaryDevice.getRegistrationId()); + assertEquals(pniRegistrationId, primaryDevice.getPhoneNumberIdentityRegistrationId().orElseThrow()); + assertEquals(deviceName, primaryDevice.getName()); + assertEquals(discoverableByPhoneNumber, account.isDiscoverableByPhoneNumber()); + assertEquals(deviceCapabilities, primaryDevice.getCapabilities()); + assertEquals(badges, account.getBadges()); + + maybeApnRegistrationId.ifPresentOrElse(apnRegistrationId -> { + assertEquals(apnRegistrationId.apnRegistrationId(), primaryDevice.getApnId()); + assertEquals(apnRegistrationId.voipRegistrationId(), primaryDevice.getVoipApnId()); + }, () -> { + assertNull(primaryDevice.getApnId()); + assertNull(primaryDevice.getVoipApnId()); + }); + + maybeGcmRegistrationId.ifPresentOrElse(gcmRegistrationId -> { + assertEquals(deliveryChannels.fcmToken(), primaryDevice.getGcmId()); + }, () -> { + assertNull(primaryDevice.getGcmId()); + }); + + assertTrue(account.getRegistrationLock().verify(registrationLockSecret)); + assertTrue(primaryDevice.getAuthTokenHash().verify(password)); + + assertEquals(Optional.of(aciSignedPreKey), keysManager.getEcSignedPreKey(account.getIdentifier(IdentityType.ACI), Device.PRIMARY_ID).join()); + assertEquals(Optional.of(pniSignedPreKey), keysManager.getEcSignedPreKey(account.getIdentifier(IdentityType.PNI), Device.PRIMARY_ID).join()); + assertEquals(Optional.of(aciPqLastResortPreKey), keysManager.getLastResort(account.getIdentifier(IdentityType.ACI), Device.PRIMARY_ID).join()); + assertEquals(Optional.of(pniPqLastResortPreKey), keysManager.getLastResort(account.getIdentifier(IdentityType.PNI), Device.PRIMARY_ID).join()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java index 44f61aea7..6d9d78d9f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java @@ -14,7 +14,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.time.Clock; -import java.util.ArrayList; import java.util.Map; import java.util.Optional; import java.util.OptionalInt; @@ -41,6 +40,7 @@ import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; +import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; class AccountsManagerChangeNumberIntegrationTest { @@ -53,7 +53,11 @@ class AccountsManagerChangeNumberIntegrationTest { Tables.NUMBERS, Tables.PNI, Tables.PNI_ASSIGNMENTS, - Tables.USERNAMES); + Tables.USERNAMES, + Tables.EC_KEYS, + Tables.PQ_KEYS, + Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, + Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS); @RegisterExtension static final RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); @@ -73,6 +77,14 @@ class AccountsManagerChangeNumberIntegrationTest { DynamicConfiguration dynamicConfiguration = new DynamicConfiguration(); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); + final KeysManager keysManager = new KeysManager( + DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), + Tables.EC_KEYS.tableName(), + Tables.PQ_KEYS.tableName(), + Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(), + Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName(), + dynamicConfigurationManager); + final Accounts accounts = new Accounts( DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), @@ -98,9 +110,6 @@ class AccountsManagerChangeNumberIntegrationTest { final PhoneNumberIdentifiers phoneNumberIdentifiers = new PhoneNumberIdentifiers(DYNAMO_DB_EXTENSION.getDynamoDbClient(), Tables.PNI.tableName()); - final KeysManager keysManager = mock(KeysManager.class); - when(keysManager.delete(any())).thenReturn(CompletableFuture.completedFuture(null)); - final MessagesManager messagesManager = mock(MessagesManager.class); when(messagesManager.clear(any())).thenReturn(CompletableFuture.completedFuture(null)); @@ -143,8 +152,8 @@ class AccountsManagerChangeNumberIntegrationTest { void testChangeNumber() throws InterruptedException, MismatchedDevicesException { final String originalNumber = "+18005551111"; final String secondNumber = "+18005552222"; + final Account account = AccountsHelper.createAccount(accountsManager, originalNumber); - final Account account = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>()); final UUID originalUuid = account.getUuid(); final UUID originalPni = account.getPhoneNumberIdentifier(); @@ -167,17 +176,17 @@ class AccountsManagerChangeNumberIntegrationTest { final String originalNumber = "+18005551111"; final String secondNumber = "+18005552222"; final int rotatedPniRegistrationId = 17; - final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); - final ECSignedPreKey rotatedSignedPreKey = KeysHelper.signedECPreKey(1L, pniIdentityKeyPair); - + final ECKeyPair rotatedPniIdentityKeyPair = Curve.generateKeyPair(); + final ECSignedPreKey rotatedSignedPreKey = KeysHelper.signedECPreKey(1L, rotatedPniIdentityKeyPair); final AccountAttributes accountAttributes = new AccountAttributes(true, rotatedPniRegistrationId + 1, "test", null, true, new Device.DeviceCapabilities(false, false, false, false)); - final Account account = accountsManager.create(originalNumber, "password", null, accountAttributes, new ArrayList<>()); - account.getPrimaryDevice().orElseThrow().setSignedPreKey(KeysHelper.signedECPreKey(1, pniIdentityKeyPair)); + final Account account = AccountsHelper.createAccount(accountsManager, originalNumber, accountAttributes); + + account.getPrimaryDevice().orElseThrow().setSignedPreKey(KeysHelper.signedECPreKey(1, rotatedPniIdentityKeyPair)); final UUID originalUuid = account.getUuid(); final UUID originalPni = account.getPhoneNumberIdentifier(); - final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); + final IdentityKey pniIdentityKey = new IdentityKey(rotatedPniIdentityKeyPair.getPublicKey()); final Map preKeys = Map.of(Device.PRIMARY_ID, rotatedSignedPreKey); final Map registrationIds = Map.of(Device.PRIMARY_ID, rotatedPniRegistrationId); @@ -207,7 +216,8 @@ class AccountsManagerChangeNumberIntegrationTest { final String originalNumber = "+18005551111"; final String secondNumber = "+18005552222"; - Account account = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>()); + Account account = AccountsHelper.createAccount(accountsManager, originalNumber); + final UUID originalUuid = account.getUuid(); final UUID originalPni = account.getPhoneNumberIdentifier(); @@ -231,10 +241,12 @@ class AccountsManagerChangeNumberIntegrationTest { final String originalNumber = "+18005551111"; final String secondNumber = "+18005552222"; - final Account account = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>()); + final Account account = AccountsHelper.createAccount(accountsManager, originalNumber); + final UUID originalUuid = account.getUuid(); - final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>()); + final Account existingAccount = AccountsHelper.createAccount(accountsManager, secondNumber); + final UUID existingAccountUuid = existingAccount.getUuid(); accountsManager.changeNumber(account, secondNumber, null, null, null, null); @@ -253,8 +265,7 @@ class AccountsManagerChangeNumberIntegrationTest { accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(), originalNumber, null, null, null, null); - final Account existingAccount2 = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), - new ArrayList<>()); + final Account existingAccount2 = AccountsHelper.createAccount(accountsManager, secondNumber); assertEquals(existingAccountUuid, existingAccount2.getUuid()); } @@ -264,17 +275,19 @@ class AccountsManagerChangeNumberIntegrationTest { final String originalNumber = "+18005551111"; final String secondNumber = "+18005552222"; - final Account account = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>()); + final Account account = AccountsHelper.createAccount(accountsManager, originalNumber); + final UUID originalUuid = account.getUuid(); final UUID originalPni = account.getPhoneNumberIdentifier(); - final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>()); + final Account existingAccount = AccountsHelper.createAccount(accountsManager, secondNumber); + final UUID existingAccountUuid = existingAccount.getUuid(); final Account changedNumberAccount = accountsManager.changeNumber(account, secondNumber, null, null, null, null); final UUID secondPni = changedNumberAccount.getPhoneNumberIdentifier(); - final Account reRegisteredAccount = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>()); + final Account reRegisteredAccount = AccountsHelper.createAccount(accountsManager, originalNumber); assertEquals(existingAccountUuid, reRegisteredAccount.getUuid()); assertEquals(originalPni, reRegisteredAccount.getPhoneNumberIdentifier()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java index f2063a4f3..3db244ec0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.time.Clock; import java.time.Instant; import java.util.ArrayList; +import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; @@ -39,6 +40,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; @@ -51,6 +53,7 @@ import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2 import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; import org.whispersystems.textsecuregcm.tests.util.JsonHelpers; +import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; import org.whispersystems.textsecuregcm.util.Pair; @@ -62,8 +65,11 @@ class AccountsManagerConcurrentModificationIntegrationTest { Tables.ACCOUNTS, Tables.NUMBERS, Tables.PNI_ASSIGNMENTS, - Tables.DELETED_ACCOUNTS - ); + Tables.DELETED_ACCOUNTS, + Tables.EC_KEYS, + Tables.PQ_KEYS, + Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, + Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS); private Accounts accounts; @@ -80,6 +86,14 @@ class AccountsManagerConcurrentModificationIntegrationTest { mock(DynamicConfigurationManager.class); when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); + final KeysManager keysManager = new KeysManager( + DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), + Tables.EC_KEYS.tableName(), + Tables.PQ_KEYS.tableName(), + Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(), + Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName(), + dynamicConfigurationManager); + accounts = new Accounts( DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), @@ -134,11 +148,25 @@ class AccountsManagerConcurrentModificationIntegrationTest { @Test void testConcurrentUpdate() throws IOException, InterruptedException { - final UUID uuid; { + final ECKeyPair aciKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniKeyPair = Curve.generateKeyPair(); + final Account account = accountsManager.update( - accountsManager.create("+14155551212", "password", null, new AccountAttributes(), new ArrayList<>()), + accountsManager.create("+14155551212", + "password", + null, + new AccountAttributes(), + new ArrayList<>(), + new IdentityKey(aciKeyPair.getPublicKey()), + new IdentityKey(pniKeyPair.getPublicKey()), + KeysHelper.signedECPreKey(1, aciKeyPair), + KeysHelper.signedECPreKey(2, pniKeyPair), + KeysHelper.signedKEMPreKey(3, aciKeyPair), + KeysHelper.signedKEMPreKey(4, pniKeyPair), + Optional.empty(), + Optional.empty()), a -> { a.setUnidentifiedAccessKey(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); a.removeDevice(Device.PRIMARY_ID); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 12d8ac851..e4fd6301b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -13,6 +13,7 @@ import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; @@ -201,6 +202,7 @@ class AccountsManagerTest { when(registrationRecoveryPasswordsManager.removeForNumber(anyString())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.delete(any())).thenReturn(CompletableFuture.completedFuture(null)); + when(keysManager.delete(any(), anyBoolean())).thenReturn(CompletableFuture.completedFuture(null)); when(messagesManager.clear(any())).thenReturn(CompletableFuture.completedFuture(null)); when(profilesManager.deleteAll(any())).thenReturn(CompletableFuture.completedFuture(null)); @@ -853,7 +855,7 @@ class AccountsManagerTest { when(commands.get(eq("Account3::" + uuid))).thenReturn(null); when(accounts.getByAccountIdentifier(uuid)).thenReturn(Optional.empty()) .thenReturn(Optional.of(account)); - when(accounts.create(any())).thenThrow(ContestedOptimisticLockException.class); + when(accounts.create(any(), any())).thenThrow(ContestedOptimisticLockException.class); accountsManager.update(account, a -> { }); @@ -930,14 +932,15 @@ class AccountsManagerTest { @Test void testCreateFreshAccount() throws InterruptedException { - when(accounts.create(any())).thenReturn(true); + when(accounts.create(any(), any())).thenReturn(true); final String e164 = "+18005550123"; final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true, null); - accountsManager.create(e164, "password", null, attributes, new ArrayList<>()); - verify(accounts).create(argThat(account -> e164.equals(account.getNumber()))); - verifyNoInteractions(keysManager); + createAccount(e164, attributes); + + verify(accounts).create(argThat(account -> e164.equals(account.getNumber())), any()); + verifyNoInteractions(messagesManager); verifyNoInteractions(profilesManager); } @@ -946,22 +949,23 @@ class AccountsManagerTest { void testReregisterAccount() throws InterruptedException { final UUID existingUuid = UUID.randomUUID(); - when(accounts.create(any())).thenAnswer(invocation -> { + when(accounts.create(any(), any())).thenAnswer(invocation -> { invocation.getArgument(0, Account.class).setUuid(existingUuid); return false; }); final String e164 = "+18005550123"; final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true, null); - accountsManager.create(e164, "password", null, attributes, new ArrayList<>()); + + createAccount(e164, attributes); assertTrue(phoneNumberIdentifiersByE164.containsKey(e164)); verify(accounts) - .create(argThat(account -> e164.equals(account.getNumber()) && existingUuid.equals(account.getUuid()))); + .create(argThat(account -> e164.equals(account.getNumber()) && existingUuid.equals(account.getUuid())), any()); - verify(keysManager).delete(existingUuid); - verify(keysManager).delete(phoneNumberIdentifiersByE164.get(e164)); + verify(keysManager).delete(existingUuid, true); + verify(keysManager).delete(phoneNumberIdentifiersByE164.get(e164), true); verify(messagesManager).clear(existingUuid); verify(profilesManager).deleteAll(existingUuid); verify(clientPresenceManager).disconnectAllPresencesForUuid(existingUuid); @@ -972,14 +976,17 @@ class AccountsManagerTest { final UUID recentlyDeletedUuid = UUID.randomUUID(); when(accounts.findRecentlyDeletedAccountIdentifier(anyString())).thenReturn(Optional.of(recentlyDeletedUuid)); - when(accounts.create(any())).thenReturn(true); + when(accounts.create(any(), any())).thenReturn(true); final String e164 = "+18005550123"; final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true, null); - accountsManager.create(e164, "password", null, attributes, new ArrayList<>()); + + createAccount(e164, attributes); verify(accounts).create( - argThat(account -> e164.equals(account.getNumber()) && recentlyDeletedUuid.equals(account.getUuid()))); + argThat(account -> e164.equals(account.getNumber()) && recentlyDeletedUuid.equals(account.getUuid())), + any()); + verifyNoInteractions(keysManager); verifyNoInteractions(messagesManager); verifyNoInteractions(profilesManager); @@ -989,7 +996,7 @@ class AccountsManagerTest { @ValueSource(booleans = {true, false}) void testCreateWithDiscoverability(final boolean discoverable) throws InterruptedException { final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, discoverable, null); - final Account account = accountsManager.create("+18005550123", "password", null, attributes, new ArrayList<>()); + final Account account = createAccount("+18005550123", attributes); assertEquals(discoverable, account.isDiscoverableByPhoneNumber()); } @@ -1000,7 +1007,7 @@ class AccountsManagerTest { final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true, new DeviceCapabilities(hasStorage, false, false, false)); - final Account account = accountsManager.create("+18005550123", "password", null, attributes, new ArrayList<>()); + final Account account = createAccount("+18005550123", attributes); assertEquals(hasStorage, account.isStorageSupported()); } @@ -1572,4 +1579,23 @@ class AccountsManagerTest { return device; } + + private Account createAccount(final String e164, final AccountAttributes accountAttributes) throws InterruptedException { + final ECKeyPair aciKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniKeyPair = Curve.generateKeyPair(); + + return accountsManager.create(e164, + "password", + null, + accountAttributes, + new ArrayList<>(), + new IdentityKey(aciKeyPair.getPublicKey()), + new IdentityKey(pniKeyPair.getPublicKey()), + KeysHelper.signedECPreKey(1, aciKeyPair), + KeysHelper.signedECPreKey(2, pniKeyPair), + KeysHelper.signedKEMPreKey(3, aciKeyPair), + KeysHelper.signedKEMPreKey(4, pniKeyPair), + Optional.empty(), + Optional.empty()); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java index 22719f827..2eb649a91 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java @@ -25,6 +25,7 @@ import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -39,13 +40,13 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.mockito.Mockito; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; -import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; +import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; @@ -70,7 +71,11 @@ class AccountsManagerUsernameIntegrationTest { Tables.USERNAMES, Tables.DELETED_ACCOUNTS, Tables.PNI, - Tables.PNI_ASSIGNMENTS); + Tables.PNI_ASSIGNMENTS, + Tables.EC_KEYS, + Tables.PQ_KEYS, + Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS, + Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS); @RegisterExtension static RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); @@ -91,6 +96,14 @@ class AccountsManagerUsernameIntegrationTest { DynamicConfiguration dynamicConfiguration = new DynamicConfiguration(); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); + final KeysManager keysManager = new KeysManager( + DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), + Tables.EC_KEYS.tableName(), + Tables.PQ_KEYS.tableName(), + Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(), + Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName(), + dynamicConfigurationManager); + accounts = Mockito.spy(new Accounts( DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), @@ -122,12 +135,13 @@ class AccountsManagerUsernameIntegrationTest { final ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class); when(experimentEnrollmentManager.isEnrolled(any(UUID.class), eq(AccountsManager.USERNAME_EXPERIMENT_NAME))) .thenReturn(true); + accountsManager = new AccountsManager( accounts, phoneNumberIdentifiers, CACHE_CLUSTER_EXTENSION.getRedisCluster(), accountLockManager, - mock(KeysManager.class), + keysManager, mock(MessagesManager.class), mock(ProfilesManager.class), mock(SecureStorageClient.class), @@ -141,8 +155,8 @@ class AccountsManagerUsernameIntegrationTest { @Test void testNoUsernames() throws InterruptedException { - Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(), - new ArrayList<>()); + final Account account = AccountsHelper.createAccount(accountsManager, "+18005551111"); + List usernameHashes = List.of(USERNAME_HASH_1, USERNAME_HASH_2); int i = 0; for (byte[] hash : usernameHashes) { @@ -169,8 +183,8 @@ class AccountsManagerUsernameIntegrationTest { @Test void testReserveUsernameSnatched() throws InterruptedException, UsernameHashNotAvailableException { - final Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(), - new ArrayList<>()); + final Account account = AccountsHelper.createAccount(accountsManager, "+18005551111"); + ArrayList usernameHashes = new ArrayList<>(Arrays.asList(USERNAME_HASH_1, USERNAME_HASH_2)); for (byte[] hash : usernameHashes) { DYNAMO_DB_EXTENSION.getDynamoDbClient().putItem(PutItemRequest.builder() @@ -205,10 +219,8 @@ class AccountsManagerUsernameIntegrationTest { } @Test - public void testReserveConfirmClear() - throws InterruptedException, UsernameHashNotAvailableException, UsernameReservationNotFoundException { - Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(), - new ArrayList<>()); + public void testReserveConfirmClear() throws InterruptedException { + Account account = AccountsHelper.createAccount(accountsManager, "+18005551111"); // reserve AccountsManager.UsernameReservation reservation = @@ -236,11 +248,8 @@ class AccountsManagerUsernameIntegrationTest { } @Test - public void testReservationLapsed() - throws InterruptedException, UsernameHashNotAvailableException, UsernameReservationNotFoundException { - - final Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(), - new ArrayList<>()); + public void testReservationLapsed() throws InterruptedException { + final Account account = AccountsHelper.createAccount(accountsManager, "+18005551111"); AccountsManager.UsernameReservation reservation1 = accountsManager.reserveUsernameHash(account, List.of(USERNAME_HASH_1)).join(); @@ -256,8 +265,8 @@ class AccountsManagerUsernameIntegrationTest { .build()); // a different account should be able to reserve it - Account account2 = accountsManager.create("+18005552222", "password", null, new AccountAttributes(), - new ArrayList<>()); + Account account2 = AccountsHelper.createAccount(accountsManager, "+18005552222"); + final AccountsManager.UsernameReservation reservation2 = accountsManager.reserveUsernameHash(account2, List.of(USERNAME_HASH_1)).join(); assertArrayEquals(reservation2.reservedUsernameHash(), USERNAME_HASH_1); @@ -271,8 +280,7 @@ class AccountsManagerUsernameIntegrationTest { @Test void testUsernameSetReserveAnotherClearSetReserved() throws InterruptedException { - Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(), - new ArrayList<>()); + Account account = AccountsHelper.createAccount(accountsManager, "+18005551111"); // Set username hash final AccountsManager.UsernameReservation reservation1 = @@ -303,9 +311,10 @@ class AccountsManagerUsernameIntegrationTest { @Test public void testUsernameLinks() throws InterruptedException { - Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(), new ArrayList<>()); + final Account account = AccountsHelper.createAccount(accountsManager, "+18005551111"); + account.setUsernameHash(RandomUtils.nextBytes(16)); - accounts.create(account); + accounts.create(account, ignored -> Collections.emptyList()); final UUID linkHandle = UUID.randomUUID(); final byte[] encryptedUsername = RandomUtils.nextBytes(32); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java index e5da5aa40..4813bb6e5 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java @@ -21,12 +21,12 @@ import static org.mockito.Mockito.when; import com.fasterxml.jackson.core.JsonProcessingException; import java.nio.charset.StandardCharsets; -import java.security.SecureRandom; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; @@ -114,7 +114,7 @@ class AccountsTest { when(mockDynamicConfigManager.getConfiguration()) .thenReturn(new DynamicConfiguration()); - this.accounts = new Accounts( + accounts = new Accounts( clock, DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), @@ -129,7 +129,7 @@ class AccountsTest { public void testStoreAndLookupUsernameLink() { final Account account = nextRandomAccount(); account.setUsernameHash(RandomUtils.nextBytes(16)); - accounts.create(account); + createAccount(account); final BiConsumer, byte[]> validator = (maybeAccount, expectedEncryptedUsername) -> { assertTrue(maybeAccount.isPresent()); @@ -165,7 +165,7 @@ class AccountsTest { Device device = generateDevice(DEVICE_ID_1); Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device)); - boolean freshUser = accounts.create(account); + boolean freshUser = createAccount(account); assertThat(freshUser).isTrue(); verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true); @@ -173,7 +173,7 @@ class AccountsTest { assertPhoneNumberConstraintExists("+14151112222", account.getUuid()); assertPhoneNumberIdentifierConstraintExists(account.getPhoneNumberIdentifier(), account.getUuid()); - freshUser = accounts.create(account); + freshUser = createAccount(account); assertThat(freshUser).isTrue(); verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true); @@ -188,7 +188,7 @@ class AccountsTest { Device device = generateDevice(DEVICE_ID_1); Account account = generateAccount("+14151112222", originalUuid, UUID.randomUUID(), List.of(device)); - boolean freshUser = accounts.create(account); + boolean freshUser = createAccount(account); assertThat(freshUser).isTrue(); verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true); @@ -199,7 +199,7 @@ class AccountsTest { accounts.delete(originalUuid).join(); assertThat(accounts.findRecentlyDeletedAccountIdentifier(account.getNumber())).hasValue(originalUuid); - freshUser = accounts.create(account); + freshUser = createAccount(account); assertThat(freshUser).isTrue(); verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true); @@ -214,7 +214,7 @@ class AccountsTest { final List devices = List.of(generateDevice(DEVICE_ID_1), generateDevice(DEVICE_ID_2)); final Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), devices); - accounts.create(account); + createAccount(account); verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true); @@ -236,8 +236,8 @@ class AccountsTest { UUID pniSecond = UUID.randomUUID(); Account accountSecond = generateAccount("+14152221111", uuidSecond, pniSecond, devicesSecond); - accounts.create(accountFirst); - accounts.create(accountSecond); + createAccount(accountFirst); + createAccount(accountSecond); Optional retrievedFirst = accounts.getByE164("+14151112222"); Optional retrievedSecond = accounts.getByE164("+14152221111"); @@ -340,7 +340,7 @@ class AccountsTest { UUID firstUuid = UUID.randomUUID(); UUID firstPni = UUID.randomUUID(); Account account = generateAccount("+14151112222", firstUuid, firstPni, List.of(device)); - accounts.create(account); + createAccount(account); final byte[] usernameHash = randomBytes(32); final byte[] encryptedUsername = randomBytes(32); @@ -358,7 +358,7 @@ class AccountsTest { // simulate a failed re-reg: we give the account a reclaimable username, but we'll try // re-registering again later in the test case account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1))); - accounts.create(account); + createAccount(account); break; case CONFIRMED: accounts.reserveUsernameHash(account, usernameHash, Duration.ofMinutes(1)).join(); @@ -370,7 +370,7 @@ class AccountsTest { // re-register the account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1))); - accounts.create(account); + createAccount(account); // If we had a username link, or we had previously saved a username link from another re-registration, make sure // we preserve it @@ -401,7 +401,7 @@ class AccountsTest { UUID firstPni = UUID.randomUUID(); Account account = generateAccount("+14151112222", firstUuid, firstPni, List.of(device)); - accounts.create(account); + createAccount(account); final byte[] usernameHash = randomBytes(32); final byte[] encryptedUsername = randomBytes(16); @@ -422,7 +422,7 @@ class AccountsTest { device = generateDevice(DEVICE_ID_1); account = generateAccount("+14151112222", secondUuid, UUID.randomUUID(), List.of(device)); - final boolean freshUser = accounts.create(account); + final boolean freshUser = createAccount(account); assertThat(freshUser).isFalse(); // usernameHash should be unset verifyStoredState("+14151112222", firstUuid, firstPni, null, account, true); @@ -453,7 +453,8 @@ class AccountsTest { device = generateDevice(DEVICE_ID_1); Account invalidAccount = generateAccount("+14151113333", firstUuid, UUID.randomUUID(), List.of(device)); - assertThatThrownBy(() -> accounts.create(invalidAccount)); + + assertThatThrownBy(() -> createAccount(invalidAccount)); } @Test @@ -461,7 +462,7 @@ class AccountsTest { Device device = generateDevice(DEVICE_ID_1); Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device)); - accounts.create(account); + createAccount(account); assertPhoneNumberConstraintExists("+14151112222", account.getUuid()); assertPhoneNumberIdentifierConstraintExists(account.getPhoneNumberIdentifier(), account.getUuid()); @@ -511,8 +512,11 @@ class AccountsTest { final DynamoDbAsyncClient dynamoDbAsyncClient = mock(DynamoDbAsyncClient.class); accounts = new Accounts(mock(DynamoDbClient.class), - dynamoDbAsyncClient, Tables.ACCOUNTS.tableName(), - Tables.NUMBERS.tableName(), Tables.PNI_ASSIGNMENTS.tableName(), Tables.USERNAMES.tableName(), + dynamoDbAsyncClient, + Tables.ACCOUNTS.tableName(), + Tables.NUMBERS.tableName(), + Tables.PNI_ASSIGNMENTS.tableName(), + Tables.USERNAMES.tableName(), Tables.DELETED_ACCOUNTS.tableName()); Exception e = TransactionConflictException.builder().build(); @@ -533,7 +537,7 @@ class AccountsTest { for (int i = 1; i <= 100; i++) { final Account account = generateAccount("+1" + String.format("%03d", i), UUID.randomUUID(), UUID.randomUUID()); expectedAccounts.add(account); - accounts.create(account); + createAccount(account); } final List retrievedAccounts = @@ -553,8 +557,8 @@ class AccountsTest { final Account retainedAccount = generateAccount("+14151112345", UUID.randomUUID(), UUID.randomUUID(), List.of(retainedDevice)); - accounts.create(deletedAccount); - accounts.create(retainedAccount); + createAccount(deletedAccount); + createAccount(retainedAccount); assertThat(accounts.findRecentlyDeletedAccountIdentifier(deletedAccount.getNumber())).isEmpty(); @@ -581,7 +585,7 @@ class AccountsTest { final Account recreatedAccount = generateAccount(deletedAccount.getNumber(), UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1))); - final boolean freshUser = accounts.create(recreatedAccount); + final boolean freshUser = createAccount(recreatedAccount); assertThat(freshUser).isTrue(); assertThat(accounts.getByAccountIdentifier(recreatedAccount.getUuid())).isPresent(); @@ -598,7 +602,7 @@ class AccountsTest { Device device = generateDevice(DEVICE_ID_1); Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device)); - accounts.create(account); + createAccount(account); Optional retrieved = accounts.getByE164("+11111111"); assertThat(retrieved.isPresent()).isFalse(); @@ -614,7 +618,7 @@ class AccountsTest { final Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1))); - accounts.create(account); + createAccount(account); assertThat(accounts.getByAccountIdentifierAsync(account.getUuid()).join()).isPresent(); } @@ -626,7 +630,7 @@ class AccountsTest { final Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1))); - accounts.create(account); + createAccount(account); assertThat(accounts.getByPhoneNumberIdentifierAsync(account.getPhoneNumberIdentifier()).join()).isPresent(); } @@ -640,7 +644,7 @@ class AccountsTest { final Account account = generateAccount(e164, UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1))); - accounts.create(account); + createAccount(account); assertThat(accounts.getByE164Async(e164).join()).isPresent(); } @@ -650,7 +654,7 @@ class AccountsTest { Device device = generateDevice(DEVICE_ID_1); Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device)); account.setDiscoverableByPhoneNumber(false); - accounts.create(account); + createAccount(account); verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, false); account.setDiscoverableByPhoneNumber(true); accounts.update(account); @@ -673,7 +677,7 @@ class AccountsTest { final Device device = generateDevice(DEVICE_ID_1); final Account account = generateAccount(originalNumber, UUID.randomUUID(), originalPni, List.of(device)); - accounts.create(account); + createAccount(account); assertThat(accounts.getByPhoneNumberIdentifier(originalPni)).isPresent(); @@ -731,8 +735,8 @@ class AccountsTest { final Device device = generateDevice(DEVICE_ID_1); final Account account = generateAccount(originalNumber, UUID.randomUUID(), originalPni, List.of(device)); - accounts.create(account); - accounts.create(existingAccount); + createAccount(account); + createAccount(existingAccount); assertThrows(TransactionCanceledException.class, () -> accounts.changeNumber(account, targetNumber, targetPni, Optional.of(existingAccount.getUuid()))); @@ -750,7 +754,7 @@ class AccountsTest { final Device device = generateDevice(DEVICE_ID_1); final Account account = generateAccount(originalNumber, UUID.randomUUID(), UUID.randomUUID(), List.of(device)); - accounts.create(account); + createAccount(account); final UUID existingAccountIdentifier = UUID.randomUUID(); final UUID existingPhoneNumberIdentifier = UUID.randomUUID(); @@ -776,7 +780,7 @@ class AccountsTest { @Test void testSwitchUsernameHashes() { final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID()); - accounts.create(account); + createAccount(account); assertThat(accounts.getByUsernameHash(USERNAME_HASH_1).join()).isEmpty(); @@ -822,8 +826,8 @@ class AccountsTest { final Account firstAccount = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID()); final Account secondAccount = generateAccount("+18005559876", UUID.randomUUID(), UUID.randomUUID()); - accounts.create(firstAccount); - accounts.create(secondAccount); + createAccount(firstAccount); + createAccount(secondAccount); // first account reserves and confirms username hash assertThatNoException().isThrownBy(() -> { @@ -855,7 +859,7 @@ class AccountsTest { @Test void testConfirmUsernameHashVersionMismatch() { final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID()); - accounts.create(account); + createAccount(account); accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(1)).join(); account.setVersion(account.getVersion() + 77); @@ -868,7 +872,7 @@ class AccountsTest { @Test void testClearUsername() { final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID()); - accounts.create(account); + createAccount(account); accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(1)).join(); accounts.confirmUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1).join(); @@ -888,7 +892,7 @@ class AccountsTest { @Test void testClearUsernameNoUsername() { final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID()); - accounts.create(account); + createAccount(account); assertThatNoException().isThrownBy(() -> accounts.clearUsernameHash(account).join()); } @@ -896,7 +900,7 @@ class AccountsTest { @Test void testClearUsernameVersionMismatch() { final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID()); - accounts.create(account); + createAccount(account); accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(1)).join(); accounts.confirmUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1).join(); @@ -912,9 +916,9 @@ class AccountsTest { @Test void testReservedUsernameHash() { final Account account1 = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID()); - accounts.create(account1); + createAccount(account1); final Account account2 = generateAccount("+18005552222", UUID.randomUUID(), UUID.randomUUID()); - accounts.create(account2); + createAccount(account2); accounts.reserveUsernameHash(account1, USERNAME_HASH_1, Duration.ofDays(1)).join(); assertArrayEquals(account1.getReservedUsernameHash().orElseThrow(), USERNAME_HASH_1); @@ -946,7 +950,7 @@ class AccountsTest { @Test void testUsernameHashAvailable() { final Account account1 = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID()); - accounts.create(account1); + createAccount(account1); accounts.reserveUsernameHash(account1, USERNAME_HASH_1, Duration.ofDays(1)).join(); assertThat(accounts.usernameHashAvailable(USERNAME_HASH_1).join()).isFalse(); @@ -964,9 +968,9 @@ class AccountsTest { @Test void testConfirmReservedUsernameHashWrongAccountUuid() { final Account account1 = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID()); - accounts.create(account1); + createAccount(account1); final Account account2 = generateAccount("+18005552222", UUID.randomUUID(), UUID.randomUUID()); - accounts.create(account2); + createAccount(account2); accounts.reserveUsernameHash(account1, USERNAME_HASH_1, Duration.ofDays(1)).join(); assertArrayEquals(account1.getReservedUsernameHash().orElseThrow(), USERNAME_HASH_1); @@ -980,9 +984,9 @@ class AccountsTest { @Test void testConfirmExpiredReservedUsernameHash() { final Account account1 = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID()); - accounts.create(account1); + createAccount(account1); final Account account2 = generateAccount("+18005552222", UUID.randomUUID(), UUID.randomUUID()); - accounts.create(account2); + createAccount(account2); accounts.reserveUsernameHash(account1, USERNAME_HASH_1, Duration.ofDays(2)).join(); @@ -1009,7 +1013,7 @@ class AccountsTest { @Test void testRetryReserveUsernameHash() { final Account account = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID()); - accounts.create(account); + createAccount(account); accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(2)).join(); CompletableFutureTestUtil.assertFailsWithCause(ContestedOptimisticLockException.class, @@ -1020,7 +1024,7 @@ class AccountsTest { @Test void testReserveConfirmUsernameHashVersionConflict() { final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID()); - accounts.create(account); + createAccount(account); account.setVersion(account.getVersion() + 12); CompletableFutureTestUtil.assertFailsWithCause(ContestedOptimisticLockException.class, accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(1))); @@ -1035,7 +1039,7 @@ class AccountsTest { final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID()); account.setUsernameHash(RandomUtils.nextBytes(32)); account.setUsernameLinkDetails(UUID.randomUUID(), RandomUtils.nextBytes(32)); - accounts.create(account); + createAccount(account); final Map accountRecord = DYNAMO_DB_EXTENSION.getDynamoDbClient() .getItem(GetItemRequest.builder() .tableName(Tables.ACCOUNTS.tableName()) @@ -1053,7 +1057,7 @@ class AccountsTest { assertThat(accounts.getByUsernameHash(USERNAME_HASH_1).join()).isEmpty(); final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID()); - accounts.create(account); + createAccount(account); assertThat(accounts.getByUsernameHash(USERNAME_HASH_1).join()).isEmpty(); @@ -1069,7 +1073,7 @@ class AccountsTest { final Device device2 = generateDevice((byte) 64); account.addDevice(device2); - accounts.create(account); + createAccount(account); final GetItemResponse response = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient().getItem(GetItemRequest.builder() .tableName(Tables.ACCOUNTS.tableName()) @@ -1108,6 +1112,10 @@ class AccountsTest { return DevicesHelper.createDevice(id); } + private boolean createAccount(final Account account) { + return accounts.create(account, ignored -> Collections.emptyList()); + } + private static Account nextRandomAccount() { final String nextNumber = "+1800%07d".formatted(ACCOUNT_COUNTER.getAndIncrement()); return generateAccount(nextNumber, UUID.randomUUID(), UUID.randomUUID()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java index 3488b4abd..17fbd3152 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java @@ -5,14 +5,16 @@ package org.whispersystems.textsecuregcm.storage; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.whispersystems.textsecuregcm.entities.SignedPreKey; import java.util.Map; import java.util.Optional; import java.util.UUID; -import org.junit.jupiter.api.Test; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; + +import static org.junit.jupiter.api.Assertions.*; abstract class RepeatedUseSignedPreKeyStoreTest> { @@ -50,38 +52,47 @@ abstract class RepeatedUseSignedPreKeyStoreTest> { } @Test - void delete() { + void deleteForDevice() { final RepeatedUseSignedPreKeyStore keys = getKeyStore(); - assertDoesNotThrow(() -> keys.delete(UUID.randomUUID()).join()); - + final UUID identifier = UUID.randomUUID(); final byte deviceId2 = 2; - { - final UUID identifier = UUID.randomUUID(); - final Map signedPreKeys = Map.of( - Device.PRIMARY_ID, generateSignedPreKey(), - deviceId2, generateSignedPreKey() - ); + final Map signedPreKeys = Map.of( + Device.PRIMARY_ID, generateSignedPreKey(), + deviceId2, generateSignedPreKey() + ); - keys.store(identifier, signedPreKeys).join(); - keys.delete(identifier, Device.PRIMARY_ID).join(); + keys.store(identifier, signedPreKeys).join(); + keys.delete(identifier, Device.PRIMARY_ID).join(); + assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join()); + assertEquals(Optional.of(signedPreKeys.get(deviceId2)), keys.find(identifier, deviceId2).join()); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void deleteForAllDevices(final boolean excludePrimaryDevice) { + final RepeatedUseSignedPreKeyStore keys = getKeyStore(); + + assertDoesNotThrow(() -> keys.delete(UUID.randomUUID(), excludePrimaryDevice).join()); + + final byte deviceId2 = Device.PRIMARY_ID + 1; + + final UUID identifier = UUID.randomUUID(); + final Map signedPreKeys = Map.of( + Device.PRIMARY_ID, generateSignedPreKey(), + deviceId2, generateSignedPreKey() + ); + + keys.store(identifier, signedPreKeys).join(); + keys.delete(identifier, excludePrimaryDevice).join(); + + if (excludePrimaryDevice) { + assertTrue(keys.find(identifier, Device.PRIMARY_ID).join().isPresent()); + } else { assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join()); - assertEquals(Optional.of(signedPreKeys.get(deviceId2)), keys.find(identifier, deviceId2).join()); } - { - final UUID identifier = UUID.randomUUID(); - final Map signedPreKeys = Map.of( - Device.PRIMARY_ID, generateSignedPreKey(), - deviceId2, generateSignedPreKey() - ); - - keys.store(identifier, signedPreKeys).join(); - keys.delete(identifier).join(); - - assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join()); - assertEquals(Optional.empty(), keys.find(identifier, deviceId2).join()); - } + assertEquals(Optional.empty(), keys.find(identifier, deviceId2).join()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java index 1efdbd6d0..3d942ea95 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java @@ -15,13 +15,19 @@ import static org.mockito.Mockito.when; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; +import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.UUID; import java.util.function.Consumer; import org.mockito.MockingDetails; import org.mockito.stubbing.Stubbing; +import org.signal.libsignal.protocol.IdentityKey; +import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; +import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; @@ -150,4 +156,30 @@ public class AccountsHelper { return argThat(other -> other.getUuid().equals(value.getUuid())); } + public static Account createAccount(final AccountsManager accountsManager, final String e164) + throws InterruptedException { + + return createAccount(accountsManager, e164, new AccountAttributes()); + } + + public static Account createAccount(final AccountsManager accountsManager, final String e164, final AccountAttributes accountAttributes) + throws InterruptedException { + + final ECKeyPair aciKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniKeyPair = Curve.generateKeyPair(); + + return accountsManager.create(e164, + "password", + null, + accountAttributes, + new ArrayList<>(), + new IdentityKey(aciKeyPair.getPublicKey()), + new IdentityKey(pniKeyPair.getPublicKey()), + KeysHelper.signedECPreKey(1, aciKeyPair), + KeysHelper.signedECPreKey(2, pniKeyPair), + KeysHelper.signedKEMPreKey(3, aciKeyPair), + KeysHelper.signedKEMPreKey(4, pniKeyPair), + Optional.empty(), + Optional.empty()); + } }