Create accounts transactionally

This commit is contained in:
Jon Chambers 2023-11-13 13:05:29 -05:00 committed by Jon Chambers
parent 07c04006df
commit c8033f875d
16 changed files with 854 additions and 265 deletions

View File

@ -342,8 +342,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName()); config.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName());
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient, Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getProfiles().getTableName()); config.getDynamoDbTables().getProfiles().getTableName());
KeysManager keys = new KeysManager( KeysManager keysManager = new KeysManager(
dynamoDbAsyncClient, dynamoDbAsyncClient,
config.getDynamoDbTables().getEcKeys().getTableName(), config.getDynamoDbTables().getEcKeys().getTableName(),
config.getDynamoDbTables().getKemKeys().getTableName(), config.getDynamoDbTables().getKemKeys().getTableName(),
config.getDynamoDbTables().getEcSignedPreKeys().getTableName(), config.getDynamoDbTables().getEcSignedPreKeys().getTableName(),
@ -525,7 +525,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
AccountLockManager accountLockManager = new AccountLockManager(dynamoDbClient, AccountLockManager accountLockManager = new AccountLockManager(dynamoDbClient,
config.getDynamoDbTables().getDeletedAccountsLock().getTableName()); config.getDynamoDbTables().getDeletedAccountsLock().getTableName());
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster, AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
accountLockManager, keys, messagesManager, profilesManager, accountLockManager, keysManager, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client, secureStorageClient, secureValueRecovery2Client,
clientPresenceManager, clientPresenceManager,
experimentEnrollmentManager, registrationRecoveryPasswordsManager, accountLockExecutor, clock); experimentEnrollmentManager, registrationRecoveryPasswordsManager, accountLockExecutor, clock);
@ -669,8 +669,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.addService(new AccountsAnonymousGrpcService(accountsManager, rateLimiters)) .addService(new AccountsAnonymousGrpcService(accountsManager, rateLimiters))
.addService(ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters)) .addService(ExternalServiceCredentialsGrpcService.createForAllExternalServices(config, rateLimiters))
.addService(ExternalServiceCredentialsAnonymousGrpcService.create(accountsManager, config)) .addService(ExternalServiceCredentialsAnonymousGrpcService.create(accountsManager, config))
.addService(ServerInterceptors.intercept(new KeysGrpcService(accountsManager, keys, rateLimiters), basicCredentialAuthenticationInterceptor)) .addService(ServerInterceptors.intercept(new KeysGrpcService(accountsManager, keysManager, rateLimiters), basicCredentialAuthenticationInterceptor))
.addService(new KeysAnonymousGrpcService(accountsManager, keys)) .addService(new KeysAnonymousGrpcService(accountsManager, keysManager))
.addService(new PaymentsGrpcService(currencyManager)) .addService(new PaymentsGrpcService(currencyManager))
.addService(ServerInterceptors.intercept(new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager, .addService(ServerInterceptors.intercept(new ProfileGrpcService(clock, accountsManager, profilesManager, dynamicConfigurationManager,
config.getBadges(), asyncCdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, rateLimiters, zkProfileOperations, config.getCdnConfiguration().bucket()), basicCredentialAuthenticationInterceptor)) config.getBadges(), asyncCdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, rateLimiters, zkProfileOperations, config.getCdnConfiguration().bucket()), basicCredentialAuthenticationInterceptor))
@ -725,7 +725,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
turnTokenGenerator, turnTokenGenerator,
registrationRecoveryPasswordsManager, usernameHashZkProofVerifier)); registrationRecoveryPasswordsManager, usernameHashZkProofVerifier));
environment.jersey().register(new KeysController(rateLimiters, keys, accountsManager)); environment.jersey().register(new KeysController(rateLimiters, keysManager, accountsManager));
boolean registeredSpamFilter = false; boolean registeredSpamFilter = false;
ReportSpamTokenProvider reportSpamTokenProvider = null; ReportSpamTokenProvider reportSpamTokenProvider = null;
@ -784,7 +784,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new CallLinkController(rateLimiters, callingGenericZkSecretParams), new CallLinkController(rateLimiters, callingGenericZkSecretParams),
new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().certificate().value(), config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()), zkAuthOperations, callingGenericZkSecretParams, clock), new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().certificate().value(), config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()), zkAuthOperations, callingGenericZkSecretParams, clock),
new ChallengeController(rateLimitChallengeManager), new ChallengeController(rateLimitChallengeManager),
new DeviceController(config.getLinkDeviceSecretConfiguration().secret().value(), accountsManager, messagesManager, keys, rateLimiters, new DeviceController(config.getLinkDeviceSecretConfiguration().secret().value(), accountsManager, messagesManager, keysManager, rateLimiters,
rateLimitersCluster, config.getMaxDevices(), clock), rateLimitersCluster, config.getMaxDevices(), clock),
new DirectoryV2Controller(directoryV2CredentialsGenerator), new DirectoryV2Controller(directoryV2CredentialsGenerator),
new DonationController(clock, zkReceiptOperations, redeemedReceiptsManager, accountsManager, config.getBadges(), new DonationController(clock, zkReceiptOperations, redeemedReceiptsManager, accountsManager, config.getBadges(),
@ -799,7 +799,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getCdnConfiguration().bucket(), zkProfileOperations, batchIdentityCheckExecutor), config.getCdnConfiguration().bucket(), zkProfileOperations, batchIdentityCheckExecutor),
new ProvisioningController(rateLimiters, provisioningManager), new ProvisioningController(rateLimiters, provisioningManager),
new RegistrationController(accountsManager, phoneVerificationTokenManager, registrationLockVerificationManager, new RegistrationController(accountsManager, phoneVerificationTokenManager, registrationLockVerificationManager,
keys, rateLimiters), rateLimiters),
new RemoteConfigController(remoteConfigsManager, adminEventLogger, new RemoteConfigController(remoteConfigsManager, adminEventLogger,
config.getRemoteConfigConfiguration().authorizedUsers(), config.getRemoteConfigConfiguration().authorizedUsers(),
config.getRemoteConfigConfiguration().requiredHostedDomain(), config.getRemoteConfigConfiguration().requiredHostedDomain(),

View File

@ -20,9 +20,7 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
@ -45,8 +43,6 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
@ -69,18 +65,16 @@ public class RegistrationController {
private final AccountsManager accounts; private final AccountsManager accounts;
private final PhoneVerificationTokenManager phoneVerificationTokenManager; private final PhoneVerificationTokenManager phoneVerificationTokenManager;
private final RegistrationLockVerificationManager registrationLockVerificationManager; private final RegistrationLockVerificationManager registrationLockVerificationManager;
private final KeysManager keysManager;
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
public RegistrationController(final AccountsManager accounts, public RegistrationController(final AccountsManager accounts,
final PhoneVerificationTokenManager phoneVerificationTokenManager, final PhoneVerificationTokenManager phoneVerificationTokenManager,
final RegistrationLockVerificationManager registrationLockVerificationManager, final RegistrationLockVerificationManager registrationLockVerificationManager,
final KeysManager keysManager,
final RateLimiters rateLimiters) { final RateLimiters rateLimiters) {
this.accounts = accounts; this.accounts = accounts;
this.phoneVerificationTokenManager = phoneVerificationTokenManager; this.phoneVerificationTokenManager = phoneVerificationTokenManager;
this.registrationLockVerificationManager = registrationLockVerificationManager; this.registrationLockVerificationManager = registrationLockVerificationManager;
this.keysManager = keysManager;
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
} }
@ -141,37 +135,19 @@ public class RegistrationController {
userAgent, RegistrationLockVerificationManager.Flow.REGISTRATION, verificationType); userAgent, RegistrationLockVerificationManager.Flow.REGISTRATION, verificationType);
} }
Account account = accounts.create(number, password, signalAgent, registrationRequest.accountAttributes(), final Account account = accounts.create(number,
existingAccount.map(Account::getBadges).orElseGet(ArrayList::new)); password,
signalAgent,
account = accounts.update(account, a -> { registrationRequest.accountAttributes(),
a.setIdentityKey(registrationRequest.aciIdentityKey()); existingAccount.map(Account::getBadges).orElseGet(ArrayList::new),
a.setPhoneNumberIdentityKey(registrationRequest.pniIdentityKey()); registrationRequest.aciIdentityKey(),
registrationRequest.pniIdentityKey(),
final Device device = a.getPrimaryDevice().orElseThrow(); registrationRequest.deviceActivationRequest().aciSignedPreKey(),
registrationRequest.deviceActivationRequest().pniSignedPreKey(),
device.setSignedPreKey(registrationRequest.deviceActivationRequest().aciSignedPreKey()); registrationRequest.deviceActivationRequest().aciPqLastResortPreKey(),
device.setPhoneNumberIdentitySignedPreKey(registrationRequest.deviceActivationRequest().pniSignedPreKey()); registrationRequest.deviceActivationRequest().pniPqLastResortPreKey(),
registrationRequest.deviceActivationRequest().apnToken(),
registrationRequest.deviceActivationRequest().apnToken().ifPresent(apnRegistrationId -> { registrationRequest.deviceActivationRequest().gcmToken());
device.setApnId(apnRegistrationId.apnRegistrationId());
device.setVoipApnId(apnRegistrationId.voipRegistrationId());
});
registrationRequest.deviceActivationRequest().gcmToken().ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
CompletableFuture.allOf(
keysManager.storeEcSignedPreKeys(a.getUuid(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().aciSignedPreKey())),
keysManager.storePqLastResort(a.getUuid(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().aciPqLastResortPreKey())),
keysManager.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().pniSignedPreKey())),
keysManager.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(Device.PRIMARY_ID, registrationRequest.deviceActivationRequest().pniPqLastResortPreKey())))
.join();
});
Metrics.counter(ACCOUNT_CREATED_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), Metrics.counter(ACCOUNT_CREATED_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)), Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)),

View File

@ -19,6 +19,7 @@ import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -28,6 +29,7 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage; import java.util.concurrent.CompletionStage;
import java.util.function.Function;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
@ -157,6 +159,7 @@ public class Accounts extends AbstractDynamoDbStore {
final String phoneNumberIdentifierConstraintTableName, final String phoneNumberIdentifierConstraintTableName,
final String usernamesConstraintTableName, final String usernamesConstraintTableName,
final String deletedAccountsTableName) { final String deletedAccountsTableName) {
super(client); super(client);
this.clock = clock; this.clock = clock;
this.asyncClient = asyncClient; this.asyncClient = asyncClient;
@ -175,12 +178,14 @@ public class Accounts extends AbstractDynamoDbStore {
final String phoneNumberIdentifierConstraintTableName, final String phoneNumberIdentifierConstraintTableName,
final String usernamesConstraintTableName, final String usernamesConstraintTableName,
final String deletedAccountsTableName) { final String deletedAccountsTableName) {
this(Clock.systemUTC(), client, asyncClient, accountsTableName, this(Clock.systemUTC(), client, asyncClient, accountsTableName,
phoneNumberConstraintTableName, phoneNumberIdentifierConstraintTableName, usernamesConstraintTableName, phoneNumberConstraintTableName, phoneNumberIdentifierConstraintTableName, usernamesConstraintTableName,
deletedAccountsTableName); deletedAccountsTableName);
} }
public boolean create(final Account account) { public boolean create(final Account account, final Function<Account, Collection<TransactWriteItem>> additionalWriteItemsFunction) {
return CREATE_TIMER.record(() -> { return CREATE_TIMER.record(() -> {
try { try {
final AttributeValue uuidAttr = AttributeValues.fromUUID(account.getUuid()); final AttributeValue uuidAttr = AttributeValues.fromUUID(account.getUuid());
@ -199,8 +204,13 @@ public class Accounts extends AbstractDynamoDbStore {
// the newly-created account. // the newly-created account.
final TransactWriteItem deletedAccountDelete = buildRemoveDeletedAccount(account.getNumber()); final TransactWriteItem deletedAccountDelete = buildRemoveDeletedAccount(account.getNumber());
final Collection<TransactWriteItem> writeItems = new ArrayList<>(
List.of(phoneNumberConstraintPut, phoneNumberIdentifierConstraintPut, accountPut, deletedAccountDelete));
writeItems.addAll(additionalWriteItemsFunction.apply(account));
final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder() final TransactWriteItemsRequest request = TransactWriteItemsRequest.builder()
.transactItems(phoneNumberConstraintPut, phoneNumberIdentifierConstraintPut, accountPut, deletedAccountDelete) .transactItems(writeItems)
.build(); .build();
try { try {
@ -229,7 +239,8 @@ public class Accounts extends AbstractDynamoDbStore {
account.setUuid(UUIDUtil.fromByteBuffer(actualAccountUuid)); account.setUuid(UUIDUtil.fromByteBuffer(actualAccountUuid));
final Account existingAccount = getByAccountIdentifier(account.getUuid()).orElseThrow(); final Account existingAccount = getByAccountIdentifier(account.getUuid()).orElseThrow();
account.setNumber(existingAccount.getNumber(), existingAccount.getPhoneNumberIdentifier()); account.setNumber(existingAccount.getNumber(), existingAccount.getPhoneNumberIdentifier());
joinAndUnwrapUpdateFuture(reclaimAccount(existingAccount, account)); joinAndUnwrapUpdateFuture(reclaimAccount(existingAccount, account, additionalWriteItemsFunction.apply(account)));
return false; return false;
} }
@ -254,7 +265,7 @@ public class Accounts extends AbstractDynamoDbStore {
* @param existingAccount the existing account in the accounts table * @param existingAccount the existing account in the accounts table
* @param accountToCreate a new account, with the same number and identifier as existingAccount * @param accountToCreate a new account, with the same number and identifier as existingAccount
*/ */
private CompletionStage<Void> reclaimAccount(final Account existingAccount, final Account accountToCreate) { private CompletionStage<Void> reclaimAccount(final Account existingAccount, final Account accountToCreate, final Collection<TransactWriteItem> additionalWriteItems) {
if (!existingAccount.getUuid().equals(accountToCreate.getUuid()) || if (!existingAccount.getUuid().equals(accountToCreate.getUuid()) ||
!existingAccount.getNumber().equals(accountToCreate.getNumber())) { !existingAccount.getNumber().equals(accountToCreate.getNumber())) {
throw new IllegalArgumentException("reclaimed accounts must match"); throw new IllegalArgumentException("reclaimed accounts must match");
@ -310,6 +321,7 @@ public class Accounts extends AbstractDynamoDbStore {
.build()); .build());
} }
writeItems.add(UpdateAccountSpec.forAccount(accountsTableName, accountToCreate).transactItem()); writeItems.add(UpdateAccountSpec.forAccount(accountsTableName, accountToCreate).transactItem());
writeItems.addAll(additionalWriteItems);
return asyncClient.transactWriteItems(TransactWriteItemsRequest.builder().transactItems(writeItems).build()) return asyncClient.transactWriteItems(TransactWriteItemsRequest.builder().transactItems(writeItems).build())
.thenApply(response -> { .thenApply(response -> {

View File

@ -53,7 +53,9 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
@ -175,17 +177,26 @@ public class AccountsManager {
this.clock = requireNonNull(clock); this.clock = requireNonNull(clock);
} }
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public Account create(final String number, public Account create(final String number,
final String password, final String password,
final String signalAgent, final String signalAgent,
final AccountAttributes accountAttributes, final AccountAttributes accountAttributes,
final List<AccountBadge> accountBadges) throws InterruptedException { final List<AccountBadge> accountBadges,
final IdentityKey aciIdentityKey,
final IdentityKey pniIdentityKey,
final ECSignedPreKey aciSignedPreKey,
final ECSignedPreKey pniSignedPreKey,
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniPqLastResortPreKey,
final Optional<ApnRegistrationId> maybeApnRegistrationId,
final Optional<GcmRegistrationId> maybeGcmRegistrationId) throws InterruptedException {
try (Timer.Context ignored = createTimer.time()) { try (Timer.Context ignored = createTimer.time()) {
final Account account = new Account(); final Account account = new Account();
accountLockManager.withLock(List.of(number), () -> { accountLockManager.withLock(List.of(number), () -> {
Device device = new Device(); final Device device = new Device();
device.setId(Device.PRIMARY_ID); device.setId(Device.PRIMARY_ID);
device.setAuthTokenHash(SaltedTokenHash.generateFor(password)); device.setAuthTokenHash(SaltedTokenHash.generateFor(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages()); device.setFetchesMessages(accountAttributes.getFetchesMessages());
@ -196,6 +207,16 @@ public class AccountsManager {
device.setCreated(System.currentTimeMillis()); device.setCreated(System.currentTimeMillis());
device.setLastSeen(Util.todayInMillis()); device.setLastSeen(Util.todayInMillis());
device.setUserAgent(signalAgent); device.setUserAgent(signalAgent);
device.setSignedPreKey(aciSignedPreKey);
device.setPhoneNumberIdentitySignedPreKey(pniSignedPreKey);
maybeApnRegistrationId.ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId());
device.setVoipApnId(apnRegistrationId.voipRegistrationId());
});
maybeGcmRegistrationId.ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
account.setNumber(number, phoneNumberIdentifiers.getPhoneNumberIdentifier(number)); account.setNumber(number, phoneNumberIdentifiers.getPhoneNumberIdentifier(number));
@ -205,6 +226,8 @@ public class AccountsManager {
// Reuse the ACI from any recently-deleted account with this number to cover cases where somebody is // Reuse the ACI from any recently-deleted account with this number to cover cases where somebody is
// re-registering. // re-registering.
account.setUuid(maybeRecentlyDeletedAccountIdentifier.orElseGet(UUID::randomUUID)); account.setUuid(maybeRecentlyDeletedAccountIdentifier.orElseGet(UUID::randomUUID));
account.setIdentityKey(aciIdentityKey);
account.setPhoneNumberIdentityKey(pniIdentityKey);
account.addDevice(device); account.addDevice(device);
account.setRegistrationLockFromAttributes(accountAttributes); account.setRegistrationLockFromAttributes(accountAttributes);
account.setUnidentifiedAccessKey(accountAttributes.getUnidentifiedAccessKey()); account.setUnidentifiedAccessKey(accountAttributes.getUnidentifiedAccessKey());
@ -214,7 +237,14 @@ public class AccountsManager {
final UUID originalUuid = account.getUuid(); final UUID originalUuid = account.getUuid();
boolean freshUser = accounts.create(account); final boolean freshUser = accounts.create(account,
a -> keysManager.buildWriteItemsForRepeatedUseKeys(a.getIdentifier(IdentityType.ACI),
a.getIdentifier(IdentityType.PNI),
Device.PRIMARY_ID,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey));
// create() sometimes updates the UUID, if there was a number conflict. // create() sometimes updates the UUID, if there was a number conflict.
// for metrics, we want secondary to run with the same original UUID // for metrics, we want secondary to run with the same original UUID
@ -235,9 +265,11 @@ public class AccountsManager {
// confident that everything has already been deleted. In the second case, though, we're taking over an existing // confident that everything has already been deleted. In the second case, though, we're taking over an existing
// account and need to clear out messages and keys that may have been stored for the old account. // account and need to clear out messages and keys that may have been stored for the old account.
if (!originalUuid.equals(actualUuid)) { if (!originalUuid.equals(actualUuid)) {
// We exclude the primary device's repeated-use keys from deletion because new keys were provided as part of
// the account creation process, and we don't want to delete the keys that just got added.
final CompletableFuture<Void> deleteKeysFuture = CompletableFuture.allOf( final CompletableFuture<Void> deleteKeysFuture = CompletableFuture.allOf(
keysManager.delete(actualUuid), keysManager.delete(actualUuid, true),
keysManager.delete(account.getPhoneNumberIdentifier())); keysManager.delete(account.getPhoneNumberIdentifier(), true));
messagesManager.clear(actualUuid).join(); messagesManager.clear(actualUuid).join();
profilesManager.deleteAll(actualUuid).join(); profilesManager.deleteAll(actualUuid).join();

View File

@ -18,6 +18,7 @@ import org.whispersystems.textsecuregcm.entities.ECPreKey;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
public class KeysManager { public class KeysManager {
@ -75,6 +76,20 @@ public class KeysManager {
return CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0])); return CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0]));
} }
public List<TransactWriteItem> buildWriteItemsForRepeatedUseKeys(final UUID accountIdentifier,
final UUID phoneNumberIdentifier,
final byte deviceId,
final ECSignedPreKey aciSignedPreKey,
final ECSignedPreKey pniSignedPreKey,
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniLastResortPreKey) {
return List.of(ecSignedPreKeys.buildTransactWriteItem(accountIdentifier, deviceId, aciSignedPreKey),
ecSignedPreKeys.buildTransactWriteItem(phoneNumberIdentifier, deviceId, pniSignedPreKey),
pqLastResortKeys.buildTransactWriteItem(accountIdentifier, deviceId, aciPqLastResortPreKey),
pqLastResortKeys.buildTransactWriteItem(phoneNumberIdentifier, deviceId, pniLastResortPreKey));
}
public CompletableFuture<Void> storeEcSignedPreKeys(final UUID identifier, final Map<Byte, ECSignedPreKey> keys) { public CompletableFuture<Void> storeEcSignedPreKeys(final UUID identifier, final Map<Byte, ECSignedPreKey> keys) {
if (dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().storeEcSignedPreKeys()) { if (dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().storeEcSignedPreKeys()) {
return ecSignedPreKeys.store(identifier, keys); return ecSignedPreKeys.store(identifier, keys);
@ -134,14 +149,18 @@ public class KeysManager {
return pqPreKeys.getCount(identifier, deviceId); return pqPreKeys.getCount(identifier, deviceId);
} }
public CompletableFuture<Void> delete(final UUID accountUuid) { public CompletableFuture<Void> delete(final UUID identifier) {
return delete(identifier, false);
}
public CompletableFuture<Void> delete(final UUID identifier, final boolean excludePrimaryDevice) {
return CompletableFuture.allOf( return CompletableFuture.allOf(
ecPreKeys.delete(accountUuid), ecPreKeys.delete(identifier),
pqPreKeys.delete(accountUuid), pqPreKeys.delete(identifier),
dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().deleteEcSignedPreKeys() dynamicConfigurationManager.getConfiguration().getEcPreKeyMigrationConfiguration().deleteEcSignedPreKeys()
? ecSignedPreKeys.delete(accountUuid) ? ecSignedPreKeys.delete(identifier, excludePrimaryDevice)
: CompletableFuture.completedFuture(null), : CompletableFuture.completedFuture(null),
pqLastResortKeys.delete(accountUuid)); pqLastResortKeys.delete(identifier, excludePrimaryDevice));
} }
public CompletableFuture<Void> delete(final UUID accountUuid, final byte deviceId) { public CompletableFuture<Void> delete(final UUID accountUuid, final byte deviceId) {

View File

@ -112,6 +112,15 @@ public abstract class RepeatedUseSignedPreKeyStore<K extends SignedPreKey<?>> {
.thenRun(() -> sample.stop(storeKeyBatchTimer)); .thenRun(() -> sample.stop(storeKeyBatchTimer));
} }
TransactWriteItem buildTransactWriteItem(final UUID identifier, final byte deviceId, final K preKey) {
return TransactWriteItem.builder()
.put(Put.builder()
.tableName(tableName)
.item(getItemFromPreKey(identifier, deviceId, preKey))
.build())
.build();
}
/** /**
* Finds a repeated-use pre-key for a specific device. * Finds a repeated-use pre-key for a specific device.
* *
@ -142,14 +151,19 @@ public abstract class RepeatedUseSignedPreKeyStore<K extends SignedPreKey<?>> {
* Clears all repeated-use pre-keys associated with the given account/identity. * Clears all repeated-use pre-keys associated with the given account/identity.
* *
* @param identifier the identifier for the account/identity for which to clear repeated-use pre-keys * @param identifier the identifier for the account/identity for which to clear repeated-use pre-keys
* @param excludePrimaryDevice whether to exclude the primary device from repeated-use key deletion; this is intended
* for cases when a user "re-registers" and displaces an existing account record and has
* provided new repeated-use keys for the primary device in the process of creating the
* new account
* *
* @return a future that completes once repeated-use pre-keys have been cleared from all devices associated with the * @return a future that completes once repeated-use pre-keys have been cleared from all devices associated with the
* target account/identity * target account/identity
*/ */
public CompletableFuture<Void> delete(final UUID identifier) { public CompletableFuture<Void> delete(final UUID identifier, final boolean excludePrimaryDevice) {
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
return getDeviceIdsWithKeys(identifier) return getDeviceIdsWithKeys(identifier)
.filter(deviceId -> deviceId != Device.PRIMARY_ID || !excludePrimaryDevice)
.map(deviceId -> DeleteItemRequest.builder() .map(deviceId -> DeleteItemRequest.builder()
.tableName(tableName) .tableName(tableName)
.key(getPrimaryKey(identifier, deviceId)) .key(getPrimaryKey(identifier, deviceId))

View File

@ -55,7 +55,6 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.signal.libsignal.usernames.BaseUsernameException; import org.signal.libsignal.usernames.BaseUsernameException;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
@ -215,15 +214,6 @@ class AccountControllerTest {
when(accountsManager.getByE164(eq(SENDER_PREAUTH))).thenReturn(Optional.empty()); when(accountsManager.getByE164(eq(SENDER_PREAUTH))).thenReturn(Optional.empty());
when(accountsManager.getByE164(eq(SENDER_HAS_STORAGE))).thenReturn(Optional.of(senderHasStorage)); when(accountsManager.getByE164(eq(SENDER_HAS_STORAGE))).thenReturn(Optional.of(senderHasStorage));
when(accountsManager.getByE164(eq(SENDER_TRANSFER))).thenReturn(Optional.of(senderTransfer)); when(accountsManager.getByE164(eq(SENDER_TRANSFER))).thenReturn(Optional.of(senderTransfer));
when(accountsManager.create(any(), any(), any(), any(), any())).thenAnswer((Answer<Account>) invocation -> {
final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.getNumber()).thenReturn(invocation.getArgument(0, String.class));
when(account.getBadges()).thenReturn(invocation.getArgument(4, List.class));
return account;
});
} }
@AfterEach @AfterEach

View File

@ -7,10 +7,11 @@ package org.whispersystems.textsecuregcm.controllers;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -21,10 +22,12 @@ import io.dropwizard.testing.junit5.ResourceExtension;
import java.io.UncheckedIOException; import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Duration; import java.time.Duration;
import java.util.Arrays;
import java.util.Base64; import java.util.Base64;
import java.util.Collections;
import java.util.EnumSet; import java.util.EnumSet;
import java.util.HashSet; import java.util.HashSet;
import java.util.Map; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
@ -72,7 +75,6 @@ import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
@ -95,7 +97,6 @@ class RegistrationControllerTest {
RegistrationLockVerificationManager.class); RegistrationLockVerificationManager.class);
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock( private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock(
RegistrationRecoveryPasswordsManager.class); RegistrationRecoveryPasswordsManager.class);
private final KeysManager keysManager = mock(KeysManager.class);
private final RateLimiters rateLimiters = mock(RateLimiters.class); private final RateLimiters rateLimiters = mock(RateLimiters.class);
private final RateLimiter registrationLimiter = mock(RateLimiter.class); private final RateLimiter registrationLimiter = mock(RateLimiter.class);
@ -110,7 +111,7 @@ class RegistrationControllerTest {
.addResource( .addResource(
new RegistrationController(accountsManager, new RegistrationController(accountsManager,
new PhoneVerificationTokenManager(registrationServiceClient, registrationRecoveryPasswordsManager), new PhoneVerificationTokenManager(registrationServiceClient, registrationRecoveryPasswordsManager),
registrationLockVerificationManager, keysManager, rateLimiters)) registrationLockVerificationManager, rateLimiters))
.build(); .build();
@BeforeEach @BeforeEach
@ -125,11 +126,6 @@ class RegistrationControllerTest {
return invocation.getArgument(0); return invocation.getArgument(0);
}); });
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storeEcOneTimePreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storeKemOneTimePreKeys(any(), anyByte(), any())).thenReturn(CompletableFuture.completedFuture(null));
} }
@Test @Test
@ -171,7 +167,7 @@ class RegistrationControllerTest {
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class))); when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any())) when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
.thenReturn(account); .thenReturn(account);
final String json = requestJson("sessionId", new byte[0], true, registrationId.orElse(0), pniRegistrationId); final String json = requestJson("sessionId", new byte[0], true, registrationId.orElse(0), pniRegistrationId);
@ -294,7 +290,7 @@ class RegistrationControllerTest {
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class))); when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any())) when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
.thenReturn(account); .thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
@ -352,7 +348,7 @@ class RegistrationControllerTest {
final Account createdAccount = mock(Account.class); final Account createdAccount = mock(Account.class);
when(createdAccount.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class))); when(createdAccount.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any())) when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
.thenReturn(createdAccount); .thenReturn(createdAccount);
expectedStatus = 200; expectedStatus = 200;
@ -406,7 +402,8 @@ class RegistrationControllerTest {
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class))); when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any())).thenReturn(account); when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
.thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration") .target("/v1/registration")
@ -429,7 +426,7 @@ class RegistrationControllerTest {
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class))); when(account.getPrimaryDevice()).thenReturn(Optional.of(mock(Device.class)));
when(accountsManager.create(any(), any(), any(), any(), any())) when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
.thenReturn(account); .thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
@ -669,9 +666,8 @@ class RegistrationControllerTest {
final ECSignedPreKey expectedPniSignedPreKey, final ECSignedPreKey expectedPniSignedPreKey,
final KEMSignedPreKey expectedAciPqLastResortPreKey, final KEMSignedPreKey expectedAciPqLastResortPreKey,
final KEMSignedPreKey expectedPniPqLastResortPreKey, final KEMSignedPreKey expectedPniPqLastResortPreKey,
final Optional<String> expectedApnsToken, final Optional<ApnRegistrationId> expectedApnRegistrationId,
final Optional<String> expectedApnsVoipToken, final Optional<GcmRegistrationId> expectedGcmRegistrationId) throws InterruptedException {
final Optional<String> expectedGcmToken) throws InterruptedException {
when(registrationServiceClient.getSession(any(), any())) when(registrationServiceClient.getSession(any(), any()))
.thenReturn( .thenReturn(
@ -679,9 +675,6 @@ class RegistrationControllerTest {
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null, Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS)))); SESSION_EXPIRATION_SECONDS))));
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final UUID accountIdentifier = UUID.randomUUID(); final UUID accountIdentifier = UUID.randomUUID();
final UUID phoneNumberIdentifier = UUID.randomUUID(); final UUID phoneNumberIdentifier = UUID.randomUUID();
final Device device = mock(Device.class); final Device device = mock(Device.class);
@ -692,9 +685,8 @@ class RegistrationControllerTest {
when(a.getPrimaryDevice()).thenReturn(Optional.of(device)); when(a.getPrimaryDevice()).thenReturn(Optional.of(device));
}); });
when(accountsManager.create(any(), any(), any(), any(), any())).thenReturn(account); when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
.thenReturn(account);
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration") .target("/v1/registration")
@ -705,27 +697,33 @@ class RegistrationControllerTest {
assertEquals(200, response.getStatus()); assertEquals(200, response.getStatus());
} }
verify(accountsManager).create(any(), any(), any(), any(), any()); verify(accountsManager).create(
eq(NUMBER),
eq(PASSWORD),
isNull(),
argThat(attributes -> accountAttributesEqual(attributes, registrationRequest.accountAttributes())),
eq(Collections.emptyList()),
eq(expectedAciIdentityKey),
eq(expectedPniIdentityKey),
eq(expectedAciSignedPreKey),
eq(expectedPniSignedPreKey),
eq(expectedAciPqLastResortPreKey),
eq(expectedPniPqLastResortPreKey),
eq(expectedApnRegistrationId),
eq(expectedGcmRegistrationId));
}
verify(account).setIdentityKey(expectedAciIdentityKey); private static boolean accountAttributesEqual(final AccountAttributes a, final AccountAttributes b) {
verify(account).setPhoneNumberIdentityKey(expectedPniIdentityKey); return a.getFetchesMessages() == b.getFetchesMessages()
&& a.getRegistrationId() == b.getRegistrationId()
verify(device).setSignedPreKey(expectedAciSignedPreKey); && a.isUnrestrictedUnidentifiedAccess() == b.isUnrestrictedUnidentifiedAccess()
verify(device).setPhoneNumberIdentitySignedPreKey(expectedPniSignedPreKey); && a.isDiscoverableByPhoneNumber() == b.isDiscoverableByPhoneNumber()
&& Objects.equals(a.getPhoneNumberIdentityRegistrationId(), b.getPhoneNumberIdentityRegistrationId())
verify(keysManager).storeEcSignedPreKeys(accountIdentifier, Map.of(Device.PRIMARY_ID, expectedAciSignedPreKey)); && Objects.equals(a.getName(), b.getName())
verify(keysManager).storeEcSignedPreKeys(phoneNumberIdentifier, Map.of(Device.PRIMARY_ID, expectedPniSignedPreKey)); && Objects.equals(a.getRegistrationLock(), b.getRegistrationLock())
verify(keysManager).storePqLastResort(accountIdentifier, Map.of(Device.PRIMARY_ID, expectedAciPqLastResortPreKey)); && Arrays.equals(a.getUnidentifiedAccessKey(), b.getUnidentifiedAccessKey())
verify(keysManager).storePqLastResort(phoneNumberIdentifier, Map.of(Device.PRIMARY_ID, expectedPniPqLastResortPreKey)); && Objects.equals(a.getCapabilities(), b.getCapabilities())
&& Objects.equals(a.recoveryPassword(), b.recoveryPassword());
expectedApnsToken.ifPresentOrElse(expectedToken -> verify(device).setApnId(expectedToken),
() -> verify(device, never()).setApnId(any()));
expectedApnsVoipToken.ifPresentOrElse(expectedToken -> verify(device).setVoipApnId(expectedToken),
() -> verify(device, never()).setVoipApnId(any()));
expectedGcmToken.ifPresentOrElse(expectedToken -> verify(device).setGcmId(expectedToken),
() -> verify(device, never()).setGcmId(any()));
} }
private static Stream<Arguments> atomicAccountCreationSuccess() { private static Stream<Arguments> atomicAccountCreationSuccess() {
@ -800,8 +798,7 @@ class RegistrationControllerTest {
pniSignedPreKey, pniSignedPreKey,
aciPqLastResortPreKey, aciPqLastResortPreKey,
pniPqLastResortPreKey, pniPqLastResortPreKey,
Optional.of(apnsToken), Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.of(apnsVoipToken),
Optional.empty()), Optional.empty()),
// requires the request to be atomic // requires the request to be atomic
@ -823,8 +820,7 @@ class RegistrationControllerTest {
pniSignedPreKey, pniSignedPreKey,
aciPqLastResortPreKey, aciPqLastResortPreKey,
pniPqLastResortPreKey, pniPqLastResortPreKey,
Optional.of(apnsToken), Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.of(apnsVoipToken),
Optional.empty()), Optional.empty()),
// Fetches messages; no push tokens // Fetches messages; no push tokens
@ -847,8 +843,7 @@ class RegistrationControllerTest {
aciPqLastResortPreKey, aciPqLastResortPreKey,
pniPqLastResortPreKey, pniPqLastResortPreKey,
Optional.empty(), Optional.empty(),
Optional.empty(), Optional.of(new GcmRegistrationId(gcmToken))));
Optional.of(gcmToken)));
} }
/** /**

View File

@ -0,0 +1,424 @@
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junitpioneer.jupiter.cartesian.ArgumentSets;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
public class AccountCreationIntegrationTest {
@RegisterExtension
static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(
DynamoDbExtensionSchema.Tables.ACCOUNTS,
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS,
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS_LOCK,
DynamoDbExtensionSchema.Tables.NUMBERS,
DynamoDbExtensionSchema.Tables.PNI,
DynamoDbExtensionSchema.Tables.PNI_ASSIGNMENTS,
DynamoDbExtensionSchema.Tables.USERNAMES,
DynamoDbExtensionSchema.Tables.EC_KEYS,
DynamoDbExtensionSchema.Tables.PQ_KEYS,
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS,
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
@RegisterExtension
static final RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private static final Clock CLOCK = Clock.fixed(Instant.now(), ZoneId.systemDefault());
private ExecutorService accountLockExecutor;
private AccountsManager accountsManager;
private KeysManager keysManager;
record DeliveryChannels(boolean fetchesMessages, String apnsToken, String apnsVoipToken, String fcmToken) {}
@BeforeEach
void setUp() {
@SuppressWarnings("unchecked") final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager =
mock(DynamicConfigurationManager.class);
DynamicConfiguration dynamicConfiguration = new DynamicConfiguration();
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.EC_KEYS.tableName(),
DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName(),
DynamoDbExtensionSchema.Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(),
DynamoDbExtensionSchema.Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName(),
dynamicConfigurationManager);
final Accounts accounts = new Accounts(
DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.ACCOUNTS.tableName(),
DynamoDbExtensionSchema.Tables.NUMBERS.tableName(),
DynamoDbExtensionSchema.Tables.PNI_ASSIGNMENTS.tableName(),
DynamoDbExtensionSchema.Tables.USERNAMES.tableName(),
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS.tableName());
accountLockExecutor = Executors.newSingleThreadExecutor();
final AccountLockManager accountLockManager = new AccountLockManager(DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS_LOCK.tableName());
final SecureStorageClient secureStorageClient = mock(SecureStorageClient.class);
when(secureStorageClient.deleteStoredData(any())).thenReturn(CompletableFuture.completedFuture(null));
final SecureValueRecovery2Client svr2Client = mock(SecureValueRecovery2Client.class);
when(svr2Client.deleteBackups(any())).thenReturn(CompletableFuture.completedFuture(null));
final PhoneNumberIdentifiers phoneNumberIdentifiers =
new PhoneNumberIdentifiers(DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DynamoDbExtensionSchema.Tables.PNI.tableName());
final MessagesManager messagesManager = mock(MessagesManager.class);
when(messagesManager.clear(any())).thenReturn(CompletableFuture.completedFuture(null));
final ProfilesManager profilesManager = mock(ProfilesManager.class);
when(profilesManager.deleteAll(any())).thenReturn(CompletableFuture.completedFuture(null));
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager =
mock(RegistrationRecoveryPasswordsManager.class);
when(registrationRecoveryPasswordsManager.removeForNumber(any()))
.thenReturn(CompletableFuture.completedFuture(null));
accountsManager = new AccountsManager(
accounts,
phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
accountLockManager,
keysManager,
messagesManager,
profilesManager,
secureStorageClient,
svr2Client,
mock(ClientPresenceManager.class),
mock(ExperimentEnrollmentManager.class),
registrationRecoveryPasswordsManager,
accountLockExecutor,
CLOCK);
}
@AfterEach
void tearDown() throws InterruptedException {
accountLockExecutor.shutdown();
//noinspection ResultOfMethodCallIgnored
accountLockExecutor.awaitTermination(1, TimeUnit.SECONDS);
}
@CartesianTest
@CartesianTest.MethodFactory("createAccount")
void createAccount(final DeliveryChannels deliveryChannels,
final boolean discoverableByPhoneNumber) throws InterruptedException {
final String number = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
final String password = RandomStringUtils.randomAlphanumeric(16);
final String signalAgent = RandomStringUtils.randomAlphabetic(3);
final int registrationId = ThreadLocalRandom.current().nextInt(Device.MAX_REGISTRATION_ID);
final int pniRegistrationId = ThreadLocalRandom.current().nextInt(Device.MAX_REGISTRATION_ID);
final String deviceName = RandomStringUtils.randomAlphabetic(16);
final String registrationLockSecret = RandomStringUtils.randomAlphanumeric(16);
final Device.DeviceCapabilities deviceCapabilities = new Device.DeviceCapabilities(
ThreadLocalRandom.current().nextBoolean(),
ThreadLocalRandom.current().nextBoolean(),
ThreadLocalRandom.current().nextBoolean(),
ThreadLocalRandom.current().nextBoolean());
final AccountAttributes accountAttributes = new AccountAttributes(deliveryChannels.fetchesMessages(),
registrationId,
deviceName,
registrationLockSecret,
discoverableByPhoneNumber,
deviceCapabilities);
accountAttributes.setPhoneNumberIdentityRegistrationId(pniRegistrationId);
final List<AccountBadge> badges = new ArrayList<>(List.of(new AccountBadge(
RandomStringUtils.randomAlphabetic(8),
CLOCK.instant().plus(Duration.ofDays(7)),
true)));
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
final ECSignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciKeyPair);
final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniKeyPair);
final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciKeyPair);
final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniKeyPair);
final Optional<ApnRegistrationId> maybeApnRegistrationId =
deliveryChannels.apnsToken() != null || deliveryChannels.apnsVoipToken() != null
? Optional.of(new ApnRegistrationId(deliveryChannels.apnsToken(), deliveryChannels.apnsVoipToken()))
: Optional.empty();
final Optional<GcmRegistrationId> maybeGcmRegistrationId = deliveryChannels.fcmToken() != null
? Optional.of(new GcmRegistrationId(deliveryChannels.fcmToken()))
: Optional.empty();
final Account account = accountsManager.create(number,
password,
signalAgent,
accountAttributes,
badges,
new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
maybeApnRegistrationId,
maybeGcmRegistrationId);
assertExpectedStoredAccount(account,
number,
password,
signalAgent,
deliveryChannels,
registrationId,
pniRegistrationId,
deviceName,
discoverableByPhoneNumber,
deviceCapabilities,
badges,
maybeApnRegistrationId,
maybeGcmRegistrationId,
registrationLockSecret,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey);
}
@CartesianTest
@CartesianTest.MethodFactory("createAccount")
void reregisterAccount(final DeliveryChannels deliveryChannels,
final boolean discoverableByPhoneNumber) throws InterruptedException {
final String number = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
final UUID existingAccountUuid;
{
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
final ECSignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciKeyPair);
final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniKeyPair);
final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciKeyPair);
final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniKeyPair);
final Account originalAccount = accountsManager.create(number,
RandomStringUtils.randomAlphanumeric(16),
"OWI",
new AccountAttributes(true, 1, "name", "registration-lock", false, new Device.DeviceCapabilities(false, false, false, false)),
Collections.emptyList(),
new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.empty());
existingAccountUuid = originalAccount.getUuid();
}
final String password = RandomStringUtils.randomAlphanumeric(16);
final String signalAgent = RandomStringUtils.randomAlphabetic(3);
final int registrationId = ThreadLocalRandom.current().nextInt(Device.MAX_REGISTRATION_ID);
final int pniRegistrationId = ThreadLocalRandom.current().nextInt(Device.MAX_REGISTRATION_ID);
final String deviceName = RandomStringUtils.randomAlphabetic(16);
final String registrationLockSecret = RandomStringUtils.randomAlphanumeric(16);
final Device.DeviceCapabilities deviceCapabilities = new Device.DeviceCapabilities(
ThreadLocalRandom.current().nextBoolean(),
ThreadLocalRandom.current().nextBoolean(),
ThreadLocalRandom.current().nextBoolean(),
ThreadLocalRandom.current().nextBoolean());
final AccountAttributes accountAttributes = new AccountAttributes(deliveryChannels.fetchesMessages(),
registrationId,
deviceName,
registrationLockSecret,
discoverableByPhoneNumber,
deviceCapabilities);
accountAttributes.setPhoneNumberIdentityRegistrationId(pniRegistrationId);
final List<AccountBadge> badges = new ArrayList<>(List.of(new AccountBadge(
RandomStringUtils.randomAlphabetic(8),
CLOCK.instant().plus(Duration.ofDays(7)),
true)));
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
final ECSignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciKeyPair);
final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniKeyPair);
final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciKeyPair);
final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniKeyPair);
final Optional<ApnRegistrationId> maybeApnRegistrationId =
deliveryChannels.apnsToken() != null || deliveryChannels.apnsVoipToken() != null
? Optional.of(new ApnRegistrationId(deliveryChannels.apnsToken(), deliveryChannels.apnsVoipToken()))
: Optional.empty();
final Optional<GcmRegistrationId> maybeGcmRegistrationId = deliveryChannels.fcmToken() != null
? Optional.of(new GcmRegistrationId(deliveryChannels.fcmToken()))
: Optional.empty();
final Account reregisteredAccount = accountsManager.create(number,
password,
signalAgent,
accountAttributes,
badges,
new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
maybeApnRegistrationId,
maybeGcmRegistrationId);
assertExpectedStoredAccount(reregisteredAccount,
number,
password,
signalAgent,
deliveryChannels,
registrationId,
pniRegistrationId,
deviceName,
discoverableByPhoneNumber,
deviceCapabilities,
badges,
maybeApnRegistrationId,
maybeGcmRegistrationId,
registrationLockSecret,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey);
assertEquals(existingAccountUuid, reregisteredAccount.getUuid());
}
@SuppressWarnings("unused")
static ArgumentSets createAccount() {
return ArgumentSets
// deliveryChannels
.argumentsForFirstParameter(
new DeliveryChannels(true, null, null, null),
new DeliveryChannels(false, "apns-token", null, null),
new DeliveryChannels(false, "apns-token", "apns-voip-token", null),
new DeliveryChannels(false, null, null, "fcm-token"))
// discoverableByPhoneNumber
.argumentsForNextParameter(true, false);
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
private void assertExpectedStoredAccount(final Account account,
final String number,
final String password,
final String signalAgent,
final DeliveryChannels deliveryChannels,
final int registrationId,
final int pniRegistrationId,
final String deviceName,
final boolean discoverableByPhoneNumber,
final Device.DeviceCapabilities deviceCapabilities,
final List<AccountBadge> badges,
final Optional<ApnRegistrationId> maybeApnRegistrationId,
final Optional<GcmRegistrationId> maybeGcmRegistrationId,
final String registrationLockSecret,
final ECSignedPreKey aciSignedPreKey,
final ECSignedPreKey pniSignedPreKey,
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniPqLastResortPreKey) {
final Device primaryDevice = account.getPrimaryDevice().orElseThrow();
assertEquals(number, account.getNumber());
assertEquals(signalAgent, primaryDevice.getUserAgent());
assertEquals(deliveryChannels.fetchesMessages(), primaryDevice.getFetchesMessages());
assertEquals(registrationId, primaryDevice.getRegistrationId());
assertEquals(pniRegistrationId, primaryDevice.getPhoneNumberIdentityRegistrationId().orElseThrow());
assertEquals(deviceName, primaryDevice.getName());
assertEquals(discoverableByPhoneNumber, account.isDiscoverableByPhoneNumber());
assertEquals(deviceCapabilities, primaryDevice.getCapabilities());
assertEquals(badges, account.getBadges());
maybeApnRegistrationId.ifPresentOrElse(apnRegistrationId -> {
assertEquals(apnRegistrationId.apnRegistrationId(), primaryDevice.getApnId());
assertEquals(apnRegistrationId.voipRegistrationId(), primaryDevice.getVoipApnId());
}, () -> {
assertNull(primaryDevice.getApnId());
assertNull(primaryDevice.getVoipApnId());
});
maybeGcmRegistrationId.ifPresentOrElse(gcmRegistrationId -> {
assertEquals(deliveryChannels.fcmToken(), primaryDevice.getGcmId());
}, () -> {
assertNull(primaryDevice.getGcmId());
});
assertTrue(account.getRegistrationLock().verify(registrationLockSecret));
assertTrue(primaryDevice.getAuthTokenHash().verify(password));
assertEquals(Optional.of(aciSignedPreKey), keysManager.getEcSignedPreKey(account.getIdentifier(IdentityType.ACI), Device.PRIMARY_ID).join());
assertEquals(Optional.of(pniSignedPreKey), keysManager.getEcSignedPreKey(account.getIdentifier(IdentityType.PNI), Device.PRIMARY_ID).join());
assertEquals(Optional.of(aciPqLastResortPreKey), keysManager.getLastResort(account.getIdentifier(IdentityType.ACI), Device.PRIMARY_ID).join());
assertEquals(Optional.of(pniPqLastResortPreKey), keysManager.getLastResort(account.getIdentifier(IdentityType.PNI), Device.PRIMARY_ID).join());
}
}

View File

@ -14,7 +14,6 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import java.time.Clock; import java.time.Clock;
import java.util.ArrayList;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.OptionalInt; import java.util.OptionalInt;
@ -41,6 +40,7 @@ import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
class AccountsManagerChangeNumberIntegrationTest { class AccountsManagerChangeNumberIntegrationTest {
@ -53,7 +53,11 @@ class AccountsManagerChangeNumberIntegrationTest {
Tables.NUMBERS, Tables.NUMBERS,
Tables.PNI, Tables.PNI,
Tables.PNI_ASSIGNMENTS, Tables.PNI_ASSIGNMENTS,
Tables.USERNAMES); Tables.USERNAMES,
Tables.EC_KEYS,
Tables.PQ_KEYS,
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS,
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
@RegisterExtension @RegisterExtension
static final RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); static final RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@ -73,6 +77,14 @@ class AccountsManagerChangeNumberIntegrationTest {
DynamicConfiguration dynamicConfiguration = new DynamicConfiguration(); DynamicConfiguration dynamicConfiguration = new DynamicConfiguration();
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
final KeysManager keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
Tables.EC_KEYS.tableName(),
Tables.PQ_KEYS.tableName(),
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(),
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName(),
dynamicConfigurationManager);
final Accounts accounts = new Accounts( final Accounts accounts = new Accounts(
DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
@ -98,9 +110,6 @@ class AccountsManagerChangeNumberIntegrationTest {
final PhoneNumberIdentifiers phoneNumberIdentifiers = final PhoneNumberIdentifiers phoneNumberIdentifiers =
new PhoneNumberIdentifiers(DYNAMO_DB_EXTENSION.getDynamoDbClient(), Tables.PNI.tableName()); new PhoneNumberIdentifiers(DYNAMO_DB_EXTENSION.getDynamoDbClient(), Tables.PNI.tableName());
final KeysManager keysManager = mock(KeysManager.class);
when(keysManager.delete(any())).thenReturn(CompletableFuture.completedFuture(null));
final MessagesManager messagesManager = mock(MessagesManager.class); final MessagesManager messagesManager = mock(MessagesManager.class);
when(messagesManager.clear(any())).thenReturn(CompletableFuture.completedFuture(null)); when(messagesManager.clear(any())).thenReturn(CompletableFuture.completedFuture(null));
@ -143,8 +152,8 @@ class AccountsManagerChangeNumberIntegrationTest {
void testChangeNumber() throws InterruptedException, MismatchedDevicesException { void testChangeNumber() throws InterruptedException, MismatchedDevicesException {
final String originalNumber = "+18005551111"; final String originalNumber = "+18005551111";
final String secondNumber = "+18005552222"; final String secondNumber = "+18005552222";
final Account account = AccountsHelper.createAccount(accountsManager, originalNumber);
final Account account = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>());
final UUID originalUuid = account.getUuid(); final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier(); final UUID originalPni = account.getPhoneNumberIdentifier();
@ -167,17 +176,17 @@ class AccountsManagerChangeNumberIntegrationTest {
final String originalNumber = "+18005551111"; final String originalNumber = "+18005551111";
final String secondNumber = "+18005552222"; final String secondNumber = "+18005552222";
final int rotatedPniRegistrationId = 17; final int rotatedPniRegistrationId = 17;
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair rotatedPniIdentityKeyPair = Curve.generateKeyPair();
final ECSignedPreKey rotatedSignedPreKey = KeysHelper.signedECPreKey(1L, pniIdentityKeyPair); final ECSignedPreKey rotatedSignedPreKey = KeysHelper.signedECPreKey(1L, rotatedPniIdentityKeyPair);
final AccountAttributes accountAttributes = new AccountAttributes(true, rotatedPniRegistrationId + 1, "test", null, true, new Device.DeviceCapabilities(false, false, false, false)); final AccountAttributes accountAttributes = new AccountAttributes(true, rotatedPniRegistrationId + 1, "test", null, true, new Device.DeviceCapabilities(false, false, false, false));
final Account account = accountsManager.create(originalNumber, "password", null, accountAttributes, new ArrayList<>()); final Account account = AccountsHelper.createAccount(accountsManager, originalNumber, accountAttributes);
account.getPrimaryDevice().orElseThrow().setSignedPreKey(KeysHelper.signedECPreKey(1, pniIdentityKeyPair));
account.getPrimaryDevice().orElseThrow().setSignedPreKey(KeysHelper.signedECPreKey(1, rotatedPniIdentityKeyPair));
final UUID originalUuid = account.getUuid(); final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier(); final UUID originalPni = account.getPhoneNumberIdentifier();
final IdentityKey pniIdentityKey = new IdentityKey(pniIdentityKeyPair.getPublicKey()); final IdentityKey pniIdentityKey = new IdentityKey(rotatedPniIdentityKeyPair.getPublicKey());
final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID, rotatedSignedPreKey); final Map<Byte, ECSignedPreKey> preKeys = Map.of(Device.PRIMARY_ID, rotatedSignedPreKey);
final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, rotatedPniRegistrationId); final Map<Byte, Integer> registrationIds = Map.of(Device.PRIMARY_ID, rotatedPniRegistrationId);
@ -207,7 +216,8 @@ class AccountsManagerChangeNumberIntegrationTest {
final String originalNumber = "+18005551111"; final String originalNumber = "+18005551111";
final String secondNumber = "+18005552222"; final String secondNumber = "+18005552222";
Account account = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>()); Account account = AccountsHelper.createAccount(accountsManager, originalNumber);
final UUID originalUuid = account.getUuid(); final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier(); final UUID originalPni = account.getPhoneNumberIdentifier();
@ -231,10 +241,12 @@ class AccountsManagerChangeNumberIntegrationTest {
final String originalNumber = "+18005551111"; final String originalNumber = "+18005551111";
final String secondNumber = "+18005552222"; final String secondNumber = "+18005552222";
final Account account = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>()); final Account account = AccountsHelper.createAccount(accountsManager, originalNumber);
final UUID originalUuid = account.getUuid(); final UUID originalUuid = account.getUuid();
final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>()); final Account existingAccount = AccountsHelper.createAccount(accountsManager, secondNumber);
final UUID existingAccountUuid = existingAccount.getUuid(); final UUID existingAccountUuid = existingAccount.getUuid();
accountsManager.changeNumber(account, secondNumber, null, null, null, null); accountsManager.changeNumber(account, secondNumber, null, null, null, null);
@ -253,8 +265,7 @@ class AccountsManagerChangeNumberIntegrationTest {
accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(), originalNumber, null, null, null, null); accountsManager.changeNumber(accountsManager.getByAccountIdentifier(originalUuid).orElseThrow(), originalNumber, null, null, null, null);
final Account existingAccount2 = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), final Account existingAccount2 = AccountsHelper.createAccount(accountsManager, secondNumber);
new ArrayList<>());
assertEquals(existingAccountUuid, existingAccount2.getUuid()); assertEquals(existingAccountUuid, existingAccount2.getUuid());
} }
@ -264,17 +275,19 @@ class AccountsManagerChangeNumberIntegrationTest {
final String originalNumber = "+18005551111"; final String originalNumber = "+18005551111";
final String secondNumber = "+18005552222"; final String secondNumber = "+18005552222";
final Account account = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>()); final Account account = AccountsHelper.createAccount(accountsManager, originalNumber);
final UUID originalUuid = account.getUuid(); final UUID originalUuid = account.getUuid();
final UUID originalPni = account.getPhoneNumberIdentifier(); final UUID originalPni = account.getPhoneNumberIdentifier();
final Account existingAccount = accountsManager.create(secondNumber, "password", null, new AccountAttributes(), new ArrayList<>()); final Account existingAccount = AccountsHelper.createAccount(accountsManager, secondNumber);
final UUID existingAccountUuid = existingAccount.getUuid(); final UUID existingAccountUuid = existingAccount.getUuid();
final Account changedNumberAccount = accountsManager.changeNumber(account, secondNumber, null, null, null, null); final Account changedNumberAccount = accountsManager.changeNumber(account, secondNumber, null, null, null, null);
final UUID secondPni = changedNumberAccount.getPhoneNumberIdentifier(); final UUID secondPni = changedNumberAccount.getPhoneNumberIdentifier();
final Account reRegisteredAccount = accountsManager.create(originalNumber, "password", null, new AccountAttributes(), new ArrayList<>()); final Account reRegisteredAccount = AccountsHelper.createAccount(accountsManager, originalNumber);
assertEquals(existingAccountUuid, reRegisteredAccount.getUuid()); assertEquals(existingAccountUuid, reRegisteredAccount.getUuid());
assertEquals(originalPni, reRegisteredAccount.getPhoneNumberIdentifier()); assertEquals(originalPni, reRegisteredAccount.getPhoneNumberIdentifier());

View File

@ -23,6 +23,7 @@ import java.io.IOException;
import java.time.Clock; import java.time.Clock;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
@ -39,6 +40,7 @@ import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
@ -51,6 +53,7 @@ import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.tests.util.JsonHelpers; import org.whispersystems.textsecuregcm.tests.util.JsonHelpers;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
@ -62,8 +65,11 @@ class AccountsManagerConcurrentModificationIntegrationTest {
Tables.ACCOUNTS, Tables.ACCOUNTS,
Tables.NUMBERS, Tables.NUMBERS,
Tables.PNI_ASSIGNMENTS, Tables.PNI_ASSIGNMENTS,
Tables.DELETED_ACCOUNTS Tables.DELETED_ACCOUNTS,
); Tables.EC_KEYS,
Tables.PQ_KEYS,
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS,
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
private Accounts accounts; private Accounts accounts;
@ -80,6 +86,14 @@ class AccountsManagerConcurrentModificationIntegrationTest {
mock(DynamicConfigurationManager.class); mock(DynamicConfigurationManager.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
final KeysManager keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
Tables.EC_KEYS.tableName(),
Tables.PQ_KEYS.tableName(),
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(),
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName(),
dynamicConfigurationManager);
accounts = new Accounts( accounts = new Accounts(
DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
@ -134,11 +148,25 @@ class AccountsManagerConcurrentModificationIntegrationTest {
@Test @Test
void testConcurrentUpdate() throws IOException, InterruptedException { void testConcurrentUpdate() throws IOException, InterruptedException {
final UUID uuid; final UUID uuid;
{ {
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
final Account account = accountsManager.update( final Account account = accountsManager.update(
accountsManager.create("+14155551212", "password", null, new AccountAttributes(), new ArrayList<>()), accountsManager.create("+14155551212",
"password",
null,
new AccountAttributes(),
new ArrayList<>(),
new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair),
Optional.empty(),
Optional.empty()),
a -> { a -> {
a.setUnidentifiedAccessKey(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); a.setUnidentifiedAccessKey(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
a.removeDevice(Device.PRIMARY_ID); a.removeDevice(Device.PRIMARY_ID);

View File

@ -13,6 +13,7 @@ import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
@ -201,6 +202,7 @@ class AccountsManagerTest {
when(registrationRecoveryPasswordsManager.removeForNumber(anyString())).thenReturn(CompletableFuture.completedFuture(null)); when(registrationRecoveryPasswordsManager.removeForNumber(anyString())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.delete(any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.delete(any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.delete(any(), anyBoolean())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.clear(any())).thenReturn(CompletableFuture.completedFuture(null)); when(messagesManager.clear(any())).thenReturn(CompletableFuture.completedFuture(null));
when(profilesManager.deleteAll(any())).thenReturn(CompletableFuture.completedFuture(null)); when(profilesManager.deleteAll(any())).thenReturn(CompletableFuture.completedFuture(null));
@ -853,7 +855,7 @@ class AccountsManagerTest {
when(commands.get(eq("Account3::" + uuid))).thenReturn(null); when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifier(uuid)).thenReturn(Optional.empty()) when(accounts.getByAccountIdentifier(uuid)).thenReturn(Optional.empty())
.thenReturn(Optional.of(account)); .thenReturn(Optional.of(account));
when(accounts.create(any())).thenThrow(ContestedOptimisticLockException.class); when(accounts.create(any(), any())).thenThrow(ContestedOptimisticLockException.class);
accountsManager.update(account, a -> { accountsManager.update(account, a -> {
}); });
@ -930,14 +932,15 @@ class AccountsManagerTest {
@Test @Test
void testCreateFreshAccount() throws InterruptedException { void testCreateFreshAccount() throws InterruptedException {
when(accounts.create(any())).thenReturn(true); when(accounts.create(any(), any())).thenReturn(true);
final String e164 = "+18005550123"; final String e164 = "+18005550123";
final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true, null); final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true, null);
accountsManager.create(e164, "password", null, attributes, new ArrayList<>());
verify(accounts).create(argThat(account -> e164.equals(account.getNumber()))); createAccount(e164, attributes);
verifyNoInteractions(keysManager);
verify(accounts).create(argThat(account -> e164.equals(account.getNumber())), any());
verifyNoInteractions(messagesManager); verifyNoInteractions(messagesManager);
verifyNoInteractions(profilesManager); verifyNoInteractions(profilesManager);
} }
@ -946,22 +949,23 @@ class AccountsManagerTest {
void testReregisterAccount() throws InterruptedException { void testReregisterAccount() throws InterruptedException {
final UUID existingUuid = UUID.randomUUID(); final UUID existingUuid = UUID.randomUUID();
when(accounts.create(any())).thenAnswer(invocation -> { when(accounts.create(any(), any())).thenAnswer(invocation -> {
invocation.getArgument(0, Account.class).setUuid(existingUuid); invocation.getArgument(0, Account.class).setUuid(existingUuid);
return false; return false;
}); });
final String e164 = "+18005550123"; final String e164 = "+18005550123";
final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true, null); final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true, null);
accountsManager.create(e164, "password", null, attributes, new ArrayList<>());
createAccount(e164, attributes);
assertTrue(phoneNumberIdentifiersByE164.containsKey(e164)); assertTrue(phoneNumberIdentifiersByE164.containsKey(e164));
verify(accounts) verify(accounts)
.create(argThat(account -> e164.equals(account.getNumber()) && existingUuid.equals(account.getUuid()))); .create(argThat(account -> e164.equals(account.getNumber()) && existingUuid.equals(account.getUuid())), any());
verify(keysManager).delete(existingUuid); verify(keysManager).delete(existingUuid, true);
verify(keysManager).delete(phoneNumberIdentifiersByE164.get(e164)); verify(keysManager).delete(phoneNumberIdentifiersByE164.get(e164), true);
verify(messagesManager).clear(existingUuid); verify(messagesManager).clear(existingUuid);
verify(profilesManager).deleteAll(existingUuid); verify(profilesManager).deleteAll(existingUuid);
verify(clientPresenceManager).disconnectAllPresencesForUuid(existingUuid); verify(clientPresenceManager).disconnectAllPresencesForUuid(existingUuid);
@ -972,14 +976,17 @@ class AccountsManagerTest {
final UUID recentlyDeletedUuid = UUID.randomUUID(); final UUID recentlyDeletedUuid = UUID.randomUUID();
when(accounts.findRecentlyDeletedAccountIdentifier(anyString())).thenReturn(Optional.of(recentlyDeletedUuid)); when(accounts.findRecentlyDeletedAccountIdentifier(anyString())).thenReturn(Optional.of(recentlyDeletedUuid));
when(accounts.create(any())).thenReturn(true); when(accounts.create(any(), any())).thenReturn(true);
final String e164 = "+18005550123"; final String e164 = "+18005550123";
final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true, null); final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true, null);
accountsManager.create(e164, "password", null, attributes, new ArrayList<>());
createAccount(e164, attributes);
verify(accounts).create( verify(accounts).create(
argThat(account -> e164.equals(account.getNumber()) && recentlyDeletedUuid.equals(account.getUuid()))); argThat(account -> e164.equals(account.getNumber()) && recentlyDeletedUuid.equals(account.getUuid())),
any());
verifyNoInteractions(keysManager); verifyNoInteractions(keysManager);
verifyNoInteractions(messagesManager); verifyNoInteractions(messagesManager);
verifyNoInteractions(profilesManager); verifyNoInteractions(profilesManager);
@ -989,7 +996,7 @@ class AccountsManagerTest {
@ValueSource(booleans = {true, false}) @ValueSource(booleans = {true, false})
void testCreateWithDiscoverability(final boolean discoverable) throws InterruptedException { void testCreateWithDiscoverability(final boolean discoverable) throws InterruptedException {
final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, discoverable, null); final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, discoverable, null);
final Account account = accountsManager.create("+18005550123", "password", null, attributes, new ArrayList<>()); final Account account = createAccount("+18005550123", attributes);
assertEquals(discoverable, account.isDiscoverableByPhoneNumber()); assertEquals(discoverable, account.isDiscoverableByPhoneNumber());
} }
@ -1000,7 +1007,7 @@ class AccountsManagerTest {
final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true, final AccountAttributes attributes = new AccountAttributes(false, 0, null, null, true,
new DeviceCapabilities(hasStorage, false, false, false)); new DeviceCapabilities(hasStorage, false, false, false));
final Account account = accountsManager.create("+18005550123", "password", null, attributes, new ArrayList<>()); final Account account = createAccount("+18005550123", attributes);
assertEquals(hasStorage, account.isStorageSupported()); assertEquals(hasStorage, account.isStorageSupported());
} }
@ -1572,4 +1579,23 @@ class AccountsManagerTest {
return device; return device;
} }
private Account createAccount(final String e164, final AccountAttributes accountAttributes) throws InterruptedException {
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
return accountsManager.create(e164,
"password",
null,
accountAttributes,
new ArrayList<>(),
new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair),
Optional.empty(),
Optional.empty());
}
} }

View File

@ -25,6 +25,7 @@ import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Base64; import java.util.Base64;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -39,13 +40,13 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.AttributeValues;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
@ -70,7 +71,11 @@ class AccountsManagerUsernameIntegrationTest {
Tables.USERNAMES, Tables.USERNAMES,
Tables.DELETED_ACCOUNTS, Tables.DELETED_ACCOUNTS,
Tables.PNI, Tables.PNI,
Tables.PNI_ASSIGNMENTS); Tables.PNI_ASSIGNMENTS,
Tables.EC_KEYS,
Tables.PQ_KEYS,
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS,
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS);
@RegisterExtension @RegisterExtension
static RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); static RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@ -91,6 +96,14 @@ class AccountsManagerUsernameIntegrationTest {
DynamicConfiguration dynamicConfiguration = new DynamicConfiguration(); DynamicConfiguration dynamicConfiguration = new DynamicConfiguration();
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
final KeysManager keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
Tables.EC_KEYS.tableName(),
Tables.PQ_KEYS.tableName(),
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(),
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName(),
dynamicConfigurationManager);
accounts = Mockito.spy(new Accounts( accounts = Mockito.spy(new Accounts(
DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
@ -122,12 +135,13 @@ class AccountsManagerUsernameIntegrationTest {
final ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class); final ExperimentEnrollmentManager experimentEnrollmentManager = mock(ExperimentEnrollmentManager.class);
when(experimentEnrollmentManager.isEnrolled(any(UUID.class), eq(AccountsManager.USERNAME_EXPERIMENT_NAME))) when(experimentEnrollmentManager.isEnrolled(any(UUID.class), eq(AccountsManager.USERNAME_EXPERIMENT_NAME)))
.thenReturn(true); .thenReturn(true);
accountsManager = new AccountsManager( accountsManager = new AccountsManager(
accounts, accounts,
phoneNumberIdentifiers, phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(), CACHE_CLUSTER_EXTENSION.getRedisCluster(),
accountLockManager, accountLockManager,
mock(KeysManager.class), keysManager,
mock(MessagesManager.class), mock(MessagesManager.class),
mock(ProfilesManager.class), mock(ProfilesManager.class),
mock(SecureStorageClient.class), mock(SecureStorageClient.class),
@ -141,8 +155,8 @@ class AccountsManagerUsernameIntegrationTest {
@Test @Test
void testNoUsernames() throws InterruptedException { void testNoUsernames() throws InterruptedException {
Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(), final Account account = AccountsHelper.createAccount(accountsManager, "+18005551111");
new ArrayList<>());
List<byte[]> usernameHashes = List.of(USERNAME_HASH_1, USERNAME_HASH_2); List<byte[]> usernameHashes = List.of(USERNAME_HASH_1, USERNAME_HASH_2);
int i = 0; int i = 0;
for (byte[] hash : usernameHashes) { for (byte[] hash : usernameHashes) {
@ -169,8 +183,8 @@ class AccountsManagerUsernameIntegrationTest {
@Test @Test
void testReserveUsernameSnatched() throws InterruptedException, UsernameHashNotAvailableException { void testReserveUsernameSnatched() throws InterruptedException, UsernameHashNotAvailableException {
final Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(), final Account account = AccountsHelper.createAccount(accountsManager, "+18005551111");
new ArrayList<>());
ArrayList<byte[]> usernameHashes = new ArrayList<>(Arrays.asList(USERNAME_HASH_1, USERNAME_HASH_2)); ArrayList<byte[]> usernameHashes = new ArrayList<>(Arrays.asList(USERNAME_HASH_1, USERNAME_HASH_2));
for (byte[] hash : usernameHashes) { for (byte[] hash : usernameHashes) {
DYNAMO_DB_EXTENSION.getDynamoDbClient().putItem(PutItemRequest.builder() DYNAMO_DB_EXTENSION.getDynamoDbClient().putItem(PutItemRequest.builder()
@ -205,10 +219,8 @@ class AccountsManagerUsernameIntegrationTest {
} }
@Test @Test
public void testReserveConfirmClear() public void testReserveConfirmClear() throws InterruptedException {
throws InterruptedException, UsernameHashNotAvailableException, UsernameReservationNotFoundException { Account account = AccountsHelper.createAccount(accountsManager, "+18005551111");
Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(),
new ArrayList<>());
// reserve // reserve
AccountsManager.UsernameReservation reservation = AccountsManager.UsernameReservation reservation =
@ -236,11 +248,8 @@ class AccountsManagerUsernameIntegrationTest {
} }
@Test @Test
public void testReservationLapsed() public void testReservationLapsed() throws InterruptedException {
throws InterruptedException, UsernameHashNotAvailableException, UsernameReservationNotFoundException { final Account account = AccountsHelper.createAccount(accountsManager, "+18005551111");
final Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(),
new ArrayList<>());
AccountsManager.UsernameReservation reservation1 = AccountsManager.UsernameReservation reservation1 =
accountsManager.reserveUsernameHash(account, List.of(USERNAME_HASH_1)).join(); accountsManager.reserveUsernameHash(account, List.of(USERNAME_HASH_1)).join();
@ -256,8 +265,8 @@ class AccountsManagerUsernameIntegrationTest {
.build()); .build());
// a different account should be able to reserve it // a different account should be able to reserve it
Account account2 = accountsManager.create("+18005552222", "password", null, new AccountAttributes(), Account account2 = AccountsHelper.createAccount(accountsManager, "+18005552222");
new ArrayList<>());
final AccountsManager.UsernameReservation reservation2 = final AccountsManager.UsernameReservation reservation2 =
accountsManager.reserveUsernameHash(account2, List.of(USERNAME_HASH_1)).join(); accountsManager.reserveUsernameHash(account2, List.of(USERNAME_HASH_1)).join();
assertArrayEquals(reservation2.reservedUsernameHash(), USERNAME_HASH_1); assertArrayEquals(reservation2.reservedUsernameHash(), USERNAME_HASH_1);
@ -271,8 +280,7 @@ class AccountsManagerUsernameIntegrationTest {
@Test @Test
void testUsernameSetReserveAnotherClearSetReserved() throws InterruptedException { void testUsernameSetReserveAnotherClearSetReserved() throws InterruptedException {
Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(), Account account = AccountsHelper.createAccount(accountsManager, "+18005551111");
new ArrayList<>());
// Set username hash // Set username hash
final AccountsManager.UsernameReservation reservation1 = final AccountsManager.UsernameReservation reservation1 =
@ -303,9 +311,10 @@ class AccountsManagerUsernameIntegrationTest {
@Test @Test
public void testUsernameLinks() throws InterruptedException { public void testUsernameLinks() throws InterruptedException {
Account account = accountsManager.create("+18005551111", "password", null, new AccountAttributes(), new ArrayList<>()); final Account account = AccountsHelper.createAccount(accountsManager, "+18005551111");
account.setUsernameHash(RandomUtils.nextBytes(16)); account.setUsernameHash(RandomUtils.nextBytes(16));
accounts.create(account); accounts.create(account, ignored -> Collections.emptyList());
final UUID linkHandle = UUID.randomUUID(); final UUID linkHandle = UUID.randomUUID();
final byte[] encryptedUsername = RandomUtils.nextBytes(32); final byte[] encryptedUsername = RandomUtils.nextBytes(32);

View File

@ -21,12 +21,12 @@ import static org.mockito.Mockito.when;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Base64; import java.util.Base64;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
@ -114,7 +114,7 @@ class AccountsTest {
when(mockDynamicConfigManager.getConfiguration()) when(mockDynamicConfigManager.getConfiguration())
.thenReturn(new DynamicConfiguration()); .thenReturn(new DynamicConfiguration());
this.accounts = new Accounts( accounts = new Accounts(
clock, clock,
DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
@ -129,7 +129,7 @@ class AccountsTest {
public void testStoreAndLookupUsernameLink() { public void testStoreAndLookupUsernameLink() {
final Account account = nextRandomAccount(); final Account account = nextRandomAccount();
account.setUsernameHash(RandomUtils.nextBytes(16)); account.setUsernameHash(RandomUtils.nextBytes(16));
accounts.create(account); createAccount(account);
final BiConsumer<Optional<Account>, byte[]> validator = (maybeAccount, expectedEncryptedUsername) -> { final BiConsumer<Optional<Account>, byte[]> validator = (maybeAccount, expectedEncryptedUsername) -> {
assertTrue(maybeAccount.isPresent()); assertTrue(maybeAccount.isPresent());
@ -165,7 +165,7 @@ class AccountsTest {
Device device = generateDevice(DEVICE_ID_1); Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device)); Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
boolean freshUser = accounts.create(account); boolean freshUser = createAccount(account);
assertThat(freshUser).isTrue(); assertThat(freshUser).isTrue();
verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true); verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true);
@ -173,7 +173,7 @@ class AccountsTest {
assertPhoneNumberConstraintExists("+14151112222", account.getUuid()); assertPhoneNumberConstraintExists("+14151112222", account.getUuid());
assertPhoneNumberIdentifierConstraintExists(account.getPhoneNumberIdentifier(), account.getUuid()); assertPhoneNumberIdentifierConstraintExists(account.getPhoneNumberIdentifier(), account.getUuid());
freshUser = accounts.create(account); freshUser = createAccount(account);
assertThat(freshUser).isTrue(); assertThat(freshUser).isTrue();
verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true); verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true);
@ -188,7 +188,7 @@ class AccountsTest {
Device device = generateDevice(DEVICE_ID_1); Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", originalUuid, UUID.randomUUID(), List.of(device)); Account account = generateAccount("+14151112222", originalUuid, UUID.randomUUID(), List.of(device));
boolean freshUser = accounts.create(account); boolean freshUser = createAccount(account);
assertThat(freshUser).isTrue(); assertThat(freshUser).isTrue();
verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true); verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true);
@ -199,7 +199,7 @@ class AccountsTest {
accounts.delete(originalUuid).join(); accounts.delete(originalUuid).join();
assertThat(accounts.findRecentlyDeletedAccountIdentifier(account.getNumber())).hasValue(originalUuid); assertThat(accounts.findRecentlyDeletedAccountIdentifier(account.getNumber())).hasValue(originalUuid);
freshUser = accounts.create(account); freshUser = createAccount(account);
assertThat(freshUser).isTrue(); assertThat(freshUser).isTrue();
verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true); verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true);
@ -214,7 +214,7 @@ class AccountsTest {
final List<Device> devices = List.of(generateDevice(DEVICE_ID_1), generateDevice(DEVICE_ID_2)); final List<Device> devices = List.of(generateDevice(DEVICE_ID_1), generateDevice(DEVICE_ID_2));
final Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), devices); final Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), devices);
accounts.create(account); createAccount(account);
verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true); verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, true);
@ -236,8 +236,8 @@ class AccountsTest {
UUID pniSecond = UUID.randomUUID(); UUID pniSecond = UUID.randomUUID();
Account accountSecond = generateAccount("+14152221111", uuidSecond, pniSecond, devicesSecond); Account accountSecond = generateAccount("+14152221111", uuidSecond, pniSecond, devicesSecond);
accounts.create(accountFirst); createAccount(accountFirst);
accounts.create(accountSecond); createAccount(accountSecond);
Optional<Account> retrievedFirst = accounts.getByE164("+14151112222"); Optional<Account> retrievedFirst = accounts.getByE164("+14151112222");
Optional<Account> retrievedSecond = accounts.getByE164("+14152221111"); Optional<Account> retrievedSecond = accounts.getByE164("+14152221111");
@ -340,7 +340,7 @@ class AccountsTest {
UUID firstUuid = UUID.randomUUID(); UUID firstUuid = UUID.randomUUID();
UUID firstPni = UUID.randomUUID(); UUID firstPni = UUID.randomUUID();
Account account = generateAccount("+14151112222", firstUuid, firstPni, List.of(device)); Account account = generateAccount("+14151112222", firstUuid, firstPni, List.of(device));
accounts.create(account); createAccount(account);
final byte[] usernameHash = randomBytes(32); final byte[] usernameHash = randomBytes(32);
final byte[] encryptedUsername = randomBytes(32); final byte[] encryptedUsername = randomBytes(32);
@ -358,7 +358,7 @@ class AccountsTest {
// simulate a failed re-reg: we give the account a reclaimable username, but we'll try // simulate a failed re-reg: we give the account a reclaimable username, but we'll try
// re-registering again later in the test case // re-registering again later in the test case
account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1))); account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
accounts.create(account); createAccount(account);
break; break;
case CONFIRMED: case CONFIRMED:
accounts.reserveUsernameHash(account, usernameHash, Duration.ofMinutes(1)).join(); accounts.reserveUsernameHash(account, usernameHash, Duration.ofMinutes(1)).join();
@ -370,7 +370,7 @@ class AccountsTest {
// re-register the account // re-register the account
account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1))); account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
accounts.create(account); createAccount(account);
// If we had a username link, or we had previously saved a username link from another re-registration, make sure // If we had a username link, or we had previously saved a username link from another re-registration, make sure
// we preserve it // we preserve it
@ -401,7 +401,7 @@ class AccountsTest {
UUID firstPni = UUID.randomUUID(); UUID firstPni = UUID.randomUUID();
Account account = generateAccount("+14151112222", firstUuid, firstPni, List.of(device)); Account account = generateAccount("+14151112222", firstUuid, firstPni, List.of(device));
accounts.create(account); createAccount(account);
final byte[] usernameHash = randomBytes(32); final byte[] usernameHash = randomBytes(32);
final byte[] encryptedUsername = randomBytes(16); final byte[] encryptedUsername = randomBytes(16);
@ -422,7 +422,7 @@ class AccountsTest {
device = generateDevice(DEVICE_ID_1); device = generateDevice(DEVICE_ID_1);
account = generateAccount("+14151112222", secondUuid, UUID.randomUUID(), List.of(device)); account = generateAccount("+14151112222", secondUuid, UUID.randomUUID(), List.of(device));
final boolean freshUser = accounts.create(account); final boolean freshUser = createAccount(account);
assertThat(freshUser).isFalse(); assertThat(freshUser).isFalse();
// usernameHash should be unset // usernameHash should be unset
verifyStoredState("+14151112222", firstUuid, firstPni, null, account, true); verifyStoredState("+14151112222", firstUuid, firstPni, null, account, true);
@ -453,7 +453,8 @@ class AccountsTest {
device = generateDevice(DEVICE_ID_1); device = generateDevice(DEVICE_ID_1);
Account invalidAccount = generateAccount("+14151113333", firstUuid, UUID.randomUUID(), List.of(device)); Account invalidAccount = generateAccount("+14151113333", firstUuid, UUID.randomUUID(), List.of(device));
assertThatThrownBy(() -> accounts.create(invalidAccount));
assertThatThrownBy(() -> createAccount(invalidAccount));
} }
@Test @Test
@ -461,7 +462,7 @@ class AccountsTest {
Device device = generateDevice(DEVICE_ID_1); Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device)); Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
accounts.create(account); createAccount(account);
assertPhoneNumberConstraintExists("+14151112222", account.getUuid()); assertPhoneNumberConstraintExists("+14151112222", account.getUuid());
assertPhoneNumberIdentifierConstraintExists(account.getPhoneNumberIdentifier(), account.getUuid()); assertPhoneNumberIdentifierConstraintExists(account.getPhoneNumberIdentifier(), account.getUuid());
@ -511,8 +512,11 @@ class AccountsTest {
final DynamoDbAsyncClient dynamoDbAsyncClient = mock(DynamoDbAsyncClient.class); final DynamoDbAsyncClient dynamoDbAsyncClient = mock(DynamoDbAsyncClient.class);
accounts = new Accounts(mock(DynamoDbClient.class), accounts = new Accounts(mock(DynamoDbClient.class),
dynamoDbAsyncClient, Tables.ACCOUNTS.tableName(), dynamoDbAsyncClient,
Tables.NUMBERS.tableName(), Tables.PNI_ASSIGNMENTS.tableName(), Tables.USERNAMES.tableName(), Tables.ACCOUNTS.tableName(),
Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName()); Tables.DELETED_ACCOUNTS.tableName());
Exception e = TransactionConflictException.builder().build(); Exception e = TransactionConflictException.builder().build();
@ -533,7 +537,7 @@ class AccountsTest {
for (int i = 1; i <= 100; i++) { for (int i = 1; i <= 100; i++) {
final Account account = generateAccount("+1" + String.format("%03d", i), UUID.randomUUID(), UUID.randomUUID()); final Account account = generateAccount("+1" + String.format("%03d", i), UUID.randomUUID(), UUID.randomUUID());
expectedAccounts.add(account); expectedAccounts.add(account);
accounts.create(account); createAccount(account);
} }
final List<Account> retrievedAccounts = final List<Account> retrievedAccounts =
@ -553,8 +557,8 @@ class AccountsTest {
final Account retainedAccount = generateAccount("+14151112345", UUID.randomUUID(), final Account retainedAccount = generateAccount("+14151112345", UUID.randomUUID(),
UUID.randomUUID(), List.of(retainedDevice)); UUID.randomUUID(), List.of(retainedDevice));
accounts.create(deletedAccount); createAccount(deletedAccount);
accounts.create(retainedAccount); createAccount(retainedAccount);
assertThat(accounts.findRecentlyDeletedAccountIdentifier(deletedAccount.getNumber())).isEmpty(); assertThat(accounts.findRecentlyDeletedAccountIdentifier(deletedAccount.getNumber())).isEmpty();
@ -581,7 +585,7 @@ class AccountsTest {
final Account recreatedAccount = generateAccount(deletedAccount.getNumber(), UUID.randomUUID(), final Account recreatedAccount = generateAccount(deletedAccount.getNumber(), UUID.randomUUID(),
UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1))); UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
final boolean freshUser = accounts.create(recreatedAccount); final boolean freshUser = createAccount(recreatedAccount);
assertThat(freshUser).isTrue(); assertThat(freshUser).isTrue();
assertThat(accounts.getByAccountIdentifier(recreatedAccount.getUuid())).isPresent(); assertThat(accounts.getByAccountIdentifier(recreatedAccount.getUuid())).isPresent();
@ -598,7 +602,7 @@ class AccountsTest {
Device device = generateDevice(DEVICE_ID_1); Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device)); Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
accounts.create(account); createAccount(account);
Optional<Account> retrieved = accounts.getByE164("+11111111"); Optional<Account> retrieved = accounts.getByE164("+11111111");
assertThat(retrieved.isPresent()).isFalse(); assertThat(retrieved.isPresent()).isFalse();
@ -614,7 +618,7 @@ class AccountsTest {
final Account account = final Account account =
generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1))); generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
accounts.create(account); createAccount(account);
assertThat(accounts.getByAccountIdentifierAsync(account.getUuid()).join()).isPresent(); assertThat(accounts.getByAccountIdentifierAsync(account.getUuid()).join()).isPresent();
} }
@ -626,7 +630,7 @@ class AccountsTest {
final Account account = final Account account =
generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1))); generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
accounts.create(account); createAccount(account);
assertThat(accounts.getByPhoneNumberIdentifierAsync(account.getPhoneNumberIdentifier()).join()).isPresent(); assertThat(accounts.getByPhoneNumberIdentifierAsync(account.getPhoneNumberIdentifier()).join()).isPresent();
} }
@ -640,7 +644,7 @@ class AccountsTest {
final Account account = final Account account =
generateAccount(e164, UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1))); generateAccount(e164, UUID.randomUUID(), UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
accounts.create(account); createAccount(account);
assertThat(accounts.getByE164Async(e164).join()).isPresent(); assertThat(accounts.getByE164Async(e164).join()).isPresent();
} }
@ -650,7 +654,7 @@ class AccountsTest {
Device device = generateDevice(DEVICE_ID_1); Device device = generateDevice(DEVICE_ID_1);
Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device)); Account account = generateAccount("+14151112222", UUID.randomUUID(), UUID.randomUUID(), List.of(device));
account.setDiscoverableByPhoneNumber(false); account.setDiscoverableByPhoneNumber(false);
accounts.create(account); createAccount(account);
verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, false); verifyStoredState("+14151112222", account.getUuid(), account.getPhoneNumberIdentifier(), null, account, false);
account.setDiscoverableByPhoneNumber(true); account.setDiscoverableByPhoneNumber(true);
accounts.update(account); accounts.update(account);
@ -673,7 +677,7 @@ class AccountsTest {
final Device device = generateDevice(DEVICE_ID_1); final Device device = generateDevice(DEVICE_ID_1);
final Account account = generateAccount(originalNumber, UUID.randomUUID(), originalPni, List.of(device)); final Account account = generateAccount(originalNumber, UUID.randomUUID(), originalPni, List.of(device));
accounts.create(account); createAccount(account);
assertThat(accounts.getByPhoneNumberIdentifier(originalPni)).isPresent(); assertThat(accounts.getByPhoneNumberIdentifier(originalPni)).isPresent();
@ -731,8 +735,8 @@ class AccountsTest {
final Device device = generateDevice(DEVICE_ID_1); final Device device = generateDevice(DEVICE_ID_1);
final Account account = generateAccount(originalNumber, UUID.randomUUID(), originalPni, List.of(device)); final Account account = generateAccount(originalNumber, UUID.randomUUID(), originalPni, List.of(device));
accounts.create(account); createAccount(account);
accounts.create(existingAccount); createAccount(existingAccount);
assertThrows(TransactionCanceledException.class, () -> accounts.changeNumber(account, targetNumber, targetPni, Optional.of(existingAccount.getUuid()))); assertThrows(TransactionCanceledException.class, () -> accounts.changeNumber(account, targetNumber, targetPni, Optional.of(existingAccount.getUuid())));
@ -750,7 +754,7 @@ class AccountsTest {
final Device device = generateDevice(DEVICE_ID_1); final Device device = generateDevice(DEVICE_ID_1);
final Account account = generateAccount(originalNumber, UUID.randomUUID(), UUID.randomUUID(), List.of(device)); final Account account = generateAccount(originalNumber, UUID.randomUUID(), UUID.randomUUID(), List.of(device));
accounts.create(account); createAccount(account);
final UUID existingAccountIdentifier = UUID.randomUUID(); final UUID existingAccountIdentifier = UUID.randomUUID();
final UUID existingPhoneNumberIdentifier = UUID.randomUUID(); final UUID existingPhoneNumberIdentifier = UUID.randomUUID();
@ -776,7 +780,7 @@ class AccountsTest {
@Test @Test
void testSwitchUsernameHashes() { void testSwitchUsernameHashes() {
final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID()); final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account); createAccount(account);
assertThat(accounts.getByUsernameHash(USERNAME_HASH_1).join()).isEmpty(); assertThat(accounts.getByUsernameHash(USERNAME_HASH_1).join()).isEmpty();
@ -822,8 +826,8 @@ class AccountsTest {
final Account firstAccount = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID()); final Account firstAccount = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
final Account secondAccount = generateAccount("+18005559876", UUID.randomUUID(), UUID.randomUUID()); final Account secondAccount = generateAccount("+18005559876", UUID.randomUUID(), UUID.randomUUID());
accounts.create(firstAccount); createAccount(firstAccount);
accounts.create(secondAccount); createAccount(secondAccount);
// first account reserves and confirms username hash // first account reserves and confirms username hash
assertThatNoException().isThrownBy(() -> { assertThatNoException().isThrownBy(() -> {
@ -855,7 +859,7 @@ class AccountsTest {
@Test @Test
void testConfirmUsernameHashVersionMismatch() { void testConfirmUsernameHashVersionMismatch() {
final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID()); final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account); createAccount(account);
accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(1)).join(); accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(1)).join();
account.setVersion(account.getVersion() + 77); account.setVersion(account.getVersion() + 77);
@ -868,7 +872,7 @@ class AccountsTest {
@Test @Test
void testClearUsername() { void testClearUsername() {
final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID()); final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account); createAccount(account);
accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(1)).join(); accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(1)).join();
accounts.confirmUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1).join(); accounts.confirmUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1).join();
@ -888,7 +892,7 @@ class AccountsTest {
@Test @Test
void testClearUsernameNoUsername() { void testClearUsernameNoUsername() {
final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID()); final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account); createAccount(account);
assertThatNoException().isThrownBy(() -> accounts.clearUsernameHash(account).join()); assertThatNoException().isThrownBy(() -> accounts.clearUsernameHash(account).join());
} }
@ -896,7 +900,7 @@ class AccountsTest {
@Test @Test
void testClearUsernameVersionMismatch() { void testClearUsernameVersionMismatch() {
final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID()); final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account); createAccount(account);
accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(1)).join(); accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(1)).join();
accounts.confirmUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1).join(); accounts.confirmUsernameHash(account, USERNAME_HASH_1, ENCRYPTED_USERNAME_1).join();
@ -912,9 +916,9 @@ class AccountsTest {
@Test @Test
void testReservedUsernameHash() { void testReservedUsernameHash() {
final Account account1 = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID()); final Account account1 = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account1); createAccount(account1);
final Account account2 = generateAccount("+18005552222", UUID.randomUUID(), UUID.randomUUID()); final Account account2 = generateAccount("+18005552222", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account2); createAccount(account2);
accounts.reserveUsernameHash(account1, USERNAME_HASH_1, Duration.ofDays(1)).join(); accounts.reserveUsernameHash(account1, USERNAME_HASH_1, Duration.ofDays(1)).join();
assertArrayEquals(account1.getReservedUsernameHash().orElseThrow(), USERNAME_HASH_1); assertArrayEquals(account1.getReservedUsernameHash().orElseThrow(), USERNAME_HASH_1);
@ -946,7 +950,7 @@ class AccountsTest {
@Test @Test
void testUsernameHashAvailable() { void testUsernameHashAvailable() {
final Account account1 = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID()); final Account account1 = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account1); createAccount(account1);
accounts.reserveUsernameHash(account1, USERNAME_HASH_1, Duration.ofDays(1)).join(); accounts.reserveUsernameHash(account1, USERNAME_HASH_1, Duration.ofDays(1)).join();
assertThat(accounts.usernameHashAvailable(USERNAME_HASH_1).join()).isFalse(); assertThat(accounts.usernameHashAvailable(USERNAME_HASH_1).join()).isFalse();
@ -964,9 +968,9 @@ class AccountsTest {
@Test @Test
void testConfirmReservedUsernameHashWrongAccountUuid() { void testConfirmReservedUsernameHashWrongAccountUuid() {
final Account account1 = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID()); final Account account1 = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account1); createAccount(account1);
final Account account2 = generateAccount("+18005552222", UUID.randomUUID(), UUID.randomUUID()); final Account account2 = generateAccount("+18005552222", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account2); createAccount(account2);
accounts.reserveUsernameHash(account1, USERNAME_HASH_1, Duration.ofDays(1)).join(); accounts.reserveUsernameHash(account1, USERNAME_HASH_1, Duration.ofDays(1)).join();
assertArrayEquals(account1.getReservedUsernameHash().orElseThrow(), USERNAME_HASH_1); assertArrayEquals(account1.getReservedUsernameHash().orElseThrow(), USERNAME_HASH_1);
@ -980,9 +984,9 @@ class AccountsTest {
@Test @Test
void testConfirmExpiredReservedUsernameHash() { void testConfirmExpiredReservedUsernameHash() {
final Account account1 = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID()); final Account account1 = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account1); createAccount(account1);
final Account account2 = generateAccount("+18005552222", UUID.randomUUID(), UUID.randomUUID()); final Account account2 = generateAccount("+18005552222", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account2); createAccount(account2);
accounts.reserveUsernameHash(account1, USERNAME_HASH_1, Duration.ofDays(2)).join(); accounts.reserveUsernameHash(account1, USERNAME_HASH_1, Duration.ofDays(2)).join();
@ -1009,7 +1013,7 @@ class AccountsTest {
@Test @Test
void testRetryReserveUsernameHash() { void testRetryReserveUsernameHash() {
final Account account = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID()); final Account account = generateAccount("+18005551111", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account); createAccount(account);
accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(2)).join(); accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(2)).join();
CompletableFutureTestUtil.assertFailsWithCause(ContestedOptimisticLockException.class, CompletableFutureTestUtil.assertFailsWithCause(ContestedOptimisticLockException.class,
@ -1020,7 +1024,7 @@ class AccountsTest {
@Test @Test
void testReserveConfirmUsernameHashVersionConflict() { void testReserveConfirmUsernameHashVersionConflict() {
final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID()); final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account); createAccount(account);
account.setVersion(account.getVersion() + 12); account.setVersion(account.getVersion() + 12);
CompletableFutureTestUtil.assertFailsWithCause(ContestedOptimisticLockException.class, CompletableFutureTestUtil.assertFailsWithCause(ContestedOptimisticLockException.class,
accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(1))); accounts.reserveUsernameHash(account, USERNAME_HASH_1, Duration.ofDays(1)));
@ -1035,7 +1039,7 @@ class AccountsTest {
final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID()); final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
account.setUsernameHash(RandomUtils.nextBytes(32)); account.setUsernameHash(RandomUtils.nextBytes(32));
account.setUsernameLinkDetails(UUID.randomUUID(), RandomUtils.nextBytes(32)); account.setUsernameLinkDetails(UUID.randomUUID(), RandomUtils.nextBytes(32));
accounts.create(account); createAccount(account);
final Map<String, AttributeValue> accountRecord = DYNAMO_DB_EXTENSION.getDynamoDbClient() final Map<String, AttributeValue> accountRecord = DYNAMO_DB_EXTENSION.getDynamoDbClient()
.getItem(GetItemRequest.builder() .getItem(GetItemRequest.builder()
.tableName(Tables.ACCOUNTS.tableName()) .tableName(Tables.ACCOUNTS.tableName())
@ -1053,7 +1057,7 @@ class AccountsTest {
assertThat(accounts.getByUsernameHash(USERNAME_HASH_1).join()).isEmpty(); assertThat(accounts.getByUsernameHash(USERNAME_HASH_1).join()).isEmpty();
final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID()); final Account account = generateAccount("+18005551234", UUID.randomUUID(), UUID.randomUUID());
accounts.create(account); createAccount(account);
assertThat(accounts.getByUsernameHash(USERNAME_HASH_1).join()).isEmpty(); assertThat(accounts.getByUsernameHash(USERNAME_HASH_1).join()).isEmpty();
@ -1069,7 +1073,7 @@ class AccountsTest {
final Device device2 = generateDevice((byte) 64); final Device device2 = generateDevice((byte) 64);
account.addDevice(device2); account.addDevice(device2);
accounts.create(account); createAccount(account);
final GetItemResponse response = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient().getItem(GetItemRequest.builder() final GetItemResponse response = DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient().getItem(GetItemRequest.builder()
.tableName(Tables.ACCOUNTS.tableName()) .tableName(Tables.ACCOUNTS.tableName())
@ -1108,6 +1112,10 @@ class AccountsTest {
return DevicesHelper.createDevice(id); return DevicesHelper.createDevice(id);
} }
private boolean createAccount(final Account account) {
return accounts.create(account, ignored -> Collections.emptyList());
}
private static Account nextRandomAccount() { private static Account nextRandomAccount() {
final String nextNumber = "+1800%07d".formatted(ACCOUNT_COUNTER.getAndIncrement()); final String nextNumber = "+1800%07d".formatted(ACCOUNT_COUNTER.getAndIncrement());
return generateAccount(nextNumber, UUID.randomUUID(), UUID.randomUUID()); return generateAccount(nextNumber, UUID.randomUUID(), UUID.randomUUID());

View File

@ -5,14 +5,16 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import static org.junit.jupiter.api.Assertions.*;
abstract class RepeatedUseSignedPreKeyStoreTest<K extends SignedPreKey<?>> { abstract class RepeatedUseSignedPreKeyStoreTest<K extends SignedPreKey<?>> {
@ -50,38 +52,47 @@ abstract class RepeatedUseSignedPreKeyStoreTest<K extends SignedPreKey<?>> {
} }
@Test @Test
void delete() { void deleteForDevice() {
final RepeatedUseSignedPreKeyStore<K> keys = getKeyStore(); final RepeatedUseSignedPreKeyStore<K> keys = getKeyStore();
assertDoesNotThrow(() -> keys.delete(UUID.randomUUID()).join()); final UUID identifier = UUID.randomUUID();
final byte deviceId2 = 2; final byte deviceId2 = 2;
{ final Map<Byte, K> signedPreKeys = Map.of(
final UUID identifier = UUID.randomUUID(); Device.PRIMARY_ID, generateSignedPreKey(),
final Map<Byte, K> signedPreKeys = Map.of( deviceId2, generateSignedPreKey()
Device.PRIMARY_ID, generateSignedPreKey(), );
deviceId2, generateSignedPreKey()
);
keys.store(identifier, signedPreKeys).join(); keys.store(identifier, signedPreKeys).join();
keys.delete(identifier, Device.PRIMARY_ID).join(); keys.delete(identifier, Device.PRIMARY_ID).join();
assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join());
assertEquals(Optional.of(signedPreKeys.get(deviceId2)), keys.find(identifier, deviceId2).join());
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void deleteForAllDevices(final boolean excludePrimaryDevice) {
final RepeatedUseSignedPreKeyStore<K> keys = getKeyStore();
assertDoesNotThrow(() -> keys.delete(UUID.randomUUID(), excludePrimaryDevice).join());
final byte deviceId2 = Device.PRIMARY_ID + 1;
final UUID identifier = UUID.randomUUID();
final Map<Byte, K> signedPreKeys = Map.of(
Device.PRIMARY_ID, generateSignedPreKey(),
deviceId2, generateSignedPreKey()
);
keys.store(identifier, signedPreKeys).join();
keys.delete(identifier, excludePrimaryDevice).join();
if (excludePrimaryDevice) {
assertTrue(keys.find(identifier, Device.PRIMARY_ID).join().isPresent());
} else {
assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join()); assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join());
assertEquals(Optional.of(signedPreKeys.get(deviceId2)), keys.find(identifier, deviceId2).join());
} }
{ assertEquals(Optional.empty(), keys.find(identifier, deviceId2).join());
final UUID identifier = UUID.randomUUID();
final Map<Byte, K> signedPreKeys = Map.of(
Device.PRIMARY_ID, generateSignedPreKey(),
deviceId2, generateSignedPreKey()
);
keys.store(identifier, signedPreKeys).join();
keys.delete(identifier).join();
assertEquals(Optional.empty(), keys.find(identifier, Device.PRIMARY_ID).join());
assertEquals(Optional.empty(), keys.find(identifier, deviceId2).join());
}
} }
} }

View File

@ -15,13 +15,19 @@ import static org.mockito.Mockito.when;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.function.Consumer; import java.util.function.Consumer;
import org.mockito.MockingDetails; import org.mockito.MockingDetails;
import org.mockito.stubbing.Stubbing; import org.mockito.stubbing.Stubbing;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
@ -150,4 +156,30 @@ public class AccountsHelper {
return argThat(other -> other.getUuid().equals(value.getUuid())); return argThat(other -> other.getUuid().equals(value.getUuid()));
} }
public static Account createAccount(final AccountsManager accountsManager, final String e164)
throws InterruptedException {
return createAccount(accountsManager, e164, new AccountAttributes());
}
public static Account createAccount(final AccountsManager accountsManager, final String e164, final AccountAttributes accountAttributes)
throws InterruptedException {
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
return accountsManager.create(e164,
"password",
null,
accountAttributes,
new ArrayList<>(),
new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair),
Optional.empty(),
Optional.empty());
}
} }