Add devices to accounts transactionally

This commit is contained in:
Jon Chambers 2023-12-07 11:19:40 -05:00 committed by GitHub
parent e084a9f2b6
commit 50d92265ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 520 additions and 268 deletions

View File

@ -68,6 +68,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
@ -403,60 +404,63 @@ public class DeviceController {
throw new WebApplicationException(Response.status(409).build()); throw new WebApplicationException(Response.status(409).build());
} }
final Device device = new Device(); return maybeDeviceActivationRequest.map(deviceActivationRequest -> {
device.setName(accountAttributes.getName()); final String signalAgent;
device.setAuthTokenHash(SaltedTokenHash.generateFor(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
device.setPhoneNumberIdentityRegistrationId(accountAttributes.getPhoneNumberIdentityRegistrationId());
device.setLastSeen(Util.todayInMillis());
device.setCreated(System.currentTimeMillis());
device.setCapabilities(accountAttributes.getCapabilities());
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> { if (deviceActivationRequest.apnToken().isPresent()) {
device.setSignedPreKey(deviceActivationRequest.aciSignedPreKey()); signalAgent = "OWP";
device.setPhoneNumberIdentitySignedPreKey(deviceActivationRequest.pniSignedPreKey()); } else if (deviceActivationRequest.gcmToken().isPresent()) {
signalAgent = "OWA";
} else {
signalAgent = "OWD";
}
deviceActivationRequest.apnToken().ifPresent(apnRegistrationId -> { return accounts.addDevice(account, new DeviceSpec(accountAttributes.getName(),
device.setApnId(apnRegistrationId.apnRegistrationId()); password,
device.setVoipApnId(apnRegistrationId.voipRegistrationId()); signalAgent,
}); capabilities,
accountAttributes.getRegistrationId(),
accountAttributes.getPhoneNumberIdentityRegistrationId(),
accountAttributes.getFetchesMessages(),
deviceActivationRequest.apnToken(),
deviceActivationRequest.gcmToken(),
deviceActivationRequest.aciSignedPreKey(),
deviceActivationRequest.pniSignedPreKey(),
deviceActivationRequest.aciPqLastResortPreKey(),
deviceActivationRequest.pniPqLastResortPreKey()))
.thenCompose(a -> usedTokenCluster.withCluster(connection -> connection.async()
.set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION)))
.thenApply(ignored -> a))
.join();
})
.orElseGet(() -> {
final Device device = new Device();
device.setName(accountAttributes.getName());
device.setAuthTokenHash(SaltedTokenHash.generateFor(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
device.setPhoneNumberIdentityRegistrationId(accountAttributes.getPhoneNumberIdentityRegistrationId());
device.setLastSeen(Util.todayInMillis());
device.setCreated(System.currentTimeMillis());
device.setCapabilities(accountAttributes.getCapabilities());
deviceActivationRequest.gcmToken().ifPresent(gcmRegistrationId -> final Account updatedAccount = accounts.update(account, a -> {
device.setGcmId(gcmRegistrationId.gcmRegistrationId())); device.setId(a.getNextDeviceId());
});
final Account updatedAccount = accounts.update(account, a -> { CompletableFuture.allOf(
device.setId(a.getNextDeviceId()); keys.delete(a.getUuid(), device.getId()),
keys.delete(a.getPhoneNumberIdentifier(), device.getId()),
messages.clear(a.getUuid(), device.getId()))
.join();
final CompletableFuture<Void> deleteKeysFuture = CompletableFuture.allOf( a.addDevice(device);
keys.delete(a.getUuid(), device.getId()), });
keys.delete(a.getPhoneNumberIdentifier(), device.getId()));
messages.clear(a.getUuid(), device.getId()).join(); usedTokenCluster.useCluster(connection ->
connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION)));
deleteKeysFuture.join(); return new Pair<>(updatedAccount, device);
});
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> CompletableFuture.allOf(
keys.storeEcSignedPreKeys(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey())),
keys.storePqLastResort(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey())),
keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniSignedPreKey())),
keys.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey())))
.join());
a.addDevice(device);
});
if (maybeAciFromToken.isPresent()) {
usedTokenCluster.useCluster(connection ->
connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION)));
}
return new Pair<>(updatedAccount, device);
} }
private static String getUsedTokenKey(final String token) { private static String getUsedTokenKey(final String token) {

View File

@ -43,6 +43,7 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
@ -140,18 +141,24 @@ public class RegistrationController {
} }
final Account account = accounts.create(number, final Account account = accounts.create(number,
password,
signalAgent,
registrationRequest.accountAttributes(), registrationRequest.accountAttributes(),
existingAccount.map(Account::getBadges).orElseGet(ArrayList::new), existingAccount.map(Account::getBadges).orElseGet(ArrayList::new),
registrationRequest.aciIdentityKey(), registrationRequest.aciIdentityKey(),
registrationRequest.pniIdentityKey(), registrationRequest.pniIdentityKey(),
registrationRequest.deviceActivationRequest().aciSignedPreKey(), new DeviceSpec(
registrationRequest.deviceActivationRequest().pniSignedPreKey(), registrationRequest.accountAttributes().getName(),
registrationRequest.deviceActivationRequest().aciPqLastResortPreKey(), password,
registrationRequest.deviceActivationRequest().pniPqLastResortPreKey(), signalAgent,
registrationRequest.deviceActivationRequest().apnToken(), registrationRequest.accountAttributes().getCapabilities(),
registrationRequest.deviceActivationRequest().gcmToken()); registrationRequest.accountAttributes().getRegistrationId(),
registrationRequest.accountAttributes().getPhoneNumberIdentityRegistrationId(),
registrationRequest.accountAttributes().getFetchesMessages(),
registrationRequest.deviceActivationRequest().apnToken(),
registrationRequest.deviceActivationRequest().gcmToken(),
registrationRequest.deviceActivationRequest().aciSignedPreKey(),
registrationRequest.deviceActivationRequest().pniSignedPreKey(),
registrationRequest.deviceActivationRequest().aciPqLastResortPreKey(),
registrationRequest.deviceActivationRequest().pniPqLastResortPreKey()));
Metrics.counter(ACCOUNT_CREATED_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), Metrics.counter(ACCOUNT_CREATED_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)), Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)),

View File

@ -53,9 +53,7 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
@ -68,6 +66,7 @@ import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator; import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.ParallelFlux; import reactor.core.publisher.ParallelFlux;
@ -132,11 +131,6 @@ public class AccountsManager {
private static final int MAX_UPDATE_ATTEMPTS = 10; private static final int MAX_UPDATE_ATTEMPTS = 10;
@FunctionalInterface
private interface AccountPersister {
void persistAccount(Account account) throws UsernameHashNotAvailableException;
}
public enum DeletionReason { public enum DeletionReason {
ADMIN_DELETED("admin"), ADMIN_DELETED("admin"),
EXPIRED ("expired"), EXPIRED ("expired"),
@ -181,46 +175,18 @@ public class AccountsManager {
this.clock = requireNonNull(clock); this.clock = requireNonNull(clock);
} }
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public Account create(final String number, public Account create(final String number,
final String password,
final String signalAgent,
final AccountAttributes accountAttributes, final AccountAttributes accountAttributes,
final List<AccountBadge> accountBadges, final List<AccountBadge> accountBadges,
final IdentityKey aciIdentityKey, final IdentityKey aciIdentityKey,
final IdentityKey pniIdentityKey, final IdentityKey pniIdentityKey,
final ECSignedPreKey aciSignedPreKey, final DeviceSpec primaryDeviceSpec) throws InterruptedException {
final ECSignedPreKey pniSignedPreKey,
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniPqLastResortPreKey,
final Optional<ApnRegistrationId> maybeApnRegistrationId,
final Optional<GcmRegistrationId> maybeGcmRegistrationId) throws InterruptedException {
try (Timer.Context ignored = createTimer.time()) { try (Timer.Context ignored = createTimer.time()) {
final Account account = new Account(); final Account account = new Account();
accountLockManager.withLock(List.of(number), () -> { accountLockManager.withLock(List.of(number), () -> {
final Device device = new Device(); final Device device = primaryDeviceSpec.toDevice(Device.PRIMARY_ID, clock);
device.setId(Device.PRIMARY_ID);
device.setAuthTokenHash(SaltedTokenHash.generateFor(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
device.setPhoneNumberIdentityRegistrationId(accountAttributes.getPhoneNumberIdentityRegistrationId());
device.setName(accountAttributes.getName());
device.setCapabilities(accountAttributes.getCapabilities());
device.setCreated(System.currentTimeMillis());
device.setLastSeen(Util.todayInMillis());
device.setUserAgent(signalAgent);
device.setSignedPreKey(aciSignedPreKey);
device.setPhoneNumberIdentitySignedPreKey(pniSignedPreKey);
maybeApnRegistrationId.ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId());
device.setVoipApnId(apnRegistrationId.voipRegistrationId());
});
maybeGcmRegistrationId.ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
account.setNumber(number, phoneNumberIdentifiers.getPhoneNumberIdentifier(number)); account.setNumber(number, phoneNumberIdentifiers.getPhoneNumberIdentifier(number));
@ -245,10 +211,10 @@ public class AccountsManager {
a -> keysManager.buildWriteItemsForRepeatedUseKeys(a.getIdentifier(IdentityType.ACI), a -> keysManager.buildWriteItemsForRepeatedUseKeys(a.getIdentifier(IdentityType.ACI),
a.getIdentifier(IdentityType.PNI), a.getIdentifier(IdentityType.PNI),
Device.PRIMARY_ID, Device.PRIMARY_ID,
aciSignedPreKey, primaryDeviceSpec.aciSignedPreKey(),
pniSignedPreKey, primaryDeviceSpec.pniSignedPreKey(),
aciPqLastResortPreKey, primaryDeviceSpec.aciPqLastResortPreKey(),
pniPqLastResortPreKey), primaryDeviceSpec.pniPqLastResortPreKey()),
(aci, pni) -> CompletableFuture.allOf( (aci, pni) -> CompletableFuture.allOf(
keysManager.delete(aci), keysManager.delete(aci),
keysManager.delete(pni), keysManager.delete(pni),
@ -299,6 +265,42 @@ public class AccountsManager {
} }
} }
public CompletableFuture<Pair<Account, Device>> addDevice(final Account account, final DeviceSpec deviceSpec) {
return addDevice(account.getIdentifier(IdentityType.ACI), deviceSpec, MAX_UPDATE_ATTEMPTS);
}
private CompletableFuture<Pair<Account, Device>> addDevice(final UUID accountIdentifier, final DeviceSpec deviceSpec, final int retries) {
return accounts.getByAccountIdentifierAsync(accountIdentifier)
.thenApply(maybeAccount -> maybeAccount.orElseThrow(ContestedOptimisticLockException::new))
.thenCompose(account -> {
final byte nextDeviceId = account.getNextDeviceId();
account.addDevice(deviceSpec.toDevice(nextDeviceId, clock));
final List<TransactWriteItem> additionalWriteItems = keysManager.buildWriteItemsForRepeatedUseKeys(
account.getIdentifier(IdentityType.ACI),
account.getIdentifier(IdentityType.PNI),
nextDeviceId,
deviceSpec.aciSignedPreKey(),
deviceSpec.pniSignedPreKey(),
deviceSpec.aciPqLastResortPreKey(),
deviceSpec.pniPqLastResortPreKey());
return CompletableFuture.allOf(
keysManager.delete(account.getUuid(), nextDeviceId),
keysManager.delete(account.getPhoneNumberIdentifier(), nextDeviceId),
messagesManager.clear(account.getUuid(), nextDeviceId))
.thenCompose(ignored -> accounts.updateTransactionallyAsync(account, additionalWriteItems))
.thenApply(ignored -> new Pair<>(account, account.getDevice(nextDeviceId).orElseThrow()));
})
.exceptionallyCompose(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof ContestedOptimisticLockException && retries > 0) {
return addDevice(accountIdentifier, deviceSpec, retries - 1);
}
return CompletableFuture.failedFuture(throwable);
});
}
public CompletableFuture<Account> removeDevice(final Account account, final byte deviceId) { public CompletableFuture<Account> removeDevice(final Account account, final byte deviceId) {
if (deviceId == Device.PRIMARY_ID) { if (deviceId == Device.PRIMARY_ID) {
throw new IllegalArgumentException("Cannot remove primary device"); throw new IllegalArgumentException("Cannot remove primary device");
@ -705,19 +707,6 @@ public class AccountsManager {
final Consumer<Account> persister, final Consumer<Account> persister,
final Supplier<Account> retriever, final Supplier<Account> retriever,
final AccountChangeValidator changeValidator) { final AccountChangeValidator changeValidator) {
try {
return failableUpdateWithRetries(account, updater, persister::accept, retriever, changeValidator);
} catch (UsernameHashNotAvailableException e) {
// not possible
throw new IllegalStateException(e);
}
}
private Account failableUpdateWithRetries(Account account,
final Function<Account, Boolean> updater,
final AccountPersister persister,
final Supplier<Account> retriever,
final AccountChangeValidator changeValidator) throws UsernameHashNotAvailableException {
Account originalAccount = AccountUtil.cloneAccountAsNotStale(account); Account originalAccount = AccountUtil.cloneAccountAsNotStale(account);
@ -731,7 +720,7 @@ public class AccountsManager {
while (tries < maxTries) { while (tries < maxTries) {
try { try {
persister.persistAccount(account); persister.accept(account);
final Account updatedAccount = AccountUtil.cloneAccountAsNotStale(account); final Account updatedAccount = AccountUtil.cloneAccountAsNotStale(account);
account.markStale(); account.markStale();

View File

@ -0,0 +1,90 @@
package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.util.Util;
import java.time.Clock;
import java.util.Arrays;
import java.util.Objects;
import java.util.Optional;
public record DeviceSpec(
byte[] deviceNameCiphertext,
String password,
String signalAgent,
Device.DeviceCapabilities capabilities,
int aciRegistrationId,
int pniRegistrationId,
boolean fetchesMessages,
Optional<ApnRegistrationId> apnRegistrationId,
Optional<GcmRegistrationId> gcmRegistrationId,
ECSignedPreKey aciSignedPreKey,
ECSignedPreKey pniSignedPreKey,
KEMSignedPreKey aciPqLastResortPreKey,
KEMSignedPreKey pniPqLastResortPreKey) {
public Device toDevice(final byte deviceId, final Clock clock) {
final Device device = new Device();
device.setId(deviceId);
device.setAuthTokenHash(SaltedTokenHash.generateFor(password()));
device.setFetchesMessages(fetchesMessages());
device.setRegistrationId(aciRegistrationId());
device.setPhoneNumberIdentityRegistrationId(pniRegistrationId());
device.setName(deviceNameCiphertext());
device.setCapabilities(capabilities());
device.setCreated(clock.millis());
device.setLastSeen(Util.todayInMillis());
device.setUserAgent(signalAgent());
device.setSignedPreKey(aciSignedPreKey());
device.setPhoneNumberIdentitySignedPreKey(pniSignedPreKey());
apnRegistrationId().ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId());
device.setVoipApnId(apnRegistrationId.voipRegistrationId());
});
gcmRegistrationId().ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
return device;
}
@Override
public boolean equals(final Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
final DeviceSpec that = (DeviceSpec) o;
return aciRegistrationId == that.aciRegistrationId
&& pniRegistrationId == that.pniRegistrationId
&& fetchesMessages == that.fetchesMessages
&& Arrays.equals(deviceNameCiphertext, that.deviceNameCiphertext)
&& Objects.equals(password, that.password)
&& Objects.equals(signalAgent, that.signalAgent)
&& Objects.equals(capabilities, that.capabilities)
&& Objects.equals(apnRegistrationId, that.apnRegistrationId)
&& Objects.equals(gcmRegistrationId, that.gcmRegistrationId)
&& Objects.equals(aciSignedPreKey, that.aciSignedPreKey)
&& Objects.equals(pniSignedPreKey, that.pniSignedPreKey)
&& Objects.equals(aciPqLastResortPreKey, that.aciPqLastResortPreKey)
&& Objects.equals(pniPqLastResortPreKey, that.pniPqLastResortPreKey);
}
@Override
public int hashCode() {
int result = Objects.hash(password, signalAgent, capabilities, aciRegistrationId, pniRegistrationId,
fetchesMessages, apnRegistrationId, gcmRegistrationId, aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey,
pniPqLastResortPreKey);
result = 31 * result + Arrays.hashCode(deviceNameCiphertext);
return result;
}
}

View File

@ -24,6 +24,7 @@ import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension; import io.dropwizard.testing.junit5.ResourceExtension;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Instant; import java.time.Instant;
@ -38,6 +39,7 @@ import java.util.stream.Stream;
import javax.ws.rs.client.Entity; import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.glassfish.jersey.server.ServerProperties;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -72,12 +74,15 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.VerificationCode; import org.whispersystems.textsecuregcm.util.VerificationCode;
@ -91,6 +96,7 @@ class DeviceControllerTest {
private static RateLimiters rateLimiters = mock(RateLimiters.class); private static RateLimiters rateLimiters = mock(RateLimiters.class);
private static RateLimiter rateLimiter = mock(RateLimiter.class); private static RateLimiter rateLimiter = mock(RateLimiter.class);
private static RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class); private static RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
private static RedisAdvancedClusterAsyncCommands<String, String> asyncCommands = mock(RedisAdvancedClusterAsyncCommands.class);
private static Account account = mock(Account.class); private static Account account = mock(Account.class);
private static Account maxedAccount = mock(Account.class); private static Account maxedAccount = mock(Account.class);
private static Device primaryDevice = mock(Device.class); private static Device primaryDevice = mock(Device.class);
@ -106,7 +112,10 @@ class DeviceControllerTest {
messagesManager, messagesManager,
keysManager, keysManager,
rateLimiters, rateLimiters,
RedisClusterHelper.builder().stringCommands(commands).build(), RedisClusterHelper.builder()
.stringCommands(commands)
.stringAsyncCommands(asyncCommands)
.build(),
deviceConfiguration, deviceConfiguration,
testClock); testClock);
@ -114,6 +123,7 @@ class DeviceControllerTest {
public static final AuthHelper.AuthFilterExtension AUTH_FILTER_EXTENSION = new AuthHelper.AuthFilterExtension(); public static final AuthHelper.AuthFilterExtension AUTH_FILTER_EXTENSION = new AuthHelper.AuthFilterExtension();
private static final ResourceExtension resources = ResourceExtension.builder() private static final ResourceExtension resources = ResourceExtension.builder()
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
@ -166,6 +176,7 @@ class DeviceControllerTest {
rateLimiters, rateLimiters,
rateLimiter, rateLimiter,
commands, commands,
asyncCommands,
account, account,
maxedAccount, maxedAccount,
primaryDevice, primaryDevice,
@ -300,11 +311,22 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(accountsManager.addDevice(any(), any())).thenAnswer(invocation -> {
final Account a = invocation.getArgument(0);
final DeviceSpec deviceSpec = invocation.getArgument(1);
return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock)));
});
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final AccountAttributes accountAttributes = new AccountAttributes(fetchesMessages, 1234, 5678, null, null, true, null);
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(), final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
new AccountAttributes(fetchesMessages, 1234, 5678, null, null, true, null), accountAttributes,
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId)); new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId));
final DeviceResponse response = resources.getJerseyTest() final DeviceResponse response = resources.getJerseyTest()
@ -315,10 +337,10 @@ class DeviceControllerTest {
assertThat(response.getDeviceId()).isEqualTo(NEXT_DEVICE_ID); assertThat(response.getDeviceId()).isEqualTo(NEXT_DEVICE_ID);
final ArgumentCaptor<Device> deviceCaptor = ArgumentCaptor.forClass(Device.class); final ArgumentCaptor<DeviceSpec> deviceSpecCaptor = ArgumentCaptor.forClass(DeviceSpec.class);
verify(account).addDevice(deviceCaptor.capture()); verify(accountsManager).addDevice(eq(account), deviceSpecCaptor.capture());
final Device device = deviceCaptor.getValue(); final Device device = deviceSpecCaptor.getValue().toDevice(NEXT_DEVICE_ID, testClock);
assertEquals(aciSignedPreKey, device.getSignedPreKey(IdentityType.ACI)); assertEquals(aciSignedPreKey, device.getSignedPreKey(IdentityType.ACI));
assertEquals(pniSignedPreKey, device.getSignedPreKey(IdentityType.PNI)); assertEquals(pniSignedPreKey, device.getSignedPreKey(IdentityType.PNI));
@ -333,14 +355,9 @@ class DeviceControllerTest {
expectedGcmToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getGcmId()), expectedGcmToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getGcmId()),
() -> assertNull(device.getGcmId())); () -> assertNull(device.getGcmId()));
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(NEXT_DEVICE_ID)); verify(asyncCommands).set(anyString(), anyString(), any());
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciSignedPreKey));
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey));
verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey));
verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey));
verify(commands).set(anyString(), anyString(), any());
} }
private static Stream<Arguments> linkDeviceAtomic() { private static Stream<Arguments> linkDeviceAtomic() {
final String apnsToken = "apns-token"; final String apnsToken = "apns-token";
final String apnsVoipToken = "apns-voip-token"; final String apnsVoipToken = "apns-voip-token";
@ -596,9 +613,18 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(accountsManager.addDevice(any(), any())).thenAnswer(invocation -> {
final Account a = invocation.getArgument(0);
final DeviceSpec deviceSpec = invocation.getArgument(1);
return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock)));
});
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(), final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
new AccountAttributes(false, registrationId, pniRegistrationId, null, null, true, null), new AccountAttributes(false, registrationId, pniRegistrationId, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.of(new ApnRegistrationId("apn", null)), Optional.empty())); new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.of(new ApnRegistrationId("apn", null)), Optional.empty()));
@ -719,35 +745,66 @@ class DeviceControllerTest {
verifyNoMoreInteractions(messagesManager); verifyNoMoreInteractions(messagesManager);
} }
@Test @ParameterizedTest
void deviceDowngradePniTest() { @MethodSource
DeviceCapabilities deviceCapabilities = new DeviceCapabilities(true, true, void deviceDowngradePniTest(final boolean accountSupportsPni, final boolean deviceSupportsPni, final int expectedStatus) {
false, true); when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
AccountAttributes accountAttributes =
new AccountAttributes(false, 1234, 5678, null, null, true, deviceCapabilities);
final String verificationToken = deviceController.generateVerificationToken(AuthHelper.VALID_UUID); final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(primaryDevice));
Response response = resources final ECSignedPreKey aciSignedPreKey;
.getJerseyTest() final ECSignedPreKey pniSignedPreKey;
.target("/v1/devices/" + verificationToken) final KEMSignedPreKey aciPqLastResortPreKey;
final KEMSignedPreKey pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(account.isPniSupported()).thenReturn(accountSupportsPni);
when(accountsManager.addDevice(any(), any())).thenAnswer(invocation -> {
final Account a = invocation.getArgument(0);
final DeviceSpec deviceSpec = invocation.getArgument(1);
return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock)));
});
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final AccountAttributes accountAttributes = new AccountAttributes(false, 1234, 5678, null, null, true, new DeviceCapabilities(true, true, deviceSupportsPni, true));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID),
accountAttributes,
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id"))));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/link")
.request() .request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.header(HttpHeaders.USER_AGENT, "Signal-Android/5.42.8675309 Android/30") .put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) {
.put(Entity.entity(accountAttributes, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(409);
deviceCapabilities = new DeviceCapabilities(true, true, true, true); assertEquals(expectedStatus, response.getStatus());
accountAttributes = new AccountAttributes(false, 1234, 5678, null, null, true, deviceCapabilities); }
response = resources }
.getJerseyTest()
.target("/v1/devices/" + verificationToken) private static List<Arguments> deviceDowngradePniTest() {
.request() return List.of(
.header("Authorization", Arguments.of(true, true, 200),
AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD)) Arguments.of(true, false, 409),
.header(HttpHeaders.USER_AGENT, "Signal-Android/5.42.8675309 Android/30") Arguments.of(false, true, 200),
.put(Entity.entity(accountAttributes, MediaType.APPLICATION_JSON_TYPE)); Arguments.of(false, false, 200));
assertThat(response.getStatus()).isEqualTo(200);
} }
@Test @Test

View File

@ -9,7 +9,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -74,6 +73,7 @@ import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient; import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@ -167,7 +167,7 @@ class RegistrationControllerTest {
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(mock(Device.class)); when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) when(accountsManager.create(any(), any(), any(), any(), any(), any()))
.thenReturn(account); .thenReturn(account);
final String json = requestJson("sessionId", new byte[0], true, registrationId.orElse(0), pniRegistrationId.orElse(0)); final String json = requestJson("sessionId", new byte[0], true, registrationId.orElse(0), pniRegistrationId.orElse(0));
@ -290,7 +290,7 @@ class RegistrationControllerTest {
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(mock(Device.class)); when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) when(accountsManager.create(any(), any(), any(), any(), any(), any()))
.thenReturn(account); .thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
@ -348,7 +348,7 @@ class RegistrationControllerTest {
final Account createdAccount = mock(Account.class); final Account createdAccount = mock(Account.class);
when(createdAccount.getPrimaryDevice()).thenReturn(mock(Device.class)); when(createdAccount.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) when(accountsManager.create(any(), any(), any(), any(), any(), any()))
.thenReturn(createdAccount); .thenReturn(createdAccount);
expectedStatus = 200; expectedStatus = 200;
@ -402,7 +402,7 @@ class RegistrationControllerTest {
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(mock(Device.class)); when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) when(accountsManager.create(any(), any(), any(), any(), any(), any()))
.thenReturn(account); .thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
@ -426,7 +426,7 @@ class RegistrationControllerTest {
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(mock(Device.class)); when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) when(accountsManager.create(any(), any(), any(), any(), any(), any()))
.thenReturn(account); .thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
@ -658,16 +658,10 @@ class RegistrationControllerTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
void atomicAccountCreationSuccess(final RegistrationRequest registrationRequest, void atomicAccountCreationSuccess(final RegistrationRequest registrationRequest,
final IdentityKey expectedAciIdentityKey, final IdentityKey expectedAciIdentityKey,
final IdentityKey expectedPniIdentityKey, final IdentityKey expectedPniIdentityKey,
final ECSignedPreKey expectedAciSignedPreKey, final DeviceSpec expectedDeviceSpec) throws InterruptedException {
final ECSignedPreKey expectedPniSignedPreKey,
final KEMSignedPreKey expectedAciPqLastResortPreKey,
final KEMSignedPreKey expectedPniPqLastResortPreKey,
final Optional<ApnRegistrationId> expectedApnRegistrationId,
final Optional<GcmRegistrationId> expectedGcmRegistrationId) throws InterruptedException {
when(registrationServiceClient.getSession(any(), any())) when(registrationServiceClient.getSession(any(), any()))
.thenReturn( .thenReturn(
@ -685,7 +679,7 @@ class RegistrationControllerTest {
when(a.getPrimaryDevice()).thenReturn(device); when(a.getPrimaryDevice()).thenReturn(device);
}); });
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any())) when(accountsManager.create(any(), any(), any(), any(), any(), any()))
.thenReturn(account); .thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
@ -699,18 +693,11 @@ class RegistrationControllerTest {
verify(accountsManager).create( verify(accountsManager).create(
eq(NUMBER), eq(NUMBER),
eq(PASSWORD),
isNull(),
argThat(attributes -> accountAttributesEqual(attributes, registrationRequest.accountAttributes())), argThat(attributes -> accountAttributesEqual(attributes, registrationRequest.accountAttributes())),
eq(Collections.emptyList()), eq(Collections.emptyList()),
eq(expectedAciIdentityKey), eq(expectedAciIdentityKey),
eq(expectedPniIdentityKey), eq(expectedPniIdentityKey),
eq(expectedAciSignedPreKey), eq(expectedDeviceSpec));
eq(expectedPniSignedPreKey),
eq(expectedAciPqLastResortPreKey),
eq(expectedPniPqLastResortPreKey),
eq(expectedApnRegistrationId),
eq(expectedGcmRegistrationId));
} }
private static boolean accountAttributesEqual(final AccountAttributes a, final AccountAttributes b) { private static boolean accountAttributesEqual(final AccountAttributes a, final AccountAttributes b) {
@ -745,11 +732,17 @@ class RegistrationControllerTest {
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair); pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
} }
final byte[] deviceName = "test".getBytes(StandardCharsets.UTF_8);
final int registrationId = 1;
final int pniRegistrationId = 2;
final Device.DeviceCapabilities deviceCapabilities = new Device.DeviceCapabilities(false, false, false, false);
final AccountAttributes fetchesMessagesAccountAttributes = final AccountAttributes fetchesMessagesAccountAttributes =
new AccountAttributes(true, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true, new Device.DeviceCapabilities(false, false, false, false)); new AccountAttributes(true, registrationId, pniRegistrationId, "test".getBytes(StandardCharsets.UTF_8), null, true, new Device.DeviceCapabilities(false, false, false, false));
final AccountAttributes pushAccountAttributes = final AccountAttributes pushAccountAttributes =
new AccountAttributes(false, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true, new Device.DeviceCapabilities(false, false, false, false)); new AccountAttributes(false, registrationId, pniRegistrationId, "test".getBytes(StandardCharsets.UTF_8), null, true, new Device.DeviceCapabilities(false, false, false, false));
final String apnsToken = "apns-token"; final String apnsToken = "apns-token";
final String apnsVoipToken = "apns-voip-token"; final String apnsVoipToken = "apns-voip-token";
@ -771,13 +764,20 @@ class RegistrationControllerTest {
Optional.empty()), Optional.empty()),
aciIdentityKey, aciIdentityKey,
pniIdentityKey, pniIdentityKey,
aciSignedPreKey, new DeviceSpec(
pniSignedPreKey, deviceName,
aciPqLastResortPreKey, PASSWORD,
pniPqLastResortPreKey, null,
Optional.empty(), deviceCapabilities,
Optional.empty(), registrationId,
Optional.empty()), pniRegistrationId,
true,
Optional.empty(),
Optional.empty(),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey)),
// Has APNs tokens // Has APNs tokens
Arguments.of(new RegistrationRequest("session-id", Arguments.of(new RegistrationRequest("session-id",
@ -794,36 +794,22 @@ class RegistrationControllerTest {
Optional.empty()), Optional.empty()),
aciIdentityKey, aciIdentityKey,
pniIdentityKey, pniIdentityKey,
aciSignedPreKey, new DeviceSpec(
pniSignedPreKey, deviceName,
aciPqLastResortPreKey, PASSWORD,
pniPqLastResortPreKey, null,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), deviceCapabilities,
Optional.empty()), registrationId,
pniRegistrationId,
// requires the request to be atomic false,
Arguments.of(new RegistrationRequest("session-id", Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
new byte[0], Optional.empty(),
pushAccountAttributes,
true,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey, aciSignedPreKey,
pniSignedPreKey, pniSignedPreKey,
aciPqLastResortPreKey, aciPqLastResortPreKey,
pniPqLastResortPreKey, pniPqLastResortPreKey)),
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
// Fetches messages; no push tokens // Has GCM token
Arguments.of(new RegistrationRequest("session-id", Arguments.of(new RegistrationRequest("session-id",
new byte[0], new byte[0],
pushAccountAttributes, pushAccountAttributes,
@ -838,12 +824,21 @@ class RegistrationControllerTest {
Optional.of(new GcmRegistrationId(gcmToken))), Optional.of(new GcmRegistrationId(gcmToken))),
aciIdentityKey, aciIdentityKey,
pniIdentityKey, pniIdentityKey,
aciSignedPreKey, new DeviceSpec(
pniSignedPreKey, deviceName,
aciPqLastResortPreKey, PASSWORD,
pniPqLastResortPreKey, null,
Optional.empty(), deviceCapabilities,
Optional.of(new GcmRegistrationId(gcmToken)))); registrationId,
pniRegistrationId,
false,
Optional.empty(),
Optional.of(new GcmRegistrationId(gcmToken)),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey))
);
} }
/** /**

View File

@ -211,18 +211,24 @@ public class AccountCreationIntegrationTest {
: Optional.empty(); : Optional.empty();
final Account account = accountsManager.create(number, final Account account = accountsManager.create(number,
password,
signalAgent,
accountAttributes, accountAttributes,
badges, badges,
new IdentityKey(aciKeyPair.getPublicKey()), new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()), new IdentityKey(pniKeyPair.getPublicKey()),
aciSignedPreKey, new DeviceSpec(
pniSignedPreKey, deviceName,
aciPqLastResortPreKey, password,
pniPqLastResortPreKey, signalAgent,
maybeApnRegistrationId, deviceCapabilities,
maybeGcmRegistrationId); registrationId,
pniRegistrationId,
deliveryChannels.fetchesMessages(),
maybeApnRegistrationId,
maybeGcmRegistrationId,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey));
assertExpectedStoredAccount(account, assertExpectedStoredAccount(account,
number, number,
@ -264,18 +270,23 @@ public class AccountCreationIntegrationTest {
final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniKeyPair); final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniKeyPair);
final Account originalAccount = accountsManager.create(number, final Account originalAccount = accountsManager.create(number,
RandomStringUtils.randomAlphanumeric(16),
"OWI",
new AccountAttributes(true, 1, 1, "name".getBytes(StandardCharsets.UTF_8), "registration-lock", false, new Device.DeviceCapabilities(false, false, false, false)), new AccountAttributes(true, 1, 1, "name".getBytes(StandardCharsets.UTF_8), "registration-lock", false, new Device.DeviceCapabilities(false, false, false, false)),
Collections.emptyList(), Collections.emptyList(),
new IdentityKey(aciKeyPair.getPublicKey()), new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()), new IdentityKey(pniKeyPair.getPublicKey()),
aciSignedPreKey, new DeviceSpec(null,
pniSignedPreKey, "password?",
aciPqLastResortPreKey, "OWI",
pniPqLastResortPreKey, new Device.DeviceCapabilities(false, false, false, false),
Optional.empty(), 1,
Optional.empty()); 2,
true,
Optional.empty(),
Optional.empty(),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey));
existingAccountUuid = originalAccount.getUuid(); existingAccountUuid = originalAccount.getUuid();
} }
@ -324,18 +335,23 @@ public class AccountCreationIntegrationTest {
: Optional.empty(); : Optional.empty();
final Account reregisteredAccount = accountsManager.create(number, final Account reregisteredAccount = accountsManager.create(number,
password,
signalAgent,
accountAttributes, accountAttributes,
badges, badges,
new IdentityKey(aciKeyPair.getPublicKey()), new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()), new IdentityKey(pniKeyPair.getPublicKey()),
aciSignedPreKey, new DeviceSpec(deviceName,
pniSignedPreKey, password,
aciPqLastResortPreKey, signalAgent,
pniPqLastResortPreKey, deviceCapabilities,
maybeApnRegistrationId, registrationId,
maybeGcmRegistrationId); pniRegistrationId,
accountAttributes.getFetchesMessages(),
maybeApnRegistrationId,
maybeGcmRegistrationId,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey));
assertExpectedStoredAccount(reregisteredAccount, assertExpectedStoredAccount(reregisteredAccount,
number, number,

View File

@ -87,14 +87,6 @@ class AccountsManagerConcurrentModificationIntegrationTest {
mock(DynamicConfigurationManager.class); mock(DynamicConfigurationManager.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration()); when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
final KeysManager keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
Tables.EC_KEYS.tableName(),
Tables.PQ_KEYS.tableName(),
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(),
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName(),
dynamicConfigurationManager);
accounts = new Accounts( accounts = new Accounts(
DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
@ -157,18 +149,24 @@ class AccountsManagerConcurrentModificationIntegrationTest {
final Account account = accountsManager.update( final Account account = accountsManager.update(
accountsManager.create("+14155551212", accountsManager.create("+14155551212",
"password",
null,
new AccountAttributes(), new AccountAttributes(),
new ArrayList<>(), new ArrayList<>(),
new IdentityKey(aciKeyPair.getPublicKey()), new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()), new IdentityKey(pniKeyPair.getPublicKey()),
KeysHelper.signedECPreKey(1, aciKeyPair), new DeviceSpec(
KeysHelper.signedECPreKey(2, pniKeyPair), null,
KeysHelper.signedKEMPreKey(3, aciKeyPair), "password",
KeysHelper.signedKEMPreKey(4, pniKeyPair), null,
Optional.empty(), new Device.DeviceCapabilities(false, false, false, false),
Optional.empty()), 1,
2,
true,
Optional.empty(),
Optional.empty(),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair))),
a -> { a -> {
a.setUnidentifiedAccessKey(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); a.setUnidentifiedAccessKey(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
a.removeDevice(Device.PRIMARY_ID); a.removeDevice(Device.PRIMARY_ID);

View File

@ -31,12 +31,12 @@ import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import io.lettuce.core.RedisException; import io.lettuce.core.RedisException;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.io.InputStream; import java.io.InputStream;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Base64; import java.util.Base64;
@ -85,6 +85,8 @@ import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture; import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock;
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) @Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class AccountsManagerTest { class AccountsManagerTest {
@ -109,6 +111,7 @@ class AccountsManagerTest {
private RedisAdvancedClusterCommands<String, String> commands; private RedisAdvancedClusterCommands<String, String> commands;
private RedisAdvancedClusterAsyncCommands<String, String> asyncCommands; private RedisAdvancedClusterAsyncCommands<String, String> asyncCommands;
private TestClock clock;
private AccountsManager accountsManager; private AccountsManager accountsManager;
private static final Answer<?> ACCOUNT_UPDATE_ANSWER = (answer) -> { private static final Answer<?> ACCOUNT_UPDATE_ANSWER = (answer) -> {
@ -219,6 +222,8 @@ class AccountsManagerTest {
when(messagesManager.clear(any())).thenReturn(CompletableFuture.completedFuture(null)); when(messagesManager.clear(any())).thenReturn(CompletableFuture.completedFuture(null));
when(profilesManager.deleteAll(any())).thenReturn(CompletableFuture.completedFuture(null)); when(profilesManager.deleteAll(any())).thenReturn(CompletableFuture.completedFuture(null));
clock = TestClock.now();
accountsManager = new AccountsManager( accountsManager = new AccountsManager(
accounts, accounts,
phoneNumberIdentifiers, phoneNumberIdentifiers,
@ -237,7 +242,7 @@ class AccountsManagerTest {
registrationRecoveryPasswordsManager, registrationRecoveryPasswordsManager,
mock(Executor.class), mock(Executor.class),
clientPresenceExecutor, clientPresenceExecutor,
mock(Clock.class)); clock);
} }
@Test @Test
@ -1074,6 +1079,84 @@ class AccountsManagerTest {
assertEquals(hasStorage, account.isStorageSupported()); assertEquals(hasStorage, account.isStorageSupported());
} }
@Test
void testAddDevice() {
final String phoneNumber =
PhoneNumberUtil.getInstance().format(PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
final Account account = AccountsHelper.generateTestAccount(phoneNumber, List.of(generateTestDevice(clock.millis())));
final UUID aci = account.getIdentifier(IdentityType.ACI);
final UUID pni = account.getIdentifier(IdentityType.PNI);
final byte nextDeviceId = account.getNextDeviceId();
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
final byte[] deviceNameCiphertext = "device-name".getBytes(StandardCharsets.UTF_8);
final String password = "password";
final String signalAgent = "OWT";
final DeviceCapabilities deviceCapabilities = new DeviceCapabilities(true, true, true, true);
final int aciRegistrationId = 17;
final int pniRegistrationId = 19;
final ECSignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciKeyPair);
final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniKeyPair);
final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciKeyPair);
final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniKeyPair);
when(keysManager.delete(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.clear(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
when(accounts.getByAccountIdentifierAsync(aci)).thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(accounts.updateTransactionallyAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
clock.pin(clock.instant().plusSeconds(60));
final Pair<Account, Device> updatedAccountAndDevice = accountsManager.addDevice(account, new DeviceSpec(
deviceNameCiphertext,
password,
signalAgent,
deviceCapabilities,
aciRegistrationId,
pniRegistrationId,
true,
Optional.empty(),
Optional.empty(),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey))
.join();
verify(keysManager).delete(aci, nextDeviceId);
verify(keysManager).delete(pni, nextDeviceId);
verify(messagesManager).clear(aci, nextDeviceId);
verify(keysManager).buildWriteItemsForRepeatedUseKeys(
aci,
pni,
nextDeviceId,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey);
final Device device = updatedAccountAndDevice.second();
assertEquals(deviceNameCiphertext, device.getName());
assertTrue(device.getAuthTokenHash().verify(password));
assertEquals(signalAgent, device.getUserAgent());
assertEquals(deviceCapabilities, device.getCapabilities());
assertEquals(aciRegistrationId, device.getRegistrationId());
assertEquals(pniRegistrationId, device.getPhoneNumberIdentityRegistrationId().getAsInt());
assertTrue(device.getFetchesMessages());
assertNull(device.getApnId());
assertNull(device.getVoipApnId());
assertNull(device.getGcmId());
assertEquals(aciSignedPreKey, device.getSignedPreKey(IdentityType.ACI));
assertEquals(pniSignedPreKey, device.getSignedPreKey(IdentityType.PNI));
}
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void testUpdateDeviceLastSeen(final boolean expectUpdate, final long initialLastSeen, final long updatedLastSeen) { void testUpdateDeviceLastSeen(final boolean expectUpdate, final long initialLastSeen, final long updatedLastSeen) {
@ -1649,17 +1732,23 @@ class AccountsManagerTest {
final ECKeyPair pniKeyPair = Curve.generateKeyPair(); final ECKeyPair pniKeyPair = Curve.generateKeyPair();
return accountsManager.create(e164, return accountsManager.create(e164,
"password",
null,
accountAttributes, accountAttributes,
new ArrayList<>(), new ArrayList<>(),
new IdentityKey(aciKeyPair.getPublicKey()), new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()), new IdentityKey(pniKeyPair.getPublicKey()),
KeysHelper.signedECPreKey(1, aciKeyPair), new DeviceSpec(
KeysHelper.signedECPreKey(2, pniKeyPair), accountAttributes.getName(),
KeysHelper.signedKEMPreKey(3, aciKeyPair), "password",
KeysHelper.signedKEMPreKey(4, pniKeyPair), null,
Optional.empty(), accountAttributes.getCapabilities(),
Optional.empty()); accountAttributes.getRegistrationId(),
accountAttributes.getPhoneNumberIdentityRegistrationId(),
accountAttributes.getFetchesMessages(),
Optional.empty(),
Optional.empty(),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)));
} }
} }

View File

@ -30,6 +30,7 @@ import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -171,17 +172,23 @@ public class AccountsHelper {
final ECKeyPair pniKeyPair = Curve.generateKeyPair(); final ECKeyPair pniKeyPair = Curve.generateKeyPair();
return accountsManager.create(e164, return accountsManager.create(e164,
"password",
null,
accountAttributes, accountAttributes,
new ArrayList<>(), new ArrayList<>(),
new IdentityKey(aciKeyPair.getPublicKey()), new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()), new IdentityKey(pniKeyPair.getPublicKey()),
KeysHelper.signedECPreKey(1, aciKeyPair), new DeviceSpec(
KeysHelper.signedECPreKey(2, pniKeyPair), accountAttributes.getName(),
KeysHelper.signedKEMPreKey(3, aciKeyPair), "password",
KeysHelper.signedKEMPreKey(4, pniKeyPair), "OWT",
Optional.empty(), accountAttributes.getCapabilities(),
Optional.empty()); accountAttributes.getRegistrationId(),
accountAttributes.getPhoneNumberIdentityRegistrationId(),
accountAttributes.getFetchesMessages(),
Optional.empty(),
Optional.empty(),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)));
} }
} }