Add devices to accounts transactionally

This commit is contained in:
Jon Chambers 2023-12-07 11:19:40 -05:00 committed by GitHub
parent e084a9f2b6
commit 50d92265ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 520 additions and 268 deletions

View File

@ -68,6 +68,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Pair;
@ -403,60 +404,63 @@ public class DeviceController {
throw new WebApplicationException(Response.status(409).build());
}
final Device device = new Device();
device.setName(accountAttributes.getName());
device.setAuthTokenHash(SaltedTokenHash.generateFor(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
device.setPhoneNumberIdentityRegistrationId(accountAttributes.getPhoneNumberIdentityRegistrationId());
device.setLastSeen(Util.todayInMillis());
device.setCreated(System.currentTimeMillis());
device.setCapabilities(accountAttributes.getCapabilities());
return maybeDeviceActivationRequest.map(deviceActivationRequest -> {
final String signalAgent;
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> {
device.setSignedPreKey(deviceActivationRequest.aciSignedPreKey());
device.setPhoneNumberIdentitySignedPreKey(deviceActivationRequest.pniSignedPreKey());
if (deviceActivationRequest.apnToken().isPresent()) {
signalAgent = "OWP";
} else if (deviceActivationRequest.gcmToken().isPresent()) {
signalAgent = "OWA";
} else {
signalAgent = "OWD";
}
deviceActivationRequest.apnToken().ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId());
device.setVoipApnId(apnRegistrationId.voipRegistrationId());
});
return accounts.addDevice(account, new DeviceSpec(accountAttributes.getName(),
password,
signalAgent,
capabilities,
accountAttributes.getRegistrationId(),
accountAttributes.getPhoneNumberIdentityRegistrationId(),
accountAttributes.getFetchesMessages(),
deviceActivationRequest.apnToken(),
deviceActivationRequest.gcmToken(),
deviceActivationRequest.aciSignedPreKey(),
deviceActivationRequest.pniSignedPreKey(),
deviceActivationRequest.aciPqLastResortPreKey(),
deviceActivationRequest.pniPqLastResortPreKey()))
.thenCompose(a -> usedTokenCluster.withCluster(connection -> connection.async()
.set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION)))
.thenApply(ignored -> a))
.join();
})
.orElseGet(() -> {
final Device device = new Device();
device.setName(accountAttributes.getName());
device.setAuthTokenHash(SaltedTokenHash.generateFor(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
device.setPhoneNumberIdentityRegistrationId(accountAttributes.getPhoneNumberIdentityRegistrationId());
device.setLastSeen(Util.todayInMillis());
device.setCreated(System.currentTimeMillis());
device.setCapabilities(accountAttributes.getCapabilities());
deviceActivationRequest.gcmToken().ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
});
final Account updatedAccount = accounts.update(account, a -> {
device.setId(a.getNextDeviceId());
final Account updatedAccount = accounts.update(account, a -> {
device.setId(a.getNextDeviceId());
CompletableFuture.allOf(
keys.delete(a.getUuid(), device.getId()),
keys.delete(a.getPhoneNumberIdentifier(), device.getId()),
messages.clear(a.getUuid(), device.getId()))
.join();
final CompletableFuture<Void> deleteKeysFuture = CompletableFuture.allOf(
keys.delete(a.getUuid(), device.getId()),
keys.delete(a.getPhoneNumberIdentifier(), device.getId()));
a.addDevice(device);
});
messages.clear(a.getUuid(), device.getId()).join();
usedTokenCluster.useCluster(connection ->
connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION)));
deleteKeysFuture.join();
maybeDeviceActivationRequest.ifPresent(deviceActivationRequest -> CompletableFuture.allOf(
keys.storeEcSignedPreKeys(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciSignedPreKey())),
keys.storePqLastResort(a.getUuid(),
Map.of(device.getId(), deviceActivationRequest.aciPqLastResortPreKey())),
keys.storeEcSignedPreKeys(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniSignedPreKey())),
keys.storePqLastResort(a.getPhoneNumberIdentifier(),
Map.of(device.getId(), deviceActivationRequest.pniPqLastResortPreKey())))
.join());
a.addDevice(device);
});
if (maybeAciFromToken.isPresent()) {
usedTokenCluster.useCluster(connection ->
connection.sync().set(getUsedTokenKey(verificationCode), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION)));
}
return new Pair<>(updatedAccount, device);
return new Pair<>(updatedAccount, device);
});
}
private static String getUsedTokenKey(final String token) {

View File

@ -43,6 +43,7 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;
@ -140,18 +141,24 @@ public class RegistrationController {
}
final Account account = accounts.create(number,
password,
signalAgent,
registrationRequest.accountAttributes(),
existingAccount.map(Account::getBadges).orElseGet(ArrayList::new),
registrationRequest.aciIdentityKey(),
registrationRequest.pniIdentityKey(),
registrationRequest.deviceActivationRequest().aciSignedPreKey(),
registrationRequest.deviceActivationRequest().pniSignedPreKey(),
registrationRequest.deviceActivationRequest().aciPqLastResortPreKey(),
registrationRequest.deviceActivationRequest().pniPqLastResortPreKey(),
registrationRequest.deviceActivationRequest().apnToken(),
registrationRequest.deviceActivationRequest().gcmToken());
new DeviceSpec(
registrationRequest.accountAttributes().getName(),
password,
signalAgent,
registrationRequest.accountAttributes().getCapabilities(),
registrationRequest.accountAttributes().getRegistrationId(),
registrationRequest.accountAttributes().getPhoneNumberIdentityRegistrationId(),
registrationRequest.accountAttributes().getFetchesMessages(),
registrationRequest.deviceActivationRequest().apnToken(),
registrationRequest.deviceActivationRequest().gcmToken(),
registrationRequest.deviceActivationRequest().aciSignedPreKey(),
registrationRequest.deviceActivationRequest().pniSignedPreKey(),
registrationRequest.deviceActivationRequest().aciPqLastResortPreKey(),
registrationRequest.deviceActivationRequest().pniPqLastResortPreKey()));
Metrics.counter(ACCOUNT_CREATED_COUNTER_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(COUNTRY_CODE_TAG_NAME, Util.getCountryCode(number)),

View File

@ -53,9 +53,7 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
@ -68,6 +66,7 @@ import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.DestinationDeviceValidator;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.ParallelFlux;
@ -132,11 +131,6 @@ public class AccountsManager {
private static final int MAX_UPDATE_ATTEMPTS = 10;
@FunctionalInterface
private interface AccountPersister {
void persistAccount(Account account) throws UsernameHashNotAvailableException;
}
public enum DeletionReason {
ADMIN_DELETED("admin"),
EXPIRED ("expired"),
@ -181,46 +175,18 @@ public class AccountsManager {
this.clock = requireNonNull(clock);
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public Account create(final String number,
final String password,
final String signalAgent,
final AccountAttributes accountAttributes,
final List<AccountBadge> accountBadges,
final IdentityKey aciIdentityKey,
final IdentityKey pniIdentityKey,
final ECSignedPreKey aciSignedPreKey,
final ECSignedPreKey pniSignedPreKey,
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniPqLastResortPreKey,
final Optional<ApnRegistrationId> maybeApnRegistrationId,
final Optional<GcmRegistrationId> maybeGcmRegistrationId) throws InterruptedException {
final DeviceSpec primaryDeviceSpec) throws InterruptedException {
try (Timer.Context ignored = createTimer.time()) {
final Account account = new Account();
accountLockManager.withLock(List.of(number), () -> {
final Device device = new Device();
device.setId(Device.PRIMARY_ID);
device.setAuthTokenHash(SaltedTokenHash.generateFor(password));
device.setFetchesMessages(accountAttributes.getFetchesMessages());
device.setRegistrationId(accountAttributes.getRegistrationId());
device.setPhoneNumberIdentityRegistrationId(accountAttributes.getPhoneNumberIdentityRegistrationId());
device.setName(accountAttributes.getName());
device.setCapabilities(accountAttributes.getCapabilities());
device.setCreated(System.currentTimeMillis());
device.setLastSeen(Util.todayInMillis());
device.setUserAgent(signalAgent);
device.setSignedPreKey(aciSignedPreKey);
device.setPhoneNumberIdentitySignedPreKey(pniSignedPreKey);
maybeApnRegistrationId.ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId());
device.setVoipApnId(apnRegistrationId.voipRegistrationId());
});
maybeGcmRegistrationId.ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
final Device device = primaryDeviceSpec.toDevice(Device.PRIMARY_ID, clock);
account.setNumber(number, phoneNumberIdentifiers.getPhoneNumberIdentifier(number));
@ -245,10 +211,10 @@ public class AccountsManager {
a -> keysManager.buildWriteItemsForRepeatedUseKeys(a.getIdentifier(IdentityType.ACI),
a.getIdentifier(IdentityType.PNI),
Device.PRIMARY_ID,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey),
primaryDeviceSpec.aciSignedPreKey(),
primaryDeviceSpec.pniSignedPreKey(),
primaryDeviceSpec.aciPqLastResortPreKey(),
primaryDeviceSpec.pniPqLastResortPreKey()),
(aci, pni) -> CompletableFuture.allOf(
keysManager.delete(aci),
keysManager.delete(pni),
@ -299,6 +265,42 @@ public class AccountsManager {
}
}
public CompletableFuture<Pair<Account, Device>> addDevice(final Account account, final DeviceSpec deviceSpec) {
return addDevice(account.getIdentifier(IdentityType.ACI), deviceSpec, MAX_UPDATE_ATTEMPTS);
}
private CompletableFuture<Pair<Account, Device>> addDevice(final UUID accountIdentifier, final DeviceSpec deviceSpec, final int retries) {
return accounts.getByAccountIdentifierAsync(accountIdentifier)
.thenApply(maybeAccount -> maybeAccount.orElseThrow(ContestedOptimisticLockException::new))
.thenCompose(account -> {
final byte nextDeviceId = account.getNextDeviceId();
account.addDevice(deviceSpec.toDevice(nextDeviceId, clock));
final List<TransactWriteItem> additionalWriteItems = keysManager.buildWriteItemsForRepeatedUseKeys(
account.getIdentifier(IdentityType.ACI),
account.getIdentifier(IdentityType.PNI),
nextDeviceId,
deviceSpec.aciSignedPreKey(),
deviceSpec.pniSignedPreKey(),
deviceSpec.aciPqLastResortPreKey(),
deviceSpec.pniPqLastResortPreKey());
return CompletableFuture.allOf(
keysManager.delete(account.getUuid(), nextDeviceId),
keysManager.delete(account.getPhoneNumberIdentifier(), nextDeviceId),
messagesManager.clear(account.getUuid(), nextDeviceId))
.thenCompose(ignored -> accounts.updateTransactionallyAsync(account, additionalWriteItems))
.thenApply(ignored -> new Pair<>(account, account.getDevice(nextDeviceId).orElseThrow()));
})
.exceptionallyCompose(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof ContestedOptimisticLockException && retries > 0) {
return addDevice(accountIdentifier, deviceSpec, retries - 1);
}
return CompletableFuture.failedFuture(throwable);
});
}
public CompletableFuture<Account> removeDevice(final Account account, final byte deviceId) {
if (deviceId == Device.PRIMARY_ID) {
throw new IllegalArgumentException("Cannot remove primary device");
@ -705,19 +707,6 @@ public class AccountsManager {
final Consumer<Account> persister,
final Supplier<Account> retriever,
final AccountChangeValidator changeValidator) {
try {
return failableUpdateWithRetries(account, updater, persister::accept, retriever, changeValidator);
} catch (UsernameHashNotAvailableException e) {
// not possible
throw new IllegalStateException(e);
}
}
private Account failableUpdateWithRetries(Account account,
final Function<Account, Boolean> updater,
final AccountPersister persister,
final Supplier<Account> retriever,
final AccountChangeValidator changeValidator) throws UsernameHashNotAvailableException {
Account originalAccount = AccountUtil.cloneAccountAsNotStale(account);
@ -731,7 +720,7 @@ public class AccountsManager {
while (tries < maxTries) {
try {
persister.persistAccount(account);
persister.accept(account);
final Account updatedAccount = AccountUtil.cloneAccountAsNotStale(account);
account.markStale();

View File

@ -0,0 +1,90 @@
package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.util.Util;
import java.time.Clock;
import java.util.Arrays;
import java.util.Objects;
import java.util.Optional;
public record DeviceSpec(
byte[] deviceNameCiphertext,
String password,
String signalAgent,
Device.DeviceCapabilities capabilities,
int aciRegistrationId,
int pniRegistrationId,
boolean fetchesMessages,
Optional<ApnRegistrationId> apnRegistrationId,
Optional<GcmRegistrationId> gcmRegistrationId,
ECSignedPreKey aciSignedPreKey,
ECSignedPreKey pniSignedPreKey,
KEMSignedPreKey aciPqLastResortPreKey,
KEMSignedPreKey pniPqLastResortPreKey) {
public Device toDevice(final byte deviceId, final Clock clock) {
final Device device = new Device();
device.setId(deviceId);
device.setAuthTokenHash(SaltedTokenHash.generateFor(password()));
device.setFetchesMessages(fetchesMessages());
device.setRegistrationId(aciRegistrationId());
device.setPhoneNumberIdentityRegistrationId(pniRegistrationId());
device.setName(deviceNameCiphertext());
device.setCapabilities(capabilities());
device.setCreated(clock.millis());
device.setLastSeen(Util.todayInMillis());
device.setUserAgent(signalAgent());
device.setSignedPreKey(aciSignedPreKey());
device.setPhoneNumberIdentitySignedPreKey(pniSignedPreKey());
apnRegistrationId().ifPresent(apnRegistrationId -> {
device.setApnId(apnRegistrationId.apnRegistrationId());
device.setVoipApnId(apnRegistrationId.voipRegistrationId());
});
gcmRegistrationId().ifPresent(gcmRegistrationId ->
device.setGcmId(gcmRegistrationId.gcmRegistrationId()));
return device;
}
@Override
public boolean equals(final Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
final DeviceSpec that = (DeviceSpec) o;
return aciRegistrationId == that.aciRegistrationId
&& pniRegistrationId == that.pniRegistrationId
&& fetchesMessages == that.fetchesMessages
&& Arrays.equals(deviceNameCiphertext, that.deviceNameCiphertext)
&& Objects.equals(password, that.password)
&& Objects.equals(signalAgent, that.signalAgent)
&& Objects.equals(capabilities, that.capabilities)
&& Objects.equals(apnRegistrationId, that.apnRegistrationId)
&& Objects.equals(gcmRegistrationId, that.gcmRegistrationId)
&& Objects.equals(aciSignedPreKey, that.aciSignedPreKey)
&& Objects.equals(pniSignedPreKey, that.pniSignedPreKey)
&& Objects.equals(aciPqLastResortPreKey, that.aciPqLastResortPreKey)
&& Objects.equals(pniPqLastResortPreKey, that.pniPqLastResortPreKey);
}
@Override
public int hashCode() {
int result = Objects.hash(password, signalAgent, capabilities, aciRegistrationId, pniRegistrationId,
fetchesMessages, apnRegistrationId, gcmRegistrationId, aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey,
pniPqLastResortPreKey);
result = 31 * result + Arrays.hashCode(deviceNameCiphertext);
return result;
}
}

View File

@ -24,6 +24,7 @@ import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
@ -38,6 +39,7 @@ import java.util.stream.Stream;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.glassfish.jersey.server.ServerProperties;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
@ -72,12 +74,15 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.VerificationCode;
@ -91,6 +96,7 @@ class DeviceControllerTest {
private static RateLimiters rateLimiters = mock(RateLimiters.class);
private static RateLimiter rateLimiter = mock(RateLimiter.class);
private static RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
private static RedisAdvancedClusterAsyncCommands<String, String> asyncCommands = mock(RedisAdvancedClusterAsyncCommands.class);
private static Account account = mock(Account.class);
private static Account maxedAccount = mock(Account.class);
private static Device primaryDevice = mock(Device.class);
@ -106,7 +112,10 @@ class DeviceControllerTest {
messagesManager,
keysManager,
rateLimiters,
RedisClusterHelper.builder().stringCommands(commands).build(),
RedisClusterHelper.builder()
.stringCommands(commands)
.stringAsyncCommands(asyncCommands)
.build(),
deviceConfiguration,
testClock);
@ -114,6 +123,7 @@ class DeviceControllerTest {
public static final AuthHelper.AuthFilterExtension AUTH_FILTER_EXTENSION = new AuthHelper.AuthFilterExtension();
private static final ResourceExtension resources = ResourceExtension.builder()
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
@ -166,6 +176,7 @@ class DeviceControllerTest {
rateLimiters,
rateLimiter,
commands,
asyncCommands,
account,
maxedAccount,
primaryDevice,
@ -300,11 +311,22 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(accountsManager.addDevice(any(), any())).thenAnswer(invocation -> {
final Account a = invocation.getArgument(0);
final DeviceSpec deviceSpec = invocation.getArgument(1);
return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock)));
});
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final AccountAttributes accountAttributes = new AccountAttributes(fetchesMessages, 1234, 5678, null, null, true, null);
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
new AccountAttributes(fetchesMessages, 1234, 5678, null, null, true, null),
accountAttributes,
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId));
final DeviceResponse response = resources.getJerseyTest()
@ -315,10 +337,10 @@ class DeviceControllerTest {
assertThat(response.getDeviceId()).isEqualTo(NEXT_DEVICE_ID);
final ArgumentCaptor<Device> deviceCaptor = ArgumentCaptor.forClass(Device.class);
verify(account).addDevice(deviceCaptor.capture());
final ArgumentCaptor<DeviceSpec> deviceSpecCaptor = ArgumentCaptor.forClass(DeviceSpec.class);
verify(accountsManager).addDevice(eq(account), deviceSpecCaptor.capture());
final Device device = deviceCaptor.getValue();
final Device device = deviceSpecCaptor.getValue().toDevice(NEXT_DEVICE_ID, testClock);
assertEquals(aciSignedPreKey, device.getSignedPreKey(IdentityType.ACI));
assertEquals(pniSignedPreKey, device.getSignedPreKey(IdentityType.PNI));
@ -333,14 +355,9 @@ class DeviceControllerTest {
expectedGcmToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getGcmId()),
() -> assertNull(device.getGcmId()));
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(NEXT_DEVICE_ID));
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciSignedPreKey));
verify(keysManager).storeEcSignedPreKeys(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniSignedPreKey));
verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey));
verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey));
verify(commands).set(anyString(), anyString(), any());
verify(asyncCommands).set(anyString(), anyString(), any());
}
private static Stream<Arguments> linkDeviceAtomic() {
final String apnsToken = "apns-token";
final String apnsVoipToken = "apns-voip-token";
@ -596,9 +613,18 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(accountsManager.addDevice(any(), any())).thenAnswer(invocation -> {
final Account a = invocation.getArgument(0);
final DeviceSpec deviceSpec = invocation.getArgument(1);
return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock)));
});
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
new AccountAttributes(false, registrationId, pniRegistrationId, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.of(new ApnRegistrationId("apn", null)), Optional.empty()));
@ -719,35 +745,66 @@ class DeviceControllerTest {
verifyNoMoreInteractions(messagesManager);
}
@Test
void deviceDowngradePniTest() {
DeviceCapabilities deviceCapabilities = new DeviceCapabilities(true, true,
false, true);
AccountAttributes accountAttributes =
new AccountAttributes(false, 1234, 5678, null, null, true, deviceCapabilities);
@ParameterizedTest
@MethodSource
void deviceDowngradePniTest(final boolean accountSupportsPni, final boolean deviceSupportsPni, final int expectedStatus) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
final String verificationToken = deviceController.generateVerificationToken(AuthHelper.VALID_UUID);
final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(primaryDevice));
Response response = resources
.getJerseyTest()
.target("/v1/devices/" + verificationToken)
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
final KEMSignedPreKey aciPqLastResortPreKey;
final KEMSignedPreKey pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(account.isPniSupported()).thenReturn(accountSupportsPni);
when(accountsManager.addDevice(any(), any())).thenAnswer(invocation -> {
final Account a = invocation.getArgument(0);
final DeviceSpec deviceSpec = invocation.getArgument(1);
return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock)));
});
when(keysManager.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(keysManager.storePqLastResort(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final AccountAttributes accountAttributes = new AccountAttributes(false, 1234, 5678, null, null, true, new DeviceCapabilities(true, true, deviceSupportsPni, true));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID),
accountAttributes,
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id"))));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/link")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.USER_AGENT, "Signal-Android/5.42.8675309 Android/30")
.put(Entity.entity(accountAttributes, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(409);
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) {
deviceCapabilities = new DeviceCapabilities(true, true, true, true);
accountAttributes = new AccountAttributes(false, 1234, 5678, null, null, true, deviceCapabilities);
response = resources
.getJerseyTest()
.target("/v1/devices/" + verificationToken)
.request()
.header("Authorization",
AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.header(HttpHeaders.USER_AGENT, "Signal-Android/5.42.8675309 Android/30")
.put(Entity.entity(accountAttributes, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(200);
assertEquals(expectedStatus, response.getStatus());
}
}
private static List<Arguments> deviceDowngradePniTest() {
return List.of(
Arguments.of(true, true, 200),
Arguments.of(true, false, 409),
Arguments.of(false, true, 200),
Arguments.of(false, false, 200));
}
@Test

View File

@ -9,7 +9,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@ -74,6 +73,7 @@ import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper
import org.whispersystems.textsecuregcm.registration.RegistrationServiceClient;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@ -167,7 +167,7 @@ class RegistrationControllerTest {
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
when(accountsManager.create(any(), any(), any(), any(), any(), any()))
.thenReturn(account);
final String json = requestJson("sessionId", new byte[0], true, registrationId.orElse(0), pniRegistrationId.orElse(0));
@ -290,7 +290,7 @@ class RegistrationControllerTest {
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
when(accountsManager.create(any(), any(), any(), any(), any(), any()))
.thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest()
@ -348,7 +348,7 @@ class RegistrationControllerTest {
final Account createdAccount = mock(Account.class);
when(createdAccount.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
when(accountsManager.create(any(), any(), any(), any(), any(), any()))
.thenReturn(createdAccount);
expectedStatus = 200;
@ -402,7 +402,7 @@ class RegistrationControllerTest {
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
when(accountsManager.create(any(), any(), any(), any(), any(), any()))
.thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest()
@ -426,7 +426,7 @@ class RegistrationControllerTest {
final Account account = mock(Account.class);
when(account.getPrimaryDevice()).thenReturn(mock(Device.class));
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
when(accountsManager.create(any(), any(), any(), any(), any(), any()))
.thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest()
@ -658,16 +658,10 @@ class RegistrationControllerTest {
@ParameterizedTest
@MethodSource
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
void atomicAccountCreationSuccess(final RegistrationRequest registrationRequest,
final IdentityKey expectedAciIdentityKey,
final IdentityKey expectedPniIdentityKey,
final ECSignedPreKey expectedAciSignedPreKey,
final ECSignedPreKey expectedPniSignedPreKey,
final KEMSignedPreKey expectedAciPqLastResortPreKey,
final KEMSignedPreKey expectedPniPqLastResortPreKey,
final Optional<ApnRegistrationId> expectedApnRegistrationId,
final Optional<GcmRegistrationId> expectedGcmRegistrationId) throws InterruptedException {
final DeviceSpec expectedDeviceSpec) throws InterruptedException {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
@ -685,7 +679,7 @@ class RegistrationControllerTest {
when(a.getPrimaryDevice()).thenReturn(device);
});
when(accountsManager.create(any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any(), any()))
when(accountsManager.create(any(), any(), any(), any(), any(), any()))
.thenReturn(account);
final Invocation.Builder request = resources.getJerseyTest()
@ -699,18 +693,11 @@ class RegistrationControllerTest {
verify(accountsManager).create(
eq(NUMBER),
eq(PASSWORD),
isNull(),
argThat(attributes -> accountAttributesEqual(attributes, registrationRequest.accountAttributes())),
eq(Collections.emptyList()),
eq(expectedAciIdentityKey),
eq(expectedPniIdentityKey),
eq(expectedAciSignedPreKey),
eq(expectedPniSignedPreKey),
eq(expectedAciPqLastResortPreKey),
eq(expectedPniPqLastResortPreKey),
eq(expectedApnRegistrationId),
eq(expectedGcmRegistrationId));
eq(expectedDeviceSpec));
}
private static boolean accountAttributesEqual(final AccountAttributes a, final AccountAttributes b) {
@ -745,11 +732,17 @@ class RegistrationControllerTest {
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
}
final byte[] deviceName = "test".getBytes(StandardCharsets.UTF_8);
final int registrationId = 1;
final int pniRegistrationId = 2;
final Device.DeviceCapabilities deviceCapabilities = new Device.DeviceCapabilities(false, false, false, false);
final AccountAttributes fetchesMessagesAccountAttributes =
new AccountAttributes(true, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true, new Device.DeviceCapabilities(false, false, false, false));
new AccountAttributes(true, registrationId, pniRegistrationId, "test".getBytes(StandardCharsets.UTF_8), null, true, new Device.DeviceCapabilities(false, false, false, false));
final AccountAttributes pushAccountAttributes =
new AccountAttributes(false, 1, 1, "test".getBytes(StandardCharsets.UTF_8), null, true, new Device.DeviceCapabilities(false, false, false, false));
new AccountAttributes(false, registrationId, pniRegistrationId, "test".getBytes(StandardCharsets.UTF_8), null, true, new Device.DeviceCapabilities(false, false, false, false));
final String apnsToken = "apns-token";
final String apnsVoipToken = "apns-voip-token";
@ -771,13 +764,20 @@ class RegistrationControllerTest {
Optional.empty()),
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.empty(),
Optional.empty()),
new DeviceSpec(
deviceName,
PASSWORD,
null,
deviceCapabilities,
registrationId,
pniRegistrationId,
true,
Optional.empty(),
Optional.empty(),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey)),
// Has APNs tokens
Arguments.of(new RegistrationRequest("session-id",
@ -794,36 +794,22 @@ class RegistrationControllerTest {
Optional.empty()),
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
// requires the request to be atomic
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
true,
aciIdentityKey,
pniIdentityKey,
new DeviceSpec(
deviceName,
PASSWORD,
null,
deviceCapabilities,
registrationId,
pniRegistrationId,
false,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty(),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
pniPqLastResortPreKey)),
// Fetches messages; no push tokens
// Has GCM token
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
@ -838,12 +824,21 @@ class RegistrationControllerTest {
Optional.of(new GcmRegistrationId(gcmToken))),
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.of(new GcmRegistrationId(gcmToken))));
new DeviceSpec(
deviceName,
PASSWORD,
null,
deviceCapabilities,
registrationId,
pniRegistrationId,
false,
Optional.empty(),
Optional.of(new GcmRegistrationId(gcmToken)),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey))
);
}
/**

View File

@ -211,18 +211,24 @@ public class AccountCreationIntegrationTest {
: Optional.empty();
final Account account = accountsManager.create(number,
password,
signalAgent,
accountAttributes,
badges,
new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
maybeApnRegistrationId,
maybeGcmRegistrationId);
new DeviceSpec(
deviceName,
password,
signalAgent,
deviceCapabilities,
registrationId,
pniRegistrationId,
deliveryChannels.fetchesMessages(),
maybeApnRegistrationId,
maybeGcmRegistrationId,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey));
assertExpectedStoredAccount(account,
number,
@ -264,18 +270,23 @@ public class AccountCreationIntegrationTest {
final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniKeyPair);
final Account originalAccount = accountsManager.create(number,
RandomStringUtils.randomAlphanumeric(16),
"OWI",
new AccountAttributes(true, 1, 1, "name".getBytes(StandardCharsets.UTF_8), "registration-lock", false, new Device.DeviceCapabilities(false, false, false, false)),
Collections.emptyList(),
new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.empty());
new DeviceSpec(null,
"password?",
"OWI",
new Device.DeviceCapabilities(false, false, false, false),
1,
2,
true,
Optional.empty(),
Optional.empty(),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey));
existingAccountUuid = originalAccount.getUuid();
}
@ -324,18 +335,23 @@ public class AccountCreationIntegrationTest {
: Optional.empty();
final Account reregisteredAccount = accountsManager.create(number,
password,
signalAgent,
accountAttributes,
badges,
new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
maybeApnRegistrationId,
maybeGcmRegistrationId);
new DeviceSpec(deviceName,
password,
signalAgent,
deviceCapabilities,
registrationId,
pniRegistrationId,
accountAttributes.getFetchesMessages(),
maybeApnRegistrationId,
maybeGcmRegistrationId,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey));
assertExpectedStoredAccount(reregisteredAccount,
number,

View File

@ -87,14 +87,6 @@ class AccountsManagerConcurrentModificationIntegrationTest {
mock(DynamicConfigurationManager.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
final KeysManager keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
Tables.EC_KEYS.tableName(),
Tables.PQ_KEYS.tableName(),
Tables.REPEATED_USE_EC_SIGNED_PRE_KEYS.tableName(),
Tables.REPEATED_USE_KEM_SIGNED_PRE_KEYS.tableName(),
dynamicConfigurationManager);
accounts = new Accounts(
DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
@ -157,18 +149,24 @@ class AccountsManagerConcurrentModificationIntegrationTest {
final Account account = accountsManager.update(
accountsManager.create("+14155551212",
"password",
null,
new AccountAttributes(),
new ArrayList<>(),
new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair),
Optional.empty(),
Optional.empty()),
new DeviceSpec(
null,
"password",
null,
new Device.DeviceCapabilities(false, false, false, false),
1,
2,
true,
Optional.empty(),
Optional.empty(),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair))),
a -> {
a.setUnidentifiedAccessKey(new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
a.removeDevice(Device.PRIMARY_ID);

View File

@ -31,12 +31,12 @@ import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import io.lettuce.core.RedisException;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Base64;
@ -85,6 +85,8 @@ import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock;
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class AccountsManagerTest {
@ -109,6 +111,7 @@ class AccountsManagerTest {
private RedisAdvancedClusterCommands<String, String> commands;
private RedisAdvancedClusterAsyncCommands<String, String> asyncCommands;
private TestClock clock;
private AccountsManager accountsManager;
private static final Answer<?> ACCOUNT_UPDATE_ANSWER = (answer) -> {
@ -219,6 +222,8 @@ class AccountsManagerTest {
when(messagesManager.clear(any())).thenReturn(CompletableFuture.completedFuture(null));
when(profilesManager.deleteAll(any())).thenReturn(CompletableFuture.completedFuture(null));
clock = TestClock.now();
accountsManager = new AccountsManager(
accounts,
phoneNumberIdentifiers,
@ -237,7 +242,7 @@ class AccountsManagerTest {
registrationRecoveryPasswordsManager,
mock(Executor.class),
clientPresenceExecutor,
mock(Clock.class));
clock);
}
@Test
@ -1074,6 +1079,84 @@ class AccountsManagerTest {
assertEquals(hasStorage, account.isStorageSupported());
}
@Test
void testAddDevice() {
final String phoneNumber =
PhoneNumberUtil.getInstance().format(PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
final Account account = AccountsHelper.generateTestAccount(phoneNumber, List.of(generateTestDevice(clock.millis())));
final UUID aci = account.getIdentifier(IdentityType.ACI);
final UUID pni = account.getIdentifier(IdentityType.PNI);
final byte nextDeviceId = account.getNextDeviceId();
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
final byte[] deviceNameCiphertext = "device-name".getBytes(StandardCharsets.UTF_8);
final String password = "password";
final String signalAgent = "OWT";
final DeviceCapabilities deviceCapabilities = new DeviceCapabilities(true, true, true, true);
final int aciRegistrationId = 17;
final int pniRegistrationId = 19;
final ECSignedPreKey aciSignedPreKey = KeysHelper.signedECPreKey(1, aciKeyPair);
final ECSignedPreKey pniSignedPreKey = KeysHelper.signedECPreKey(2, pniKeyPair);
final KEMSignedPreKey aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciKeyPair);
final KEMSignedPreKey pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniKeyPair);
when(keysManager.delete(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
when(messagesManager.clear(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(null));
when(accounts.getByAccountIdentifierAsync(aci)).thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(accounts.updateTransactionallyAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
clock.pin(clock.instant().plusSeconds(60));
final Pair<Account, Device> updatedAccountAndDevice = accountsManager.addDevice(account, new DeviceSpec(
deviceNameCiphertext,
password,
signalAgent,
deviceCapabilities,
aciRegistrationId,
pniRegistrationId,
true,
Optional.empty(),
Optional.empty(),
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey))
.join();
verify(keysManager).delete(aci, nextDeviceId);
verify(keysManager).delete(pni, nextDeviceId);
verify(messagesManager).clear(aci, nextDeviceId);
verify(keysManager).buildWriteItemsForRepeatedUseKeys(
aci,
pni,
nextDeviceId,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey);
final Device device = updatedAccountAndDevice.second();
assertEquals(deviceNameCiphertext, device.getName());
assertTrue(device.getAuthTokenHash().verify(password));
assertEquals(signalAgent, device.getUserAgent());
assertEquals(deviceCapabilities, device.getCapabilities());
assertEquals(aciRegistrationId, device.getRegistrationId());
assertEquals(pniRegistrationId, device.getPhoneNumberIdentityRegistrationId().getAsInt());
assertTrue(device.getFetchesMessages());
assertNull(device.getApnId());
assertNull(device.getVoipApnId());
assertNull(device.getGcmId());
assertEquals(aciSignedPreKey, device.getSignedPreKey(IdentityType.ACI));
assertEquals(pniSignedPreKey, device.getSignedPreKey(IdentityType.PNI));
}
@ParameterizedTest
@MethodSource
void testUpdateDeviceLastSeen(final boolean expectUpdate, final long initialLastSeen, final long updatedLastSeen) {
@ -1649,17 +1732,23 @@ class AccountsManagerTest {
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
return accountsManager.create(e164,
"password",
null,
accountAttributes,
new ArrayList<>(),
new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair),
Optional.empty(),
Optional.empty());
new DeviceSpec(
accountAttributes.getName(),
"password",
null,
accountAttributes.getCapabilities(),
accountAttributes.getRegistrationId(),
accountAttributes.getPhoneNumberIdentityRegistrationId(),
accountAttributes.getFetchesMessages(),
Optional.empty(),
Optional.empty(),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)));
}
}

View File

@ -30,6 +30,7 @@ import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -171,17 +172,23 @@ public class AccountsHelper {
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
return accountsManager.create(e164,
"password",
null,
accountAttributes,
new ArrayList<>(),
new IdentityKey(aciKeyPair.getPublicKey()),
new IdentityKey(pniKeyPair.getPublicKey()),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair),
Optional.empty(),
Optional.empty());
new DeviceSpec(
accountAttributes.getName(),
"password",
"OWT",
accountAttributes.getCapabilities(),
accountAttributes.getRegistrationId(),
accountAttributes.getPhoneNumberIdentityRegistrationId(),
accountAttributes.getFetchesMessages(),
Optional.empty(),
Optional.empty(),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)));
}
}