Create accounts transactionally

This commit is contained in:
Jon Chambers 2023-11-13 13:05:29 -05:00 committed by Jon Chambers
parent 07c04006df
commit c8033f875d
16 changed files with 854 additions and 265 deletions

View File

@ -342,8 +342,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName());
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getProfiles().getTableName());
KeysManager keys = new KeysManager(
dynamoDbAsyncClient,
KeysManager keysManager = new KeysManager(
dynamoDbAsyncClient,
config.getDynamoDbTables().getEcKeys().getTableName(),
config.getDynamoDbTables().getKemKeys().getTableName(),
config.getDynamoDbTables().getEcSignedPreKeys().getTableName(),
@ -525,7 +525,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
AccountLockManager accountLockManager = new AccountLockManager(dynamoDbClient,
config.getDynamoDbTables().getDeletedAccountsLock().getTableName());
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
accountLockManager, keys, messagesManager, profilesManager,
accountLockManager, keysManager, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client,
clientPresenceManager,
experimentEnrollmentManager, registrationRecoveryPasswordsManager, accountLockExecutor, clock);
@ -669,8 +669,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.addService(new AccountsAnonymousGrpcService(accountsManager, rateLimiters))
.addService(ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters))
.addService(ExternalServiceCredentialsAnonymousGrpcService.create(accountsManager, config))
.addService(ServerInterceptors.intercept(new KeysGrpcService(accountsManager, keys, rateLimiters), basicCredentialAuthenticationInterceptor))
.addService(new KeysAnonymousGrpcService(accountsManager, keys))
.addService(ServerInterceptors.intercept(new KeysGrpcService(accountsManager, keysManager, rateLimiters), basicCredentialAuthenticationInterceptor))
.addService(new KeysAnonymousGrpcService(accountsManager, keysManager))
.addService(new PaymentsGrpcService(currencyManager))
.addService(ServerInterceptors.intercept(new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager,
config.getBadges(), asyncCdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, rateLimiters, zkProfileOperations, config.getCdnConfiguration().bucket()), basicCredentialAuthenticationInterceptor))
@ -725,7 +725,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
turnTokenGenerator,
registrationRecoveryPasswordsManager, usernameHashZkProofVerifier));
environment.jersey().register(new KeysController(rateLimiters, keys, accountsManager));
environment.jersey().register(new KeysController(rateLimiters, keysManager, accountsManager));
boolean registeredSpamFilter = false;
ReportSpamTokenProvider reportSpamTokenProvider = null;
@ -784,7 +784,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new CallLinkController(rateLimiters, callingGenericZkSecretParams),
new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().certificate().value(), config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()), zkAuthOperations, callingGenericZkSecretParams, clock),
new ChallengeController(rateLimitChallengeManager),
new DeviceController(config.getLinkDeviceSecretConfiguration().secret().value(), accountsManager, messagesManager, keys, rateLimiters,
new DeviceController(config.getLinkDeviceSecretConfiguration().secret().value(), accountsManager, messagesManager, keysManager, rateLimiters,
rateLimitersCluster, config.getMaxDevices(), clock),
new DirectoryV2Controller(directoryV2CredentialsGenerator),
new DonationController(clock, zkReceiptOperations, redeemedReceiptsManager, accountsManager, config.getBadges(),
@ -799,7 +799,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getCdnConfiguration().bucket(), zkProfileOperations, batchIdentityCheckExecutor),
new ProvisioningController(rateLimiters, provisioningManager),
new RegistrationController(accountsManager, phoneVerificationTokenManager, registrationLockVerificationManager,
keys, rateLimiters),
rateLimiters),
new RemoteConfigController(remoteConfigsManager, adminEventLogger,
config.getRemoteConfigConfiguration().authorizedUsers(),
config.getRemoteConfigConfiguration().requiredHostedDomain(),

View File

@ -20,9 +20,7 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import javax.ws.rs.Consumes;
@ -45,8 +43,6 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;
@ -69,18 +65,16 @@ public class RegistrationController {
private final AccountsManager accounts;
private final PhoneVerificationTokenManager phoneVerificationTokenManager;
private final RegistrationLockVerificationManager registrationLockVerificationManager;
private final KeysManager keysManager;
private final RateLimiters rateLimiters;
public RegistrationController(final AccountsManager accounts,
final PhoneVerificationTokenManager phoneVerificationTokenManager,
final RegistrationLockVerificationManager registrationLockVerificationManager,
final KeysManager keysManager,
final RateLimiters rateLimiters) {
this.accounts = accounts;
this.phoneVerificationTokenManager = phoneVerificationTokenManager;
this.registrationLockVerificationManager = registrationLockVerificationManager;
this.keysManager = keysManager;
this.rateLimiters = rateLimiters;
}
@ -141,37 +135,19 @@ public class RegistrationController {
userAgent, RegistrationLockVerificationManager.Flow.REGISTRATION, verificationType);
}
Account account = accounts.create(number, password, signalAgent, registrationRequest.accountAttributes(),
existingAccount.map(Account::getBadges).orElseGet(ArrayList::new));
account = accounts.update(account, a -> {
a.setIdentityKey(registrationRequest.aciIdentityKey());
a.setPhoneNumberIdentityKey(registrationRequest.pniIdentityKey());
final Device device = a.getPrimaryDevice().orElseThrow();
device.setSignedPreKey(registrationRequest.deviceActivationRequest().aciSignedPreKey());
device.setPhoneNumberIdentitySignedPreKey(registrationRequest.deviceActivationRequest().pniSignedPreKey());
registrationRequest.deviceActivationRequest().apnToken().ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId());
device.setVoipApnId(apnRegistrationId.voipRegistrationId());
});
registrationRequest.deviceActivationRequest().gcmToken().ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
CompletableFuture.allOf(
keysManager.storeEcSignedPreKeys(a.getUuid(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().aciSignedPreKey())),
keysManager.storePqLastResort(a.getUuid(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().aciPqLastResortPreKey())),
keysManager.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().pniSignedPreKey())),
keysManager.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().pniPqLastResortPreKey())))
.join();
});
final Account account = accounts.create(number,
password,
signalAgent,
registrationRequest.accountAttributes(),
existingAccount.map(Account::getBadges).orElseGet(ArrayList::new),
registrationRequest.aciIdentityKey(),
registrationRequest.pniIdentityKey(),
registrationRequest.deviceActivationRequest().aciSignedPreKey(),
registrationRequest.deviceActivationRequest().pniSignedPreKey(),
registrationRequest.deviceActivationRequest().aciPqLastResortPreKey(),
registrationRequest.deviceActivationRequest().pniPqLastResortPreKey(),
registrationRequest.deviceActivationRequest().apnToken(),
registrationRequest.deviceActivationRequest().gcmToken());
Metrics.counter(ACCOUNT_CREATED_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)),

View File

@ -19,6 +19,7 @@ import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -28,6 +29,7 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
@ -157,6 +159,7 @@ public class Accounts extends AbstractDynamoDbStore {
final String phoneNumberIdentifierConstraintTableName,
final String usernamesConstraintTableName,
final String deletedAccountsTableName) {
super(client);
this.clock = clock;
this.asyncClient = asyncClient;
@ -175,12 +178,14 @@ public class Accounts extends AbstractDynamoDbStore {
final String phoneNumberIdentifierConstraintTableName,
final String usernamesConstraintTableName,
final String deletedAccountsTableName) {
this(Clock.systemUTC(), client, asyncClient, accountsTableName,
phoneNumberConstraintTableName, phoneNumberIdentifierConstraintTableName, usernamesConstraintTableName,
deletedAccountsTableName);
}
public boolean create(final Account account) {
public boolean create(final Account account, final Function<Account, Collection<TransactWriteItem>> additionalWriteItemsFunction) {
return CREATE_TIMER.record(() -> {
try {
final AttributeValue uuidAttr = AttributeValues.fromUUID(account.getUuid());
@ -199,8 +204,13 @@ public class Accounts extends AbstractDynamoDbStore {
// the newly-created account.
final TransactWriteItem deletedAccountDelete = buildRemoveDeletedAccount(account.getNumber());
final Collection<TransactWriteItem> writeItems = new ArrayList<>(
List.of(phoneNumberConstraintPut, phoneNumberIdentifierConstraintPut, accountPut, deletedAccountDelete));
writeItems.addAll(additionalWriteItemsFunction.apply(account));
final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder()
.transactItems(phoneNumberConstraintPut, phoneNumberIdentifierConstraintPut, accountPut, deletedAccountDelete)
.transactItems(writeItems)
.build();
try {
@ -229,7 +239,8 @@ public class Accounts extends AbstractDynamoDbStore {
account.setUuid(UUIDUtil.fromByteBuffer(actualAccountUuid));
final Account existingAccount = getByAccountIdentifier(account.getUuid()).orElseThrow();
account.setNumber(existingAccount.getNumber(), existingAccount.getPhoneNumberIdentifier());
joinAndUnwrapUpdateFuture(reclaimAccount(existingAccount, account));
joinAndUnwrapUpdateFuture(reclaimAccount(existingAccount, account, additionalWriteItemsFunction.apply(account)));
return false;
}
@ -254,7 +265,7 @@ public class Accounts extends AbstractDynamoDbStore {
* @param existingAccount the existing account in the accounts table
* @param accountToCreate a new account, with the same number and identifier as existingAccount
*/
private CompletionStage<Void> reclaimAccount(final Account existingAccount, final Account accountToCreate) {
private CompletionStage<Void> reclaimAccount(final Account existingAccount, final Account accountToCreate, final Collection<TransactWriteItem> additionalWriteItems) {
if (!existingAccount.getUuid().equals(accountToCreate.getUuid()) ||
!existingAccount.getNumber().equals(accountToCreate.getNumber())) {
throw new IllegalArgumentException("reclaimed accounts must match");
@ -310,6 +321,7 @@ public class Accounts extends AbstractDynamoDbStore {
.build());
}
writeItems.add(UpdateAccountSpec.forAccount(accountsTableName, accountToCreate).transactItem());
writeItems.addAll(additionalWriteItems);
return asyncClient.transactWriteItems(TransactWriteItemsRequest.builder().transactItems(writeItems).build())
.thenApply(response -> {

View File

@ -53,7 +53,9 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
@ -175,17 +177,26 @@ public class AccountsManager {
this.clock = requireNonNull(clock);
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public Account create(final String number,
final String password,
final String signalAgent,
final AccountAttributes accountAttributes,
final List<AccountBadge> accountBadges) throws InterruptedException {
final List<AccountBadge> accountBadges,
final IdentityKey aciIdentityKey,
final IdentityKey pniIdentityKey,
final ECSignedPreKey aciSignedPreKey,
final ECSignedPreKey pniSignedPreKey,
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniPqLastResortPreKey,
final Optional<ApnRegistrationId> maybeApnRegistrationId,
final Optional<GcmRegistrationId> maybeGcmRegistrationId) throws InterruptedException {
try (Timer.Context ignored = createTimer.time()) {
final Account account = new Account();
accountLockManager.withLock(List.of(number), () -> {
Device device = new Device();
final Device device = new Device();
device.setId(Device.PRIMARY_ID);
device.setAuthTokenHash(SaltedTokenHash.generateFor(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
@ -196,6 +207,16 @@ public class AccountsManager {
device.setCreated(System.currentTimeMillis());
device.setLastSeen(Util.todayInMillis());
device.setUserAgent(signalAgent);
device.setSignedPreKey(aciSignedPreKey);
device.setPhoneNumberIdentitySignedPreKey(pniSignedPreKey);
maybeApnRegistrationId.ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId());
device.setVoipApnId(apnRegistrationId.voipRegistrationId());
});
maybeGcmRegistrationId.ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
account.setNumber(number, phoneNumberIdentifiers.getPhoneNumberIdentifier(number));
@ -205,6 +226,8 @@ public class AccountsManager {
// Reuse the ACI from any recently-deleted account with this number to cover cases where somebody is
// re-registering.
account.setUuid(maybeRecentlyDeletedAccountIdentifier.orElseGet(UUID::randomUUID));
account.setIdentityKey(aciIdentityKey);
account.setPhoneNumberIdentityKey(pniIdentityKey);
account.addDevice(device);
account.setRegistrationLockFromAttributes(accountAttributes);
account.setUnidentifiedAccessKey(accountAttributes.getUnidentifiedAccessKey());
@ -214,7 +237,14 @@ public class AccountsManager {
final UUID originalUuid = account.getUuid();
boolean freshUser = accounts.create(account);
final boolean freshUser = accounts.create(account,
a -> keysManager.buildWriteItemsForRepeatedUseKeys(a.getIdentifier(IdentityType.ACI),
a.getIdentifier(IdentityType.PNI),
Device.PRIMARY_ID,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey));
// create() sometimes updates the UUID, if there was a number conflict.
// for metrics, we want secondary to run with the same original UUID
@ -235,9 +265,11 @@ public class AccountsManager {
// confident that everything has already been deleted. In the second case, though, we're taking over an existing
// account and need to clear out messages and keys that may have been stored for the old account.
if (!originalUuid.equals(actualUuid)) {
// We exclude the primary device's repeated-use keys from deletion because new keys were provided as part of
// the account creation process, and we don't want to delete the keys that just got added.
final CompletableFuture<Void> deleteKeysFuture = CompletableFuture.allOf(
keysManager.delete(actualUuid),
keysManager.delete(account.getPhoneNumberIdentifier()));
keysManager.delete(actualUuid, true),
keysManager.delete(account.getPhoneNumberIdentifier(), true));
messagesManager.clear(actualUuid).join();
profilesManager.deleteAll(actualUuid).join();

View File

@ -18,6 +18,7 @@ import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
public class KeysManager {
@ -75,6 +76,20 @@ public class KeysManager {
return CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0]));
}
public List<TransactWriteItem> buildWriteItemsForRepeatedUseKeys(final UUID accountIdentifier,
final UUID phoneNumberIdentifier,
final byte deviceId,
final ECSignedPreKey aciSignedPreKey,
final ECSignedPreKey pniSignedPreKey,
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniLastResortPreKey) {
return List.of(ecSignedPreKeys.buildTransactWriteItem(accountIdentifier, deviceId, aciSignedPreKey),
ecSignedPreKeys.buildTransactWriteItem(phoneNumberIdentifier, deviceId, pniSignedPreKey),
pqLastResortKeys.buildTransactWriteItem(accountIdentifier, deviceId, aciPqLastResortPreKey),
pqLastResortKeys.buildTransactWriteItem(phoneNumberIdentifier, deviceId, pniLastResortPreKey));
}
public CompletableFuture<Void> storeEcSignedPreKeys(final UUID identifier, final Map<Byte, ECSignedPreKey> keys) {
if (dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().storeEcSignedPreKeys()) {
return ecSignedPreKeys.store(identifier, keys);
@ -134,14 +149,18 @@ public class KeysManager {
return pqPreKeys.getCount(identifier, deviceId);
}
public CompletableFuture<Void> delete(final UUID accountUuid) {
public CompletableFuture<Void> delete(final UUID identifier) {
return delete(identifier, false);
}
public CompletableFuture<Void> delete(final UUID identifier, final boolean excludePrimaryDevice) {
return CompletableFuture.allOf(
ecPreKeys.delete(accountUuid),
pqPreKeys.delete(accountUuid),
dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().deleteEcSignedPreKeys()
? ecSignedPreKeys.delete(accountUuid)
: CompletableFuture.completedFuture(null),
pqLastResortKeys.delete(accountUuid));
ecPreKeys.delete(identifier),
pqPreKeys.delete(identifier),
dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().deleteEcSignedPreKeys()
? ecSignedPreKeys.delete(identifier, excludePrimaryDevice)
: CompletableFuture.completedFuture(null),
pqLastResortKeys.delete(identifier, excludePrimaryDevice));
}
public CompletableFuture<Void> delete(final UUID accountUuid, final byte deviceId) {

View File

@ -112,6 +112,15 @@ public abstract class RepeatedUseSignedPreKeyStore<K extends SignedPreKey<?>> {
.thenRun(() -> sample.stop(storeKeyBatchTimer));
}
TransactWriteItem buildTransactWriteItem(final UUID identifier, final byte deviceId, final K preKey) {
return TransactWriteItem.builder()
.put(Put.builder()
.tableName(tableName)
.item(getItemFromPreKey(identifier, deviceId, preKey))
.build())
.build();
}
/**
* Finds a repeated-use pre-key for a specific device.
*
@ -142,14 +151,19 @@ public abstract class RepeatedUseSignedPreKeyStore<K extends SignedPreKey<?>> {
* Clears all repeated-use pre-keys associated with the given account/identity.
*
* @param identifier the identifier for the account/identity for which to clear repeated-use pre-keys
* @param excludePrimaryDevice whether to exclude the primary device from repeated-use key deletion; this is intended
* for cases when a user "re-registers" and displaces an existing account record and has
* provided new repeated-use keys for the primary device in the process of creating the
* new account
*
* @return a future that completes once repeated-use pre-keys have been cleared from all devices associated with the
* target account/identity
*/
public CompletableFuture<Void> delete(final UUID identifier) {
public CompletableFuture<Void> delete(final UUID identifier, final boolean excludePrimaryDevice) {
final Timer.Sample sample = Timer.start();
return getDeviceIdsWithKeys(identifier)
.filter(deviceId -> deviceId != Device.PRIMARY_ID || !excludePrimaryDevice)
.map(deviceId -> DeleteItemRequest.builder()
.tableName(tableName)
.key(getPrimaryKey(identifier, deviceId))

View File

@ -55,7 +55,6 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.signal.libsignal.usernames.BaseUsernameException;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
@ -215,15 +214,6 @@ class AccountControllerTest {
when(accountsManager.getByE164(eq(SENDER_PREAUTH))).thenReturn(Optional.empty());
when(accountsManager.getByE164(eq(SENDER_HAS_STORAGE))).thenReturn(Optional.of(senderHasStorage));
when(accountsManager.getByE164(eq(SENDER_TRANSFER))).thenReturn(Optional.of(senderTransfer));
when(accountsManager.create(any(), any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> {
final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.getNumber()).thenReturn(invocation.getArgument(0, String.class));
when(account.getBadges()).thenReturn(invocation.getArgument(4, List.class));
return account;
});
}
@AfterEach

View File

@ -7,10 +7,11 @@ package org.whispersystems.textsecuregcm.controllers;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@ -21,10 +22,12 @@ import io.dropwizard.testing.junit5.ResourceExtension;
import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
@ -72,7 +75,6 @@ import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
@ -95,7 +97,6 @@ class RegistrationControllerTest {
RegistrationLockVerificationManager.class);
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock(
RegistrationRecoveryPasswordsManager.class);
private final KeysManager keysManager = mock(KeysManager.class);
private final RateLimiters rateLimiters = mock(RateLimiters.class);
private final RateLimiter registrationLimiter = mock(RateLimiter.class);
@ -110,7 +111,7 @@ class RegistrationControllerTest {
.addResource(
new RegistrationController(accountsManager,
new PhoneVerificationTokenManager(registrationServiceClient, registrationRecoveryPasswordsManager),
registrationLockVerificationManager, keysManager, rateLimiters))
registrationLockVerificationManager, rateLimiters))
.build();
@BeforeEach
@ -125,11 +126,6 @@ class RegistrationControllerTest {
return invocation.getArgument(0);
});
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storeEcOneTimePreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storeKemOneTimePreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
}
@Test
@ -171,7 +167,7 @@ class RegistrationControllerTest {
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any()))
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
.thenReturn(account);
final String json = requestJson("sessionId", new byte[0], true, registrationId.orElse(0), pniRegistrationId);
@ -294,7 +290,7 @@ class RegistrationControllerTest {
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any()))
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
.thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest()
@ -352,7 +348,7 @@ class RegistrationControllerTest {
final Account createdAccount = mock(Account.class);
when(createdAccount.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any()))
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
.thenReturn(createdAccount);
expectedStatus = 200;
@ -406,7 +402,8 @@ class RegistrationControllerTest {
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any())).thenReturn(account);
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
.thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
@ -429,7 +426,7 @@ class RegistrationControllerTest {
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any()))
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
.thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest()
@ -669,9 +666,8 @@ class RegistrationControllerTest {
final ECSignedPreKey expectedPniSignedPreKey,
final KEMSignedPreKey expectedAciPqLastResortPreKey,
final KEMSignedPreKey expectedPniPqLastResortPreKey,
final Optional<String> expectedApnsToken,
final Optional<String> expectedApnsVoipToken,
final Optional<String> expectedGcmToken) throws InterruptedException {
final Optional<ApnRegistrationId> expectedApnRegistrationId,
final Optional<GcmRegistrationId> expectedGcmRegistrationId) throws InterruptedException {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
@ -679,9 +675,6 @@ class RegistrationControllerTest {
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final UUID accountIdentifier = UUID.randomUUID();
final UUID phoneNumberIdentifier = UUID.randomUUID();
final Device device = mock(Device.class);
@ -692,9 +685,8 @@ class RegistrationControllerTest {
when(a.getPrimaryDevice()).thenReturn(Optional.of(device));
});
when(accountsManager.create(any(), any(), any(), any(), any())).thenReturn(account);
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
.thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
@ -705,27 +697,33 @@ class RegistrationControllerTest {
assertEquals(200, response.getStatus());
}
verify(accountsManager).create(any(), any(), any(), any(), any());
verify(accountsManager).create(
eq(NUMBER),
eq(PASSWORD),
isNull(),
argThat(attributes -> accountAttributesEqual(attributes, registrationRequest.accountAttributes())),
eq(Collections.emptyList()),
eq(expectedAciIdentityKey),
eq(expectedPniIdentityKey),
eq(expectedAciSignedPreKey),
eq(expectedPniSignedPreKey),
eq(expectedAciPqLastResortPreKey),
eq(expectedPniPqLastResortPreKey),
eq(expectedApnRegistrationId),
eq(expectedGcmRegistrationId));
}
verify(account).setIdentityKey(expectedAciIdentityKey);
verify(account).setPhoneNumberIdentityKey(expectedPniIdentityKey);
verify(device).setSignedPreKey(expectedAciSignedPreKey);
verify(device).setPhoneNumberIdentitySignedPreKey(expectedPniSignedPreKey);
verify(keysManager).storeEcSignedPreKeys(accountIdentifier, Map.of(Device.PRIMARY_ID, expectedAciSignedPreKey));
verify(keysManager).storeEcSignedPreKeys(phoneNumberIdentifier, Map.of(Device.PRIMARY_ID, expectedPniSignedPreKey));
verify(keysManager).storePqLastResort(accountIdentifier, Map.of(Device.PRIMARY_ID, expectedAciPqLastResortPreKey));
verify(keysManager).storePqLastResort(phoneNumberIdentifier, Map.of(Device.PRIMARY_ID, expectedPniPqLastResortPreKey));
expectedApnsToken.ifPresentOrElse(expectedToken -> verify(device).setApnId(expectedToken),
() -> verify(device, never()).setApnId(any()));
expectedApnsVoipToken.ifPresentOrElse(expectedToken -> verify(device).setVoipApnId(expectedToken),
() -> verify(device, never()).setVoipApnId(any()));
expectedGcmToken.ifPresentOrElse(expectedToken -> verify(device).setGcmId(expectedToken),
() -> verify(device, never()).setGcmId(any()));
private static boolean accountAttributesEqual(final AccountAttributes a, final AccountAttributes b) {
return a.getFetchesMessages() == b.getFetchesMessages()
&& a.getRegistrationId() == b.getRegistrationId()
&& a.isUnrestrictedUnidentifiedAccess() == b.isUnrestrictedUnidentifiedAccess()
&& a.isDiscoverableByPhoneNumber() == b.isDiscoverableByPhoneNumber()
&& Objects.equals(a.getPhoneNumberIdentityRegistrationId(), b.getPhoneNumberIdentityRegistrationId())
&& Objects.equals(a.getName(), b.getName())
&& Objects.equals(a.getRegistrationLock(), b.getRegistrationLock())
&& Arrays.equals(a.getUnidentifiedAccessKey(), b.getUnidentifiedAccessKey())
&& Objects.equals(a.getCapabilities(), b.getCapabilities())
&& Objects.equals(a.recoveryPassword(), b.recoveryPassword());
}
private static Stream<Arguments> atomicAccountCreationSuccess() {
@ -800,8 +798,7 @@ class RegistrationControllerTest {
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(apnsToken),
Optional.of(apnsVoipToken),
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
// requires the request to be atomic
@ -823,8 +820,7 @@ class RegistrationControllerTest {
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(apnsToken),
Optional.of(apnsVoipToken),
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
// Fetches messages; no push tokens
@ -847,8 +843,7 @@ class RegistrationControllerTest {
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.empty(),
Optional.of(gcmToken)));
Optional.of(new GcmRegistrationId(gcmToken))));
}
/**

View File

@ -0,0 +1,424 @@
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junitpioneer.jupiter.cartesian.ArgumentSets;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
public class AccountCreationIntegrationTest {
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(
DynamoDbExtensionSchema.Tables.ACCOUNTS,
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS,
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS_LOCK,
DynamoDbExtensionSchema.Tables.NUMBERS,
DynamoDbExtensionSchema.Tables.PNI,
DynamoDbExtensionSchema.Tables.PNI_ASSIGNMENTS,
DynamoDbExtensionSchema.Tables.USERNAMES,
DynamoDbExtensionSchema.Tables.EC_KEYS,
DynamoDbExtensionSchema.Tables.PQ_KEYS,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS,
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
@RegisterExtension
static final RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private static final Clock CLOCK = Clock.fixed(Instant.now(), ZoneId.systemDefault());
private ExecutorService accountLockExecutor;
private AccountsManager accountsManager;
private KeysManager keysManager;
record DeliveryChannels(boolean fetchesMessages, String apnsToken, String apnsVoipToken, String fcmToken) {}
@BeforeEach
void setUp() {
@SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager =
mock(DynamicConfigurationManager.class);
DynamicConfiguration dynamicConfiguration = new DynamicConfiguration();
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.EC_KEYS.tableName(),
DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName(),
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(),
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName(),
dynamicConfigurationManager);
final Accounts accounts = new Accounts(
DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.ACCOUNTS.tableName(),
DynamoDbExtensionSchema.Tables.NUMBERS.tableName(),
DynamoDbExtensionSchema.Tables.PNI_ASSIGNMENTS.tableName(),
DynamoDbExtensionSchema.Tables.USERNAMES.tableName(),
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS.tableName());
accountLockExecutor = Executors.newSingleThreadExecutor();
final AccountLockManager accountLockManager = new AccountLockManager(DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS_LOCK.tableName());
final SecureStorageClient secureStorageClient = mock(SecureStorageClient.class);
when(secureStorageClient.deleteStoredData(any())).thenReturn(CompletableFuture.completedFuture(null));
final SecureValueRecovery2Client svr2Client = mock(SecureValueRecovery2Client.class);
when(svr2Client.deleteBackups(any())).thenReturn(CompletableFuture.completedFuture(null));
final PhoneNumberIdentifiers phoneNumberIdentifiers =
new PhoneNumberIdentifiers(DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DynamoDbExtensionSchema.Tables.PNI.tableName());
final MessagesManager messagesManager = mock(MessagesManager.class);
when(messagesManager.clear(any())).thenReturn(CompletableFuture.completedFuture(null));
final ProfilesManager profilesManager = mock(ProfilesManager.class);
when(profilesManager.deleteAll(any())).thenReturn(CompletableFuture.completedFuture(null));
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager =
mock(RegistrationRecoveryPasswordsManager.class);
when(registrationRecoveryPasswordsManager.removeForNumber(any()))
.thenReturn(CompletableFuture.completedFuture(null));
accountsManager = new AccountsManager(
accounts,
phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
accountLockManager,
keysManager,
messagesManager,
profilesManager,
secureStorageClient,
svr2Client,
mock(ClientPresenceManager.class),
mock(ExperimentEnrollmentManager.class),
registrationRecoveryPasswordsManager,
accountLockExecutor,
CLOCK);
}
@AfterEach
void tearDown() throws InterruptedException {
accountLockExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored
accountLockExecutor.awaitTermination(1, TimeUnit.SECONDS);
}
@CartesianTest
@CartesianTest.MethodFactory("createAccount")
void createAccount(final DeliveryChannels deliveryChannels,
final boolean discoverableByPhoneNumber) throws InterruptedException {
final String number = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
final String password = RandomStringUtils.randomAlphanumeric(16);
final String signalAgent = RandomStringUtils.randomAlphabetic(3);
final int registrationId = ThreadLocalRandom.current().nextInt(Device.MAX_REGISTRATION_ID);
final int pniRegistrationId = ThreadLocalRandom.current().nextInt(Device.MAX_REGISTRATION_ID);
final String deviceName = RandomStringUtils.randomAlphabetic(16);
final String registrationLockSecret = RandomStringUtils.randomAlphanumeric(16);
final Device.DeviceCapabilities deviceCapabilities = new Device.DeviceCapabilities(
ThreadLocalRandom.current().nextBoolean(),
ThreadLocalRandom.current().nextBoolean(),
ThreadLocalRandom.current().nextBoolean(),
ThreadLocalRandom.current().nextBoolean());
final AccountAttributes accountAttributes = new AccountAttributes(deliveryChannels.fetchesMessages(),
registrationId,
deviceName,
registrationLockSecret,
discoverableByPhoneNumber,
deviceCapabilities);
accountAttributes.setPhoneNumberIdentityRegistrationId(pniRegistrationId);
final List<AccountBadge> badges = new ArrayList<>(List.of(new AccountBadge(
RandomStringUtils.randomAlphabetic(8),
CLOCK.instant().plus(Duration.ofDays(7)),
true)));
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
final ECSignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciKeyPair);
final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniKeyPair);
final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciKeyPair);
final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniKeyPair);
final Optional<ApnRegistrationId> maybeApnRegistrationId =
deliveryChannels.apnsToken() != null || deliveryChannels.apnsVoipToken() != null
? Optional.of(new ApnRegistrationId(deliveryChannels.apnsToken(), deliveryChannels.apnsVoipToken()))
: Optional.empty();
final Optional<GcmRegistrationId> maybeGcmRegistrationId = deliveryChannels.fcmToken() != null
? Optional.of(new GcmRegistrationId(deliveryChannels.fcmToken()))
: Optional.empty();
final Account account = accountsManager.create(number,
password,
signalAgent,
accountAttributes,
badges,
new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
maybeApnRegistrationId,
maybeGcmRegistrationId);
assertExpectedStoredAccount(account,
number,
password,
signalAgent,
deliveryChannels,
registrationId,
pniRegistrationId,
deviceName,
discoverableByPhoneNumber,
deviceCapabilities,
badges,
maybeApnRegistrationId,
maybeGcmRegistrationId,
registrationLockSecret,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey);
}
@CartesianTest
@CartesianTest.MethodFactory("createAccount")
void reregisterAccount(final DeliveryChannels deliveryChannels,
final boolean discoverableByPhoneNumber) throws InterruptedException {
final String number = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
final UUID existingAccountUuid;
{
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
final ECSignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciKeyPair);
final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniKeyPair);
final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciKeyPair);
final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniKeyPair);
final Account originalAccount = accountsManager.create(number,
RandomStringUtils.randomAlphanumeric(16),
"OWI",
new AccountAttributes(true, 1, "name", "registration-lock", false, new Device.DeviceCapabilities(false, false, false, false)),
Collections.emptyList(),
new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.empty());
existingAccountUuid = originalAccount.getUuid();
}
final String password = RandomStringUtils.randomAlphanumeric(16);
final String signalAgent = RandomStringUtils.randomAlphabetic(3);
final int registrationId = ThreadLocalRandom.current().nextInt(Device.MAX_REGISTRATION_ID);
final int pniRegistrationId = ThreadLocalRandom.current().nextInt(Device.MAX_REGISTRATION_ID);
final String deviceName = RandomStringUtils.randomAlphabetic(16);
final String registrationLockSecret = RandomStringUtils.randomAlphanumeric(16);
final Device.DeviceCapabilities deviceCapabilities = new Device.DeviceCapabilities(
ThreadLocalRandom.current().nextBoolean(),
ThreadLocalRandom.current().nextBoolean(),
ThreadLocalRandom.current().nextBoolean(),
ThreadLocalRandom.current().nextBoolean());
final AccountAttributes accountAttributes = new AccountAttributes(deliveryChannels.fetchesMessages(),
registrationId,
deviceName,
registrationLockSecret,
discoverableByPhoneNumber,
deviceCapabilities);
accountAttributes.setPhoneNumberIdentityRegistrationId(pniRegistrationId);
final List<AccountBadge> badges = new ArrayList<>(List.of(new AccountBadge(
RandomStringUtils.randomAlphabetic(8),
CLOCK.instant().plus(Duration.ofDays(7)),
true)));
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
final ECSignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciKeyPair);
final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniKeyPair);
final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciKeyPair);
final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniKeyPair);
final Optional<ApnRegistrationId> maybeApnRegistrationId =
deliveryChannels.apnsToken() != null || deliveryChannels.apnsVoipToken() != null
? Optional.of(new ApnRegistrationId(deliveryChannels.apnsToken(), deliveryChannels.apnsVoipToken()))
: Optional.empty();
final Optional<GcmRegistrationId> maybeGcmRegistrationId = deliveryChannels.fcmToken() != null
? Optional.of(new GcmRegistrationId(deliveryChannels.fcmToken()))
: Optional.empty();
final Account reregisteredAccount = accountsManager.create(number,
password,
signalAgent,
accountAttributes,
badges,
new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
maybeApnRegistrationId,
maybeGcmRegistrationId);
assertExpectedStoredAccount(reregisteredAccount,
number,
password,
signalAgent,
deliveryChannels,
registrationId,
pniRegistrationId,
deviceName,
discoverableByPhoneNumber,
deviceCapabilities,
badges,
maybeApnRegistrationId,
maybeGcmRegistrationId,
registrationLockSecret,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey);
assertEquals(existingAccountUuid, reregisteredAccount.getUuid());
}
@SuppressWarnings("unused")
static ArgumentSets createAccount() {
return ArgumentSets
// deliveryChannels
.argumentsForFirstParameter(
new DeliveryChannels(true, null, null, null),
new DeliveryChannels(false, "apns-token", null, null),
new DeliveryChannels(false, "apns-token", "apns-voip-token", null),
new DeliveryChannels(false, null, null, "fcm-token"))
// discoverableByPhoneNumber
.argumentsForNextParameter(true, false);
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
private void assertExpectedStoredAccount(final Account account,
final String number,
final String password,
final String signalAgent,
final DeliveryChannels deliveryChannels,
final int registrationId,
final int pniRegistrationId,
final String deviceName,
final boolean discoverableByPhoneNumber,
final Device.DeviceCapabilities deviceCapabilities,
final List<AccountBadge> badges,
final Optional<ApnRegistrationId> maybeApnRegistrationId,
final Optional<GcmRegistrationId> maybeGcmRegistrationId,
final String registrationLockSecret,
final ECSignedPreKey aciSignedPreKey,
final ECSignedPreKey pniSignedPreKey,
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniPqLastResortPreKey) {
final Device primaryDevice = account.getPrimaryDevice().orElseThrow();
assertEquals(number, account.getNumber());
assertEquals(signalAgent, primaryDevice.getUserAgent());
assertEquals(deliveryChannels.fetchesMessages(), primaryDevice.getFetchesMessages());
assertEquals(registrationId, primaryDevice.getRegistrationId());
assertEquals(pniRegistrationId, primaryDevice.getPhoneNumberIdentityRegistrationId().orElseThrow());
assertEquals(deviceName, primaryDevice.getName());
assertEquals(discoverableByPhoneNumber, account.isDiscoverableByPhoneNumber());
assertEquals(deviceCapabilities, primaryDevice.getCapabilities());
assertEquals(badges, account.getBadges());
maybeApnRegistrationId.ifPresentOrElse(apnRegistrationId -> {
assertEquals(apnRegistrationId.apnRegistrationId(), primaryDevice.getApnId());
assertEquals(apnRegistrationId.voipRegistrationId(), primaryDevice.getVoipApnId());
}, () -> {
assertNull(primaryDevice.getApnId());
assertNull(primaryDevice.getVoipApnId());
});
maybeGcmRegistrationId.ifPresentOrElse(gcmRegistrationId -> {
assertEquals(deliveryChannels.fcmToken(), primaryDevice.getGcmId());
}, () -> {
assertNull(primaryDevice.getGcmId());
});
assertTrue(account.getRegistrationLock().verify(registrationLockSecret));
assertTrue(primaryDevice.getAuthTokenHash().verify(password));
assertEquals(Optional.of(aciSignedPreKey), keysManager.getEcSignedPreKey(account.getIdentifier(IdentityType.ACI), Device.PRIMARY_ID).join());
assertEquals(Optional.of(pniSignedPreKey), keysManager.getEcSignedPreKey(account.getIdentifier(IdentityType.PNI), Device.PRIMARY_ID).join());
assertEquals(Optional.of(aciPqLastResortPreKey), keysManager.getLastResort(account.getIdentifier(IdentityType.ACI), Device.PRIMARY_ID).join());
assertEquals(Optional.of(pniPqLastResortPreKey), keysManager.getLastResort(account.getIdentifier(IdentityType.PNI), Device.PRIMARY_ID).join());
}
}

View File

@ -14,7 +14,6 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.time.Clock;
import java.util.ArrayList;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
@ -41,6 +40,7 @@ import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
class AccountsManagerChangeNumberIntegrationTest {
@ -53,7 +53,11 @@ class AccountsManagerChangeNumberIntegrationTest {
Tables.NUMBERS,
Tables.PNI,
Tables.PNI_ASSIGNMENTS,
Tables.USERNAMES);
Tables.USERNAMES,
Tables.EC_KEYS,
Tables.PQ_KEYS,
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS,
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
@RegisterExtension
static final RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@ -73,6 +77,14 @@ class AccountsManagerChangeNumberIntegrationTest {
DynamicConfiguration dynamicConfiguration = new DynamicConfiguration();
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
final KeysManager keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
Tables.EC_KEYS.tableName(),
Tables.PQ_KEYS.tableName(),
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(),
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName(),
dynamicConfigurationManager);
final Accounts accounts = new Accounts(
DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
@ -98,9 +110,6 @@ class AccountsManagerChangeNumberIntegrationTest {
final PhoneNumberIdentifiers phoneNumberIdentifiers =
new PhoneNumberIdentifiers(DYNAMO_DB_EXTENSION.getDynamoDbClient(), Tables.PNI.tableName());
final KeysManager keysManager = mock(KeysManager.class);
when(keysManager.delete(any())).thenReturn(CompletableFuture.completedFuture(null));
final MessagesManager messagesManager = mock(MessagesManager.class);
when(messagesManager.clear(any())).thenReturn(CompletableFuture.completedFuture(null));
@ -143,8 +152,8 @@ class AccountsManagerChangeNumberIntegrationTest {
void testChangeNumber() throws InterruptedException, MismatchedDevicesException {
final String originalNumber = "+18005551111";
final String secondNumber = "+18005552222";
final Account account = AccountsHelper.createAccount(accountsManager, originalNumber);
final Account account = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>());
final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier();
@ -167,17 +176,17 @@ class AccountsManagerChangeNumberIntegrationTest {
final String originalNumber = "+18005551111";
final String secondNumber = "+18005552222";
final int rotatedPniRegistrationId = 17;
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
final ECSignedPreKey rotatedSignedPreKey = KeysHelper.signedECPreKey(1L, pniIdentityKeyPair);
final ECKeyPair rotatedPniIdentityKeyPair = Curve.generateKeyPair();
final ECSignedPreKey rotatedSignedPreKey = KeysHelper.signedECPreKey(1L, rotatedPniIdentityKeyPair);
final AccountAttributes accountAttributes = new AccountAttributes(true, rotatedPniRegistrationId + 1, "test", null, true, new Device.DeviceCapabilities(false, false, false, false));
final Account account = accountsManager.create(originalNumber, "password", null, accountAttributes, new ArrayList<>());
account.getPrimaryDevice().orElseThrow().setSignedPreKey(KeysHelper.signedECPreKey(1, pniIdentityKeyPair));
final Account account = AccountsHelper.createAccount(accountsManager, originalNumber, accountAttributes);
account.getPrimaryDevice().orElseThrow().setSignedPreKey(KeysHelper.signedECPreKey(1, rotatedPniIdentityKeyPair));
final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey());
final IdentityKey pniIdentityKey = new IdentityKey(rotatedPniIdentityKeyPair.getPublicKey());
final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID, rotatedSignedPreKey);
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, rotatedPniRegistrationId);
@ -207,7 +216,8 @@ class AccountsManagerChangeNumberIntegrationTest {
final String originalNumber = "+18005551111";
final String secondNumber = "+18005552222";
Account account = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>());
Account account = AccountsHelper.createAccount(accountsManager, originalNumber);
final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier();
@ -231,10 +241,12 @@ class AccountsManagerChangeNumberIntegrationTest {
final String originalNumber = "+18005551111";
final String secondNumber = "+18005552222";
final Account account = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>());
final Account account = AccountsHelper.createAccount(accountsManager, originalNumber);
final UUID originalUuid = account.getUuid();
final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>());
final Account existingAccount = AccountsHelper.createAccount(accountsManager, secondNumber);
final UUID existingAccountUuid = existingAccount.getUuid();
accountsManager.changeNumber(account, secondNumber, null, null, null, null);
@ -253,8 +265,7 @@ class AccountsManagerChangeNumberIntegrationTest {
accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(), originalNumber, null, null, null, null);
final Account existingAccount2 = accountsManager.create(secondNumber, "password", null, new AccountAttributes(),
new ArrayList<>());
final Account existingAccount2 = AccountsHelper.createAccount(accountsManager, secondNumber);
assertEquals(existingAccountUuid, existingAccount2.getUuid());
}
@ -264,17 +275,19 @@ class AccountsManagerChangeNumberIntegrationTest {
final String originalNumber = "+18005551111";
final String secondNumber = "+18005552222";
final Account account = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>());
final Account account = AccountsHelper.createAccount(accountsManager, originalNumber);
final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier();
final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>());
final Account existingAccount = AccountsHelper.createAccount(accountsManager, secondNumber);
final UUID existingAccountUuid = existingAccount.getUuid();
final Account changedNumberAccount = accountsManager.changeNumber(account, secondNumber, null, null, null, null);
final UUID secondPni = changedNumberAccount.getPhoneNumberIdentifier();
final Account reRegisteredAccount = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>());
final Account reRegisteredAccount = AccountsHelper.createAccount(accountsManager, originalNumber);
assertEquals(existingAccountUuid, reRegisteredAccount.getUuid());
assertEquals(originalPni, reRegisteredAccount.getPhoneNumberIdentifier());

View File

@ -23,6 +23,7 @@ import java.io.IOException;
import java.time.Clock;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
@ -39,6 +40,7 @@ import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
@ -51,6 +53,7 @@ import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.tests.util.JsonHelpers;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.Pair;
@ -62,8 +65,11 @@ class AccountsManagerConcurrentModificationIntegrationTest {
Tables.ACCOUNTS,
Tables.NUMBERS,
Tables.PNI_ASSIGNMENTS,
Tables.DELETED_ACCOUNTS
);
Tables.DELETED_ACCOUNTS,
Tables.EC_KEYS,
Tables.PQ_KEYS,
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS,
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
private Accounts accounts;
@ -80,6 +86,14 @@ class AccountsManagerConcurrentModificationIntegrationTest {
mock(DynamicConfigurationManager.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
final KeysManager keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
Tables.EC_KEYS.tableName(),
Tables.PQ_KEYS.tableName(),
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(),
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName(),
dynamicConfigurationManager);
accounts = new Accounts(
DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
@ -134,11 +148,25 @@ class AccountsManagerConcurrentModificationIntegrationTest {
@Test
void testConcurrentUpdate() throws IOException, InterruptedException {
final UUID uuid;
{
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
final Account account = accountsManager.update(
accountsManager.create("+14155551212", "password", null, new AccountAttributes(), new ArrayList<>()),
accountsManager.create("+14155551212",
"password",
null,
new AccountAttributes(),
new ArrayList<>(),
new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair),
Optional.empty(),
Optional.empty()),
a -> {
a.setUnidentifiedAccessKey(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
a.removeDevice(Device.PRIMARY_ID);

View File

@ -13,6 +13,7 @@ import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
@ -201,6 +202,7 @@ class AccountsManagerTest {
when(registrationRecoveryPasswordsManager.removeForNumber(anyString())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.delete(any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.delete(any(), anyBoolean())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.clear(any())).thenReturn(CompletableFuture.completedFuture(null));
when(profilesManager.deleteAll(any())).thenReturn(CompletableFuture.completedFuture(null));
@ -853,7 +855,7 @@ class AccountsManagerTest {
when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifier(uuid)).thenReturn(Optional.empty())
.thenReturn(Optional.of(account));
when(accounts.create(any())).thenThrow(ContestedOptimisticLockException.class);
when(accounts.create(any(), any())).thenThrow(ContestedOptimisticLockException.class);
accountsManager.update(account, a -> {
});
@ -930,14 +932,15 @@ class AccountsManagerTest {
@Test
void testCreateFreshAccount() throws InterruptedException {
when(accounts.create(any())).thenReturn(true);
when(accounts.create(any(), any())).thenReturn(true);
final String e164 = "+18005550123";
final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true, null);
accountsManager.create(e164, "password", null, attributes, new ArrayList<>());
verify(accounts).create(argThat(account -> e164.equals(account.getNumber())));
verifyNoInteractions(keysManager);
createAccount(e164, attributes);
verify(accounts).create(argThat(account -> e164.equals(account.getNumber())), any());
verifyNoInteractions(messagesManager);
verifyNoInteractions(profilesManager);
}
@ -946,22 +949,23 @@ class AccountsManagerTest {
void testReregisterAccount() throws InterruptedException {
final UUID existingUuid = UUID.randomUUID();
when(accounts.create(any())).thenAnswer(invocation -> {
when(accounts.create(any(), any())).thenAnswer(invocation -> {
invocation.getArgument(0, Account.class).setUuid(existingUuid);
return false;
});
final String e164 = "+18005550123";
final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true, null);
accountsManager.create(e164, "password", null, attributes, new ArrayList<>());
createAccount(e164, attributes);
assertTrue(phoneNumberIdentifiersByE164.containsKey(e164));
verify(accounts)
.create(argThat(account -> e164.equals(account.getNumber()) && existingUuid.equals(account.getUuid())));
.create(argThat(account -> e164.equals(account.getNumber()) && existingUuid.equals(account.getUuid())), any());
verify(keysManager).delete(existingUuid);
verify(keysManager).delete(phoneNumberIdentifiersByE164.get(e164));
verify(keysManager).delete(existingUuid, true);
verify(keysManager).delete(phoneNumberIdentifiersByE164.get(e164), true);
verify(messagesManager).clear(existingUuid);
verify(profilesManager).deleteAll(existingUuid);
verify(clientPresenceManager).disconnectAllPresencesForUuid(existingUuid);
@ -972,14 +976,17 @@ class AccountsManagerTest {
final UUID recentlyDeletedUuid = UUID.randomUUID();
when(accounts.findRecentlyDeletedAccountIdentifier(anyString())).thenReturn(Optional.of(recentlyDeletedUuid));
when(accounts.create(any())).thenReturn(true);
when(accounts.create(any(), any())).thenReturn(true);
final String e164 = "+18005550123";
final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true, null);
accountsManager.create(e164, "password", null, attributes, new ArrayList<>());
createAccount(e164, attributes);
verify(accounts).create(
argThat(account -> e164.equals(account.getNumber()) && recentlyDeletedUuid.equals(account.getUuid())));
argThat(account -> e164.equals(account.getNumber()) && recentlyDeletedUuid.equals(account.getUuid())),
any());
verifyNoInteractions(keysManager);
verifyNoInteractions(messagesManager);
verifyNoInteractions(profilesManager);
@ -989,7 +996,7 @@ class AccountsManagerTest {
@ValueSource(booleans = {true, false})
void testCreateWithDiscoverability(final boolean discoverable) throws InterruptedException {
final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, discoverable, null);
final Account account = accountsManager.create("+18005550123", "password", null, attributes, new ArrayList<>());
final Account account = createAccount("+18005550123", attributes);
assertEquals(discoverable, account.isDiscoverableByPhoneNumber());
}
@ -1000,7 +1007,7 @@ class AccountsManagerTest {
final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true,
new DeviceCapabilities(hasStorage, false, false, false));
final Account account = accountsManager.create("+18005550123", "password", null, attributes, new ArrayList<>());
final Account account = createAccount("+18005550123", attributes);
assertEquals(hasStorage, account.isStorageSupported());
}
@ -1572,4 +1579,23 @@ class AccountsManagerTest {
return device;
}
private Account createAccount(final String e164, final AccountAttributes accountAttributes) throws InterruptedException {
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
return accountsManager.create(e164,
"password",
null,
accountAttributes,
new ArrayList<>(),
new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair),
Optional.empty(),
Optional.empty());
}
}

View File

@ -25,6 +25,7 @@ import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -39,13 +40,13 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
@ -70,7 +71,11 @@ class AccountsManagerUsernameIntegrationTest {
Tables.USERNAMES,
Tables.DELETED_ACCOUNTS,
Tables.PNI,
Tables.PNI_ASSIGNMENTS);
Tables.PNI_ASSIGNMENTS,
Tables.EC_KEYS,
Tables.PQ_KEYS,
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS,
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
@RegisterExtension
static RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@ -91,6 +96,14 @@ class AccountsManagerUsernameIntegrationTest {
DynamicConfiguration dynamicConfiguration = new DynamicConfiguration();
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
final KeysManager keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
Tables.EC_KEYS.tableName(),
Tables.PQ_KEYS.tableName(),
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(),
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName(),
dynamicConfigurationManager);
accounts = Mockito.spy(new Accounts(
DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
@ -122,12 +135,13 @@ class AccountsManagerUsernameIntegrationTest {
final ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
when(experimentEnrollmentManager.isEnrolled(any(UUID.class), eq(AccountsManager.USERNAME_EXPERIMENT_NAME)))
.thenReturn(true);
accountsManager = new AccountsManager(
accounts,
phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
accountLockManager,
mock(KeysManager.class),
keysManager,
mock(MessagesManager.class),
mock(ProfilesManager.class),
mock(SecureStorageClient.class),
@ -141,8 +155,8 @@ class AccountsManagerUsernameIntegrationTest {
@Test
void testNoUsernames() throws InterruptedException {
Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(),
new ArrayList<>());
final Account account = AccountsHelper.createAccount(accountsManager, "+18005551111");
List<byte[]> usernameHashes = List.of(USERNAME_HASH_1, USERNAME_HASH_2);
int i = 0;
for (byte[] hash : usernameHashes) {
@ -169,8 +183,8 @@ class AccountsManagerUsernameIntegrationTest {
@Test
void testReserveUsernameSnatched() throws InterruptedException, UsernameHashNotAvailableException {
final Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(),
new ArrayList<>());
final Account account = AccountsHelper.createAccount(accountsManager, "+18005551111");
ArrayList<byte[]> usernameHashes = new ArrayList<>(Arrays.asList(USERNAME_HASH_1, USERNAME_HASH_2));
for (byte[] hash : usernameHashes) {
DYNAMO_DB_EXTENSION.getDynamoDbClient().putItem(PutItemRequest.builder()
@ -205,10 +219,8 @@ class AccountsManagerUsernameIntegrationTest {
}
@Test
public void testReserveConfirmClear()
throws InterruptedException, UsernameHashNotAvailableException, UsernameReservationNotFoundException {
Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(),
new ArrayList<>());
public void testReserveConfirmClear() throws InterruptedException {
Account account = AccountsHelper.createAccount(accountsManager, "+18005551111");
// reserve
AccountsManager.UsernameReservation reservation =
@ -236,11 +248,8 @@ class AccountsManagerUsernameIntegrationTest {
}
@Test
public void testReservationLapsed()
throws InterruptedException, UsernameHashNotAvailableException, UsernameReservationNotFoundException {
final Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(),
new ArrayList<>());
public void testReservationLapsed() throws InterruptedException {
final Account account = AccountsHelper.createAccount(accountsManager, "+18005551111");
AccountsManager.UsernameReservation reservation1 =
accountsManager.reserveUsernameHash(account, List.of(USERNAME_HASH_1)).join();
@ -256,8 +265,8 @@ class AccountsManagerUsernameIntegrationTest {
.build());
// a different account should be able to reserve it
Account account2 = accountsManager.create("+18005552222", "password", null, new AccountAttributes(),
new ArrayList<>());
Account account2 = AccountsHelper.createAccount(accountsManager, "+18005552222");
final AccountsManager.UsernameReservation reservation2 =
accountsManager.reserveUsernameHash(account2, List.of(USERNAME_HASH_1)).join();
assertArrayEquals(reservation2.reservedUsernameHash(), USERNAME_HASH_1);
@ -271,8 +280,7 @@ class AccountsManagerUsernameIntegrationTest {
@Test
void testUsernameSetReserveAnotherClearSetReserved() throws InterruptedException {
Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(),
new ArrayList<>());
Account account = AccountsHelper.createAccount(accountsManager, "+18005551111");
// Set username hash
final AccountsManager.UsernameReservation reservation1 =
@ -303,9 +311,10 @@ class AccountsManagerUsernameIntegrationTest {
@Test
public void testUsernameLinks() throws InterruptedException {
Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(), new ArrayList<>());
final Account account = AccountsHelper.createAccount(accountsManager, "+18005551111");
account.setUsernameHash(RandomUtils.nextBytes(16));
accounts.create(account);
accounts.create(account, ignored -> Collections.emptyList());
final UUID linkHandle = UUID.randomUUID();
final byte[] encryptedUsername = RandomUtils.nextBytes(32);

View File

@ -21,12 +21,12 @@ import static org.mockito.Mockito.when;
import com.fasterxml.jackson.core.JsonProcessingException;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@ -114,7 +114,7 @@ class AccountsTest {
when(mockDynamicConfigManager.getConfiguration())
.thenReturn(new DynamicConfiguration());
this.accounts = new Accounts(
accounts = new Accounts(
clock,
DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
@ -129,7 +129,7 @@ class AccountsTest {
public void testStoreAndLookupUsernameLink() {
final Account account = nextRandomAccount();
account.setUsernameHash(RandomUtils.nextBytes(16));
accounts.create(account);
createAccount(account);
final BiConsumer<Optional<Account>, byte[]> validator = (maybeAccount, expectedEncryptedUsername) -> {
assertTrue(maybeAccount.isPresent());
@ -165,7 +165,7 @@ class AccountsTest {
Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
boolean freshUser = accounts.create(account);
boolean freshUser = createAccount(account);
assertThat(freshUser).isTrue();
verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true);
@ -173,7 +173,7 @@ class AccountsTest {
assertPhoneNumberConstraintExists("+14151112222", account.getUuid());
assertPhoneNumberIdentifierConstraintExists(account.getPhoneNumberIdentifier(), account.getUuid());
freshUser = accounts.create(account);
freshUser = createAccount(account);
assertThat(freshUser).isTrue();
verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true);
@ -188,7 +188,7 @@ class AccountsTest {
Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", originalUuid, UUID.randomUUID(), List.of(device));
boolean freshUser = accounts.create(account);
boolean freshUser = createAccount(account);
assertThat(freshUser).isTrue();
verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true);
@ -199,7 +199,7 @@ class AccountsTest {
accounts.delete(originalUuid).join();
assertThat(accounts.findRecentlyDeletedAccountIdentifier(account.getNumber())).hasValue(originalUuid);
freshUser = accounts.create(account);
freshUser = createAccount(account);
assertThat(freshUser).isTrue();
verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true);
@ -214,7 +214,7 @@ class AccountsTest {
final List<Device> devices = List.of(generateDevice(DEVICE_ID_1), generateDevice(DEVICE_ID_2));
final Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), devices);
accounts.create(account);
createAccount(account);
verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true);
@ -236,8 +236,8 @@ class AccountsTest {
UUID pniSecond = UUID.randomUUID();
Account accountSecond = generateAccount("+14152221111", uuidSecond, pniSecond, devicesSecond);
accounts.create(accountFirst);
accounts.create(accountSecond);
createAccount(accountFirst);
createAccount(accountSecond);
Optional<Account> retrievedFirst = accounts.getByE164("+14151112222");
Optional<Account> retrievedSecond = accounts.getByE164("+14152221111");
@ -340,7 +340,7 @@ class AccountsTest {
UUID firstUuid = UUID.randomUUID();
UUID firstPni = UUID.randomUUID();
Account account = generateAccount("+14151112222", firstUuid, firstPni, List.of(device));
accounts.create(account);
createAccount(account);
final byte[] usernameHash = randomBytes(32);
final byte[] encryptedUsername = randomBytes(32);
@ -358,7 +358,7 @@ class AccountsTest {
// simulate a failed re-reg: we give the account a reclaimable username, but we'll try
// re-registering again later in the test case
account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
accounts.create(account);
createAccount(account);
break;
case CONFIRMED:
accounts.reserveUsernameHash(account, usernameHash, Duration.ofMinutes(1)).join();
@ -370,7 +370,7 @@ class AccountsTest {
// re-register the account
account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
accounts.create(account);
createAccount(account);
// If we had a username link, or we had previously saved a username link from another re-registration, make sure
// we preserve it
@ -401,7 +401,7 @@ class AccountsTest {
UUID firstPni = UUID.randomUUID();
Account account = generateAccount("+14151112222", firstUuid, firstPni, List.of(device));
accounts.create(account);
createAccount(account);
final byte[] usernameHash = randomBytes(32);
final byte[] encryptedUsername = randomBytes(16);
@ -422,7 +422,7 @@ class AccountsTest {
device = generateDevice(DEVICE_ID_1);
account = generateAccount("+14151112222", secondUuid, UUID.randomUUID(), List.of(device));
final boolean freshUser = accounts.create(account);
final boolean freshUser = createAccount(account);
assertThat(freshUser).isFalse();
// usernameHash should be unset
verifyStoredState("+14151112222", firstUuid, firstPni, null, account, true);
@ -453,7 +453,8 @@ class AccountsTest {
device = generateDevice(DEVICE_ID_1);
Account invalidAccount = generateAccount("+14151113333", firstUuid, UUID.randomUUID(), List.of(device));
assertThatThrownBy(() -> accounts.create(invalidAccount));
assertThatThrownBy(() -> createAccount(invalidAccount));
}
@Test
@ -461,7 +462,7 @@ class AccountsTest {
Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
accounts.create(account);
createAccount(account);
assertPhoneNumberConstraintExists("+14151112222", account.getUuid());
assertPhoneNumberIdentifierConstraintExists(account.getPhoneNumberIdentifier(), account.getUuid());
@ -511,8 +512,11 @@ class AccountsTest {
final DynamoDbAsyncClient dynamoDbAsyncClient = mock(DynamoDbAsyncClient.class);
accounts = new Accounts(mock(DynamoDbClient.class),
dynamoDbAsyncClient, Tables.ACCOUNTS.tableName(),
Tables.NUMBERS.tableName(), Tables.PNI_ASSIGNMENTS.tableName(), Tables.USERNAMES.tableName(),
dynamoDbAsyncClient,
Tables.ACCOUNTS.tableName(),
Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName());
Exception e = TransactionConflictException.builder().build();
@ -533,7 +537,7 @@ class AccountsTest {
for (int i = 1; i <= 100; i++) {
final Account account = generateAccount("+1" + String.format("%03d", i), UUID.randomUUID(), UUID.randomUUID());
expectedAccounts.add(account);
accounts.create(account);
createAccount(account);
}
final List<Account> retrievedAccounts =
@ -553,8 +557,8 @@ class AccountsTest {
final Account retainedAccount = generateAccount("+14151112345", UUID.randomUUID(),
UUID.randomUUID(), List.of(retainedDevice));
accounts.create(deletedAccount);
accounts.create(retainedAccount);
createAccount(deletedAccount);
createAccount(retainedAccount);
assertThat(accounts.findRecentlyDeletedAccountIdentifier(deletedAccount.getNumber())).isEmpty();
@ -581,7 +585,7 @@ class AccountsTest {
final Account recreatedAccount = generateAccount(deletedAccount.getNumber(), UUID.randomUUID(),
UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
final boolean freshUser = accounts.create(recreatedAccount);
final boolean freshUser = createAccount(recreatedAccount);
assertThat(freshUser).isTrue();
assertThat(accounts.getByAccountIdentifier(recreatedAccount.getUuid())).isPresent();
@ -598,7 +602,7 @@ class AccountsTest {
Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
accounts.create(account);
createAccount(account);
Optional<Account> retrieved = accounts.getByE164("+11111111");
assertThat(retrieved.isPresent()).isFalse();
@ -614,7 +618,7 @@ class AccountsTest {
final Account account =
generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
accounts.create(account);
createAccount(account);
assertThat(accounts.getByAccountIdentifierAsync(account.getUuid()).join()).isPresent();
}
@ -626,7 +630,7 @@ class AccountsTest {
final Account account =
generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
accounts.create(account);
createAccount(account);
assertThat(accounts.getByPhoneNumberIdentifierAsync(account.getPhoneNumberIdentifier()).join()).isPresent();
}
@ -640,7 +644,7 @@ class AccountsTest {
final Account account =
generateAccount(e164, UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
accounts.create(account);
createAccount(account);
assertThat(accounts.getByE164Async(e164).join()).isPresent();
}
@ -650,7 +654,7 @@ class AccountsTest {
Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
account.setDiscoverableByPhoneNumber(false);
accounts.create(account);
createAccount(account);
verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, false);
account.setDiscoverableByPhoneNumber(true);
accounts.update(account);
@ -673,7 +677,7 @@ class AccountsTest {
final Device device = generateDevice(DEVICE_ID_1);
final Account account = generateAccount(originalNumber, UUID.randomUUID(), originalPni, List.of(device));
accounts.create(account);
createAccount(account);
assertThat(accounts.getByPhoneNumberIdentifier(originalPni)).isPresent();
@ -731,8 +735,8 @@ class AccountsTest {
final Device device = generateDevice(DEVICE_ID_1);
final Account account = generateAccount(originalNumber, UUID.randomUUID(), originalPni, List.of(device));
accounts.create(account);
accounts.create(existingAccount);
createAccount(account);
createAccount(existingAccount);
assertThrows(TransactionCanceledException.class, () -> accounts.changeNumber(account, targetNumber, targetPni, Optional.of(existingAccount.getUuid())));
@ -750,7 +754,7 @@ class AccountsTest {
final Device device = generateDevice(DEVICE_ID_1);
final Account account = generateAccount(originalNumber, UUID.randomUUID(), UUID.randomUUID(), List.of(device));
accounts.create(account);
createAccount(account);
final UUID existingAccountIdentifier = UUID.randomUUID();
final UUID existingPhoneNumberIdentifier = UUID.randomUUID();
@ -776,7 +780,7 @@ class AccountsTest {
@Test
void testSwitchUsernameHashes() {
final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account);
createAccount(account);
assertThat(accounts.getByUsernameHash(USERNAME_HASH_1).join()).isEmpty();
@ -822,8 +826,8 @@ class AccountsTest {
final Account firstAccount = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
final Account secondAccount = generateAccount("+18005559876", UUID.randomUUID(), UUID.randomUUID());
accounts.create(firstAccount);
accounts.create(secondAccount);
createAccount(firstAccount);
createAccount(secondAccount);
// first account reserves and confirms username hash
assertThatNoException().isThrownBy(() -> {
@ -855,7 +859,7 @@ class AccountsTest {
@Test
void testConfirmUsernameHashVersionMismatch() {
final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account);
createAccount(account);
accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(1)).join();
account.setVersion(account.getVersion() + 77);
@ -868,7 +872,7 @@ class AccountsTest {
@Test
void testClearUsername() {
final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account);
createAccount(account);
accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(1)).join();
accounts.confirmUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1).join();
@ -888,7 +892,7 @@ class AccountsTest {
@Test
void testClearUsernameNoUsername() {
final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account);
createAccount(account);
assertThatNoException().isThrownBy(() -> accounts.clearUsernameHash(account).join());
}
@ -896,7 +900,7 @@ class AccountsTest {
@Test
void testClearUsernameVersionMismatch() {
final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account);
createAccount(account);
accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(1)).join();
accounts.confirmUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1).join();
@ -912,9 +916,9 @@ class AccountsTest {
@Test
void testReservedUsernameHash() {
final Account account1 = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account1);
createAccount(account1);
final Account account2 = generateAccount("+18005552222", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account2);
createAccount(account2);
accounts.reserveUsernameHash(account1, USERNAME_HASH_1, Duration.ofDays(1)).join();
assertArrayEquals(account1.getReservedUsernameHash().orElseThrow(), USERNAME_HASH_1);
@ -946,7 +950,7 @@ class AccountsTest {
@Test
void testUsernameHashAvailable() {
final Account account1 = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account1);
createAccount(account1);
accounts.reserveUsernameHash(account1, USERNAME_HASH_1, Duration.ofDays(1)).join();
assertThat(accounts.usernameHashAvailable(USERNAME_HASH_1).join()).isFalse();
@ -964,9 +968,9 @@ class AccountsTest {
@Test
void testConfirmReservedUsernameHashWrongAccountUuid() {
final Account account1 = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account1);
createAccount(account1);
final Account account2 = generateAccount("+18005552222", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account2);
createAccount(account2);
accounts.reserveUsernameHash(account1, USERNAME_HASH_1, Duration.ofDays(1)).join();
assertArrayEquals(account1.getReservedUsernameHash().orElseThrow(), USERNAME_HASH_1);
@ -980,9 +984,9 @@ class AccountsTest {
@Test
void testConfirmExpiredReservedUsernameHash() {
final Account account1 = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account1);
createAccount(account1);
final Account account2 = generateAccount("+18005552222", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account2);
createAccount(account2);
accounts.reserveUsernameHash(account1, USERNAME_HASH_1, Duration.ofDays(2)).join();
@ -1009,7 +1013,7 @@ class AccountsTest {
@Test
void testRetryReserveUsernameHash() {
final Account account = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account);
createAccount(account);
accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(2)).join();
CompletableFutureTestUtil.assertFailsWithCause(ContestedOptimisticLockException.class,
@ -1020,7 +1024,7 @@ class AccountsTest {
@Test
void testReserveConfirmUsernameHashVersionConflict() {
final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account);
createAccount(account);
account.setVersion(account.getVersion() + 12);
CompletableFutureTestUtil.assertFailsWithCause(ContestedOptimisticLockException.class,
accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(1)));
@ -1035,7 +1039,7 @@ class AccountsTest {
final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
account.setUsernameHash(RandomUtils.nextBytes(32));
account.setUsernameLinkDetails(UUID.randomUUID(), RandomUtils.nextBytes(32));
accounts.create(account);
createAccount(account);
final Map<String, AttributeValue> accountRecord = DYNAMO_DB_EXTENSION.getDynamoDbClient()
.getItem(GetItemRequest.builder()
.tableName(Tables.ACCOUNTS.tableName())
@ -1053,7 +1057,7 @@ class AccountsTest {
assertThat(accounts.getByUsernameHash(USERNAME_HASH_1).join()).isEmpty();
final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account);
createAccount(account);
assertThat(accounts.getByUsernameHash(USERNAME_HASH_1).join()).isEmpty();
@ -1069,7 +1073,7 @@ class AccountsTest {
final Device device2 = generateDevice((byte) 64);
account.addDevice(device2);
accounts.create(account);
createAccount(account);
final GetItemResponse response = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient().getItem(GetItemRequest.builder()
.tableName(Tables.ACCOUNTS.tableName())
@ -1108,6 +1112,10 @@ class AccountsTest {
return DevicesHelper.createDevice(id);
}
private boolean createAccount(final Account account) {
return accounts.create(account, ignored -> Collections.emptyList());
}
private static Account nextRandomAccount() {
final String nextNumber = "+1800%07d".formatted(ACCOUNT_COUNTER.getAndIncrement());
return generateAccount(nextNumber, UUID.randomUUID(), UUID.randomUUID());

View File

@ -5,14 +5,16 @@
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import static org.junit.jupiter.api.Assertions.*;
abstract class RepeatedUseSignedPreKeyStoreTest<K extends SignedPreKey<?>> {
@ -50,38 +52,47 @@ abstract class RepeatedUseSignedPreKeyStoreTest<K extends SignedPreKey<?>> {
}
@Test
void delete() {
void deleteForDevice() {
final RepeatedUseSignedPreKeyStore<K> keys = getKeyStore();
assertDoesNotThrow(() -> keys.delete(UUID.randomUUID()).join());
final UUID identifier = UUID.randomUUID();
final byte deviceId2 = 2;
{
final UUID identifier = UUID.randomUUID();
final Map<Byte, K> signedPreKeys = Map.of(
Device.PRIMARY_ID, generateSignedPreKey(),
deviceId2, generateSignedPreKey()
);
final Map<Byte, K> signedPreKeys = Map.of(
Device.PRIMARY_ID, generateSignedPreKey(),
deviceId2, generateSignedPreKey()
);
keys.store(identifier, signedPreKeys).join();
keys.delete(identifier, Device.PRIMARY_ID).join();
keys.store(identifier, signedPreKeys).join();
keys.delete(identifier, Device.PRIMARY_ID).join();
assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join());
assertEquals(Optional.of(signedPreKeys.get(deviceId2)), keys.find(identifier, deviceId2).join());
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void deleteForAllDevices(final boolean excludePrimaryDevice) {
final RepeatedUseSignedPreKeyStore<K> keys = getKeyStore();
assertDoesNotThrow(() -> keys.delete(UUID.randomUUID(), excludePrimaryDevice).join());
final byte deviceId2 = Device.PRIMARY_ID + 1;
final UUID identifier = UUID.randomUUID();
final Map<Byte, K> signedPreKeys = Map.of(
Device.PRIMARY_ID, generateSignedPreKey(),
deviceId2, generateSignedPreKey()
);
keys.store(identifier, signedPreKeys).join();
keys.delete(identifier, excludePrimaryDevice).join();
if (excludePrimaryDevice) {
assertTrue(keys.find(identifier, Device.PRIMARY_ID).join().isPresent());
} else {
assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join());
assertEquals(Optional.of(signedPreKeys.get(deviceId2)), keys.find(identifier, deviceId2).join());
}
{
final UUID identifier = UUID.randomUUID();
final Map<Byte, K> signedPreKeys = Map.of(
Device.PRIMARY_ID, generateSignedPreKey(),
deviceId2, generateSignedPreKey()
);
keys.store(identifier, signedPreKeys).join();
keys.delete(identifier).join();
assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join());
assertEquals(Optional.empty(), keys.find(identifier, deviceId2).join());
}
assertEquals(Optional.empty(), keys.find(identifier, deviceId2).join());
}
}

View File

@ -15,13 +15,19 @@ import static org.mockito.Mockito.when;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.function.Consumer;
import org.mockito.MockingDetails;
import org.mockito.stubbing.Stubbing;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
@ -150,4 +156,30 @@ public class AccountsHelper {
return argThat(other -> other.getUuid().equals(value.getUuid()));
}
public static Account createAccount(final AccountsManager accountsManager, final String e164)
throws InterruptedException {
return createAccount(accountsManager, e164, new AccountAttributes());
}
public static Account createAccount(final AccountsManager accountsManager, final String e164, final AccountAttributes accountAttributes)
throws InterruptedException {
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
return accountsManager.create(e164,
"password",
null,
accountAttributes,
new ArrayList<>(),
new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair),
Optional.empty(),
Optional.empty());
}
}