Manage device linking tokens transactionally

This commit is contained in:
Jon Chambers 2024-10-07 16:26:11 -04:00 committed by GitHub
parent 42e920cd5c
commit f7aacefc40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 539 additions and 308 deletions

View File

@ -102,6 +102,7 @@ dynamoDbTables:
phoneNumberTableName: Example_Accounts_PhoneNumbers phoneNumberTableName: Example_Accounts_PhoneNumbers
phoneNumberIdentifierTableName: Example_Accounts_PhoneNumberIdentifiers phoneNumberIdentifierTableName: Example_Accounts_PhoneNumberIdentifiers
usernamesTableName: Example_Accounts_Usernames usernamesTableName: Example_Accounts_Usernames
usedLinkDeviceTokensTableName: Example_Accounts_UsedLinkDeviceTokens
backups: backups:
tableName: Example_Backups tableName: Example_Backups
clientReleases: clientReleases:

View File

@ -403,7 +403,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getDynamoDbTables().getAccounts().getPhoneNumberTableName(), config.getDynamoDbTables().getAccounts().getPhoneNumberTableName(),
config.getDynamoDbTables().getAccounts().getPhoneNumberIdentifierTableName(), config.getDynamoDbTables().getAccounts().getPhoneNumberIdentifierTableName(),
config.getDynamoDbTables().getAccounts().getUsernamesTableName(), config.getDynamoDbTables().getAccounts().getUsernamesTableName(),
config.getDynamoDbTables().getDeletedAccounts().getTableName()); config.getDynamoDbTables().getDeletedAccounts().getTableName(),
config.getDynamoDbTables().getAccounts().getUsedLinkDeviceTokensTableName());
ClientReleases clientReleases = new ClientReleases(dynamoDbAsyncClient, ClientReleases clientReleases = new ClientReleases(dynamoDbAsyncClient,
config.getDynamoDbTables().getClientReleases().getTableName()); config.getDynamoDbTables().getClientReleases().getTableName());
PhoneNumberIdentifiers phoneNumberIdentifiers = new PhoneNumberIdentifiers(dynamoDbClient, PhoneNumberIdentifiers phoneNumberIdentifiers = new PhoneNumberIdentifiers(dynamoDbClient,
@ -637,11 +638,11 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
ClientPublicKeysManager clientPublicKeysManager = ClientPublicKeysManager clientPublicKeysManager =
new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor); new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor);
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster, AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
accountLockManager, keysManager, messagesManager, profilesManager, rateLimitersCluster, accountLockManager, keysManager, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client, secureStorageClient, secureValueRecovery2Client,
clientPresenceManager, clientPresenceManager,
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, clientPresenceExecutor, registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, clientPresenceExecutor,
clock, dynamicConfigurationManager); clock, config.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager);
RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs); RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs);
APNSender apnSender = new APNSender(apnSenderExecutor, config.getApnConfiguration()); APNSender apnSender = new APNSender(apnSenderExecutor, config.getApnConfiguration());
FcmSender fcmSender = new FcmSender(fcmSenderExecutor, config.getFcmConfiguration().credentials().value()); FcmSender fcmSender = new FcmSender(fcmSenderExecutor, config.getFcmConfiguration().credentials().value());
@ -1107,8 +1108,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()), config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()),
zkAuthOperations, callingGenericZkSecretParams, clock), zkAuthOperations, callingGenericZkSecretParams, clock),
new ChallengeController(rateLimitChallengeManager, challengeConstraintChecker), new ChallengeController(rateLimitChallengeManager, challengeConstraintChecker),
new DeviceController(config.getLinkDeviceSecretConfiguration().secret().value(), accountsManager, new DeviceController(accountsManager, clientPublicKeysManager, rateLimiters, config.getMaxDevices()),
clientPublicKeysManager, rateLimiters, rateLimitersCluster, config.getMaxDevices(), clock),
new DirectoryV2Controller(directoryV2CredentialsGenerator), new DirectoryV2Controller(directoryV2CredentialsGenerator),
new DonationController(clock, zkReceiptOperations, redeemedReceiptsManager, accountsManager, config.getBadges(), new DonationController(clock, zkReceiptOperations, redeemedReceiptsManager, accountsManager, config.getBadges(),
ReceiptCredentialPresentation::new), ReceiptCredentialPresentation::new),

View File

@ -10,19 +10,22 @@ public class AccountsTableConfiguration extends Table {
private final String phoneNumberTableName; private final String phoneNumberTableName;
private final String phoneNumberIdentifierTableName; private final String phoneNumberIdentifierTableName;
private final String usernamesTableName; private final String usernamesTableName;
private final String usedLinkDeviceTokensTableName;
@JsonCreator @JsonCreator
public AccountsTableConfiguration( public AccountsTableConfiguration(
@JsonProperty("tableName") final String tableName, @JsonProperty("tableName") final String tableName,
@JsonProperty("phoneNumberTableName") final String phoneNumberTableName, @JsonProperty("phoneNumberTableName") final String phoneNumberTableName,
@JsonProperty("phoneNumberIdentifierTableName") final String phoneNumberIdentifierTableName, @JsonProperty("phoneNumberIdentifierTableName") final String phoneNumberIdentifierTableName,
@JsonProperty("usernamesTableName") final String usernamesTableName) { @JsonProperty("usernamesTableName") final String usernamesTableName,
@JsonProperty("usedLinkDeviceTokensTableName") final String usedLinkDeviceTokensTableName) {
super(tableName); super(tableName);
this.phoneNumberTableName = phoneNumberTableName; this.phoneNumberTableName = phoneNumberTableName;
this.phoneNumberIdentifierTableName = phoneNumberIdentifierTableName; this.phoneNumberIdentifierTableName = phoneNumberIdentifierTableName;
this.usernamesTableName = usernamesTableName; this.usernamesTableName = usernamesTableName;
this.usedLinkDeviceTokensTableName = usedLinkDeviceTokensTableName;
} }
@NotBlank @NotBlank
@ -39,4 +42,9 @@ public class AccountsTableConfiguration extends Table {
public String getUsernamesTableName() { public String getUsernamesTableName() {
return usernamesTableName; return usernamesTableName;
} }
@NotBlank
public String getUsedLinkDeviceTokensTableName() {
return usedLinkDeviceTokensTableName;
}
} }

View File

@ -4,32 +4,18 @@
*/ */
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.HttpHeaders; import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth; import io.dropwizard.auth.Auth;
import io.lettuce.core.SetArgs;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.headers.Header; import io.swagger.v3.oas.annotations.headers.Header;
import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Base64;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
@ -61,13 +47,13 @@ import org.whispersystems.textsecuregcm.entities.ProvisioningMessage;
import org.whispersystems.textsecuregcm.entities.SetPublicKeyRequest; import org.whispersystems.textsecuregcm.entities.SetPublicKeyRequest;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.DeviceSpec; import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.LinkDeviceTokenAlreadyUsedException;
import org.whispersystems.textsecuregcm.util.VerificationCode; import org.whispersystems.textsecuregcm.util.VerificationCode;
import org.whispersystems.websocket.auth.Mutable; import org.whispersystems.websocket.auth.Mutable;
import org.whispersystems.websocket.auth.ReadOnly; import org.whispersystems.websocket.auth.ReadOnly;
@ -78,43 +64,20 @@ public class DeviceController {
static final int MAX_DEVICES = 6; static final int MAX_DEVICES = 6;
private final Key verificationTokenKey;
private final AccountsManager accounts; private final AccountsManager accounts;
private final ClientPublicKeysManager clientPublicKeysManager; private final ClientPublicKeysManager clientPublicKeysManager;
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
private final FaultTolerantRedisCluster usedTokenCluster;
private final Map<String, Integer> maxDeviceConfiguration; private final Map<String, Integer> maxDeviceConfiguration;
private final Clock clock; public DeviceController(final AccountsManager accounts,
final ClientPublicKeysManager clientPublicKeysManager,
final RateLimiters rateLimiters,
final Map<String, Integer> maxDeviceConfiguration) {
private static final String VERIFICATION_TOKEN_ALGORITHM = "HmacSHA256";
@VisibleForTesting
static final Duration TOKEN_EXPIRATION_DURATION = Duration.ofMinutes(10);
public DeviceController(byte[] linkDeviceSecret,
AccountsManager accounts,
ClientPublicKeysManager clientPublicKeysManager,
RateLimiters rateLimiters,
FaultTolerantRedisCluster usedTokenCluster,
Map<String, Integer> maxDeviceConfiguration, final Clock clock) {
this.verificationTokenKey = new SecretKeySpec(linkDeviceSecret, VERIFICATION_TOKEN_ALGORITHM);
this.accounts = accounts; this.accounts = accounts;
this.clientPublicKeysManager = clientPublicKeysManager; this.clientPublicKeysManager = clientPublicKeysManager;
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
this.usedTokenCluster = usedTokenCluster;
this.maxDeviceConfiguration = maxDeviceConfiguration; this.maxDeviceConfiguration = maxDeviceConfiguration;
this.clock = clock;
// Fail fast: reject bad keys
try {
final Mac mac = Mac.getInstance(VERIFICATION_TOKEN_ALGORITHM);
mac.init(verificationTokenKey);
} catch (final NoSuchAlgorithmException e) {
throw new AssertionError("All Java implementations must support HmacSHA256", e);
} catch (final InvalidKeyException e) {
throw new IllegalArgumentException(e);
}
} }
@GET @GET
@ -196,7 +159,7 @@ public class DeviceController {
throw new WebApplicationException(Response.Status.UNAUTHORIZED); throw new WebApplicationException(Response.Status.UNAUTHORIZED);
} }
return new VerificationCode(generateVerificationToken(account.getUuid())); return new VerificationCode(accounts.generateDeviceLinkingToken(account.getUuid()));
} }
@PUT @PUT
@ -222,7 +185,7 @@ public class DeviceController {
@Context ContainerRequest containerRequest) @Context ContainerRequest containerRequest)
throws RateLimitExceededException, DeviceLimitExceededException { throws RateLimitExceededException, DeviceLimitExceededException {
final Account account = checkVerificationToken(linkDeviceRequest.verificationCode()) final Account account = accounts.checkDeviceLinkingToken(linkDeviceRequest.verificationCode())
.flatMap(accounts::getByAccountIdentifier) .flatMap(accounts::getByAccountIdentifier)
.orElseThrow(ForbiddenException::new); .orElseThrow(ForbiddenException::new);
@ -274,27 +237,33 @@ public class DeviceController {
signalAgent = "OWD"; signalAgent = "OWD";
} }
return accounts.addDevice(account, new DeviceSpec(accountAttributes.getName(), try {
authorizationHeader.getPassword(), return accounts.addDevice(account, new DeviceSpec(accountAttributes.getName(),
signalAgent, authorizationHeader.getPassword(),
capabilities, signalAgent,
accountAttributes.getRegistrationId(), capabilities,
accountAttributes.getPhoneNumberIdentityRegistrationId(), accountAttributes.getRegistrationId(),
accountAttributes.getFetchesMessages(), accountAttributes.getPhoneNumberIdentityRegistrationId(),
deviceActivationRequest.apnToken(), accountAttributes.getFetchesMessages(),
deviceActivationRequest.gcmToken(), deviceActivationRequest.apnToken(),
deviceActivationRequest.aciSignedPreKey(), deviceActivationRequest.gcmToken(),
deviceActivationRequest.pniSignedPreKey(), deviceActivationRequest.aciSignedPreKey(),
deviceActivationRequest.aciPqLastResortPreKey(), deviceActivationRequest.pniSignedPreKey(),
deviceActivationRequest.pniPqLastResortPreKey())) deviceActivationRequest.aciPqLastResortPreKey(),
.thenCompose(a -> usedTokenCluster.withCluster(connection -> connection.async() deviceActivationRequest.pniPqLastResortPreKey()),
.set(getUsedTokenKey(linkDeviceRequest.verificationCode()), "", new SetArgs().ex(TOKEN_EXPIRATION_DURATION))) linkDeviceRequest.verificationCode())
.thenApply(ignored -> a)) .thenApply(accountAndDevice -> new DeviceResponse(
.thenApply(accountAndDevice -> new DeviceResponse( accountAndDevice.first().getIdentifier(IdentityType.ACI),
accountAndDevice.first().getIdentifier(IdentityType.ACI), accountAndDevice.first().getIdentifier(IdentityType.PNI),
accountAndDevice.first().getIdentifier(IdentityType.PNI), accountAndDevice.second().getId()))
accountAndDevice.second().getId())) .join();
.join(); } catch (final CompletionException e) {
if (e.getCause() instanceof LinkDeviceTokenAlreadyUsedException) {
throw new ForbiddenException();
}
throw e;
}
} }
@PUT @PUT
@ -336,95 +305,10 @@ public class DeviceController {
setPublicKeyRequest.publicKey()); setPublicKeyRequest.publicKey());
} }
private Mac getInitializedMac() {
try {
final Mac mac = Mac.getInstance(VERIFICATION_TOKEN_ALGORITHM);
mac.init(verificationTokenKey);
return mac;
} catch (final NoSuchAlgorithmException | InvalidKeyException e) {
// All Java implementations must support HmacSHA256 and we checked the key at construction time, so this can never
// happen
throw new AssertionError(e);
}
}
@VisibleForTesting
String generateVerificationToken(final UUID aci) {
final String claims = aci + "." + clock.instant().toEpochMilli();
final byte[] signature = getInitializedMac().doFinal(claims.getBytes(StandardCharsets.UTF_8));
return claims + ":" + Base64.getUrlEncoder().encodeToString(signature);
}
@VisibleForTesting
Optional<UUID> checkVerificationToken(final String verificationToken) {
final boolean tokenUsed = usedTokenCluster.withCluster(connection ->
connection.sync().get(getUsedTokenKey(verificationToken)) != null);
if (tokenUsed) {
return Optional.empty();
}
final String[] claimsAndSignature = verificationToken.split(":", 2);
if (claimsAndSignature.length != 2) {
return Optional.empty();
}
final byte[] expectedSignature = getInitializedMac().doFinal(
claimsAndSignature[0].getBytes(StandardCharsets.UTF_8));
final byte[] providedSignature;
try {
providedSignature = Base64.getUrlDecoder().decode(claimsAndSignature[1]);
} catch (final IllegalArgumentException e) {
return Optional.empty();
}
if (!MessageDigest.isEqual(expectedSignature, providedSignature)) {
return Optional.empty();
}
final String[] aciAndTimestamp = claimsAndSignature[0].split("\\.", 2);
if (aciAndTimestamp.length != 2) {
return Optional.empty();
}
final UUID aci;
try {
aci = UUID.fromString(aciAndTimestamp[0]);
} catch (final IllegalArgumentException e) {
return Optional.empty();
}
final Instant timestamp;
try {
timestamp = Instant.ofEpochMilli(Long.parseLong(aciAndTimestamp[1]));
} catch (final NumberFormatException e) {
return Optional.empty();
}
final Instant tokenExpiration = timestamp.plus(TOKEN_EXPIRATION_DURATION);
if (tokenExpiration.isBefore(clock.instant())) {
return Optional.empty();
}
return Optional.of(aci);
}
private static boolean isCapabilityDowngrade(Account account, DeviceCapabilities capabilities) { private static boolean isCapabilityDowngrade(Account account, DeviceCapabilities capabilities) {
boolean isDowngrade = false; boolean isDowngrade = false;
isDowngrade |= account.isDeleteSyncSupported() && !capabilities.deleteSync(); isDowngrade |= account.isDeleteSyncSupported() && !capabilities.deleteSync();
isDowngrade |= account.isVersionedExpirationTimerSupported() && !capabilities.versionedExpirationTimer(); isDowngrade |= account.isVersionedExpirationTimerSupported() && !capabilities.versionedExpirationTimer();
return isDowngrade; return isDowngrade;
} }
private static String getUsedTokenKey(final String token) {
return "usedToken::" + token;
}
} }

View File

@ -14,6 +14,9 @@ import com.google.common.base.Throwables;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Timer; import io.micrometer.core.instrument.Timer;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
@ -43,6 +46,7 @@ import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient; import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
@ -122,6 +126,12 @@ public class Accounts extends AbstractDynamoDbStore {
// username hash; byte[] or null // username hash; byte[] or null
static final String ATTR_USERNAME_HASH = "N"; static final String ATTR_USERNAME_HASH = "N";
// bytes, primary key
static final String KEY_LINK_DEVICE_TOKEN_HASH = "H";
// integer, seconds
static final String ATTR_LINK_DEVICE_TOKEN_TTL = "E";
// unidentified access key; byte[] or null // unidentified access key; byte[] or null
static final String ATTR_UAK = "UAK"; static final String ATTR_UAK = "UAK";
@ -154,6 +164,7 @@ public class Accounts extends AbstractDynamoDbStore {
private final String phoneNumberIdentifierConstraintTableName; private final String phoneNumberIdentifierConstraintTableName;
private final String usernamesConstraintTableName; private final String usernamesConstraintTableName;
private final String deletedAccountsTableName; private final String deletedAccountsTableName;
private final String usedLinkDeviceTokenTableName;
private final String accountsTableName; private final String accountsTableName;
@VisibleForTesting @VisibleForTesting
@ -165,7 +176,8 @@ public class Accounts extends AbstractDynamoDbStore {
final String phoneNumberConstraintTableName, final String phoneNumberConstraintTableName,
final String phoneNumberIdentifierConstraintTableName, final String phoneNumberIdentifierConstraintTableName,
final String usernamesConstraintTableName, final String usernamesConstraintTableName,
final String deletedAccountsTableName) { final String deletedAccountsTableName,
final String usedLinkDeviceTokenTableName) {
super(client); super(client);
this.clock = clock; this.clock = clock;
@ -175,6 +187,7 @@ public class Accounts extends AbstractDynamoDbStore {
this.accountsTableName = accountsTableName; this.accountsTableName = accountsTableName;
this.usernamesConstraintTableName = usernamesConstraintTableName; this.usernamesConstraintTableName = usernamesConstraintTableName;
this.deletedAccountsTableName = deletedAccountsTableName; this.deletedAccountsTableName = deletedAccountsTableName;
this.usedLinkDeviceTokenTableName = usedLinkDeviceTokenTableName;
} }
public Accounts( public Accounts(
@ -184,11 +197,12 @@ public class Accounts extends AbstractDynamoDbStore {
final String phoneNumberConstraintTableName, final String phoneNumberConstraintTableName,
final String phoneNumberIdentifierConstraintTableName, final String phoneNumberIdentifierConstraintTableName,
final String usernamesConstraintTableName, final String usernamesConstraintTableName,
final String deletedAccountsTableName) { final String deletedAccountsTableName,
final String usedLinkDeviceTokenTableName) {
this(Clock.systemUTC(), client, asyncClient, accountsTableName, this(Clock.systemUTC(), client, asyncClient, accountsTableName,
phoneNumberConstraintTableName, phoneNumberIdentifierConstraintTableName, usernamesConstraintTableName, phoneNumberConstraintTableName, phoneNumberIdentifierConstraintTableName, usernamesConstraintTableName,
deletedAccountsTableName); deletedAccountsTableName, usedLinkDeviceTokenTableName);
} }
static class UsernameTable { static class UsernameTable {
@ -1065,6 +1079,28 @@ public class Accounts extends AbstractDynamoDbStore {
}); });
} }
public TransactWriteItem buildTransactWriteItemForLinkDevice(final String linkDeviceToken, final Duration tokenTtl) {
final byte[] linkDeviceTokenHash;
try {
linkDeviceTokenHash = MessageDigest.getInstance("SHA-256").digest(linkDeviceToken.getBytes(StandardCharsets.UTF_8));
} catch (final NoSuchAlgorithmException e) {
throw new AssertionError("Every implementation of the Java platform is required to support the SHA-256 MessageDigest algorithm", e);
}
return TransactWriteItem.builder()
.put(Put.builder()
.tableName(usedLinkDeviceTokenTableName)
.item(Map.of(
KEY_LINK_DEVICE_TOKEN_HASH, AttributeValue.fromB(SdkBytes.fromByteArray(linkDeviceTokenHash)),
ATTR_LINK_DEVICE_TOKEN_TTL, AttributeValue.fromN(String.valueOf(clock.instant().plus(tokenTtl).getEpochSecond()))
))
.conditionExpression("attribute_not_exists(#linkDeviceTokenHash)")
.expressionAttributeNames(Map.of("#linkDeviceTokenHash", KEY_LINK_DEVICE_TOKEN_HASH))
.build())
.build();
}
@Nonnull @Nonnull
public Optional<Account> getByE164(final String number) { public Optional<Account> getByE164(final String number) {
return getByIndirectLookup( return getByIndirectLookup(

View File

@ -13,6 +13,7 @@ import com.fasterxml.jackson.databind.ObjectWriter;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.lettuce.core.RedisException; import io.lettuce.core.RedisException;
import io.lettuce.core.SetArgs;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tag;
@ -20,11 +21,18 @@ import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer; import io.micrometer.core.instrument.Timer;
import java.io.IOException; import java.io.IOException;
import java.io.UncheckedIOException; import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant;
import java.util.ArrayDeque; import java.util.ArrayDeque;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Base64;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -43,6 +51,8 @@ import java.util.function.Function;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import org.apache.commons.lang3.ObjectUtils; import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
@ -71,6 +81,7 @@ import org.whispersystems.textsecuregcm.util.Util;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem; import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
import software.amazon.awssdk.services.dynamodb.model.TransactionCanceledException;
public class AccountsManager { public class AccountsManager {
@ -98,6 +109,7 @@ public class AccountsManager {
private final Accounts accounts; private final Accounts accounts;
private final PhoneNumberIdentifiers phoneNumberIdentifiers; private final PhoneNumberIdentifiers phoneNumberIdentifiers;
private final FaultTolerantRedisCluster cacheCluster; private final FaultTolerantRedisCluster cacheCluster;
private final FaultTolerantRedisCluster rateLimitCluster;
private final AccountLockManager accountLockManager; private final AccountLockManager accountLockManager;
private final KeysManager keysManager; private final KeysManager keysManager;
private final MessagesManager messagesManager; private final MessagesManager messagesManager;
@ -112,6 +124,8 @@ public class AccountsManager {
private final Clock clock; private final Clock clock;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager; private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final Key verificationTokenKey;
private static final ObjectWriter ACCOUNT_REDIS_JSON_WRITER = SystemMapper.jsonMapper() private static final ObjectWriter ACCOUNT_REDIS_JSON_WRITER = SystemMapper.jsonMapper()
.writer(SystemMapper.excludingField(Account.class, List.of("uuid"))); .writer(SystemMapper.excludingField(Account.class, List.of("uuid")));
@ -125,6 +139,12 @@ public class AccountsManager {
private static final int MAX_UPDATE_ATTEMPTS = 10; private static final int MAX_UPDATE_ATTEMPTS = 10;
@VisibleForTesting
static final Duration LINK_DEVICE_TOKEN_EXPIRATION_DURATION = Duration.ofMinutes(10);
@VisibleForTesting
static final String LINK_DEVICE_VERIFICATION_TOKEN_ALGORITHM = "HmacSHA256";
public enum DeletionReason { public enum DeletionReason {
ADMIN_DELETED("admin"), ADMIN_DELETED("admin"),
EXPIRED ("expired"), EXPIRED ("expired"),
@ -140,6 +160,7 @@ public class AccountsManager {
public AccountsManager(final Accounts accounts, public AccountsManager(final Accounts accounts,
final PhoneNumberIdentifiers phoneNumberIdentifiers, final PhoneNumberIdentifiers phoneNumberIdentifiers,
final FaultTolerantRedisCluster cacheCluster, final FaultTolerantRedisCluster cacheCluster,
final FaultTolerantRedisCluster rateLimitCluster,
final AccountLockManager accountLockManager, final AccountLockManager accountLockManager,
final KeysManager keysManager, final KeysManager keysManager,
final MessagesManager messagesManager, final MessagesManager messagesManager,
@ -152,10 +173,12 @@ public class AccountsManager {
final Executor accountLockExecutor, final Executor accountLockExecutor,
final Executor clientPresenceExecutor, final Executor clientPresenceExecutor,
final Clock clock, final Clock clock,
final byte[] linkDeviceSecret,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) { final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) {
this.accounts = accounts; this.accounts = accounts;
this.phoneNumberIdentifiers = phoneNumberIdentifiers; this.phoneNumberIdentifiers = phoneNumberIdentifiers;
this.cacheCluster = cacheCluster; this.cacheCluster = cacheCluster;
this.rateLimitCluster = rateLimitCluster;
this.accountLockManager = accountLockManager; this.accountLockManager = accountLockManager;
this.keysManager = keysManager; this.keysManager = keysManager;
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
@ -169,6 +192,15 @@ public class AccountsManager {
this.clientPresenceExecutor = clientPresenceExecutor; this.clientPresenceExecutor = clientPresenceExecutor;
this.clock = requireNonNull(clock); this.clock = requireNonNull(clock);
this.dynamicConfigurationManager = dynamicConfigurationManager; this.dynamicConfigurationManager = dynamicConfigurationManager;
this.verificationTokenKey = new SecretKeySpec(linkDeviceSecret, LINK_DEVICE_VERIFICATION_TOKEN_ALGORITHM);
// Fail fast: reject bad keys
try {
getInitializedMac(verificationTokenKey);
} catch (final InvalidKeyException e) {
throw new IllegalArgumentException(e);
}
} }
public Account create(final String number, public Account create(final String number,
@ -275,46 +307,179 @@ public class AccountsManager {
}); });
} }
public CompletableFuture<Pair<Account, Device>> addDevice(final Account account, final DeviceSpec deviceSpec) { public CompletableFuture<Pair<Account, Device>> addDevice(final Account account, final DeviceSpec deviceSpec, final String linkDeviceToken) {
return accountLockManager.withLockAsync(List.of(account.getNumber()), return accountLockManager.withLockAsync(List.of(account.getNumber()),
() -> addDevice(account.getIdentifier(IdentityType.ACI), deviceSpec, MAX_UPDATE_ATTEMPTS), () -> addDevice(account.getIdentifier(IdentityType.ACI), deviceSpec, linkDeviceToken, MAX_UPDATE_ATTEMPTS),
accountLockExecutor); accountLockExecutor);
} }
private CompletableFuture<Pair<Account, Device>> addDevice(final UUID accountIdentifier, final DeviceSpec deviceSpec, final int retries) { private CompletableFuture<Pair<Account, Device>> addDevice(final UUID accountIdentifier, final DeviceSpec deviceSpec, final String linkDeviceToken, final int retries) {
return accounts.getByAccountIdentifierAsync(accountIdentifier) return accounts.getByAccountIdentifierAsync(accountIdentifier)
.thenApply(maybeAccount -> maybeAccount.orElseThrow(ContestedOptimisticLockException::new)) .thenApply(maybeAccount -> maybeAccount.orElseThrow(ContestedOptimisticLockException::new))
.thenCompose(account -> { .thenCompose(account -> {
final byte nextDeviceId = account.getNextDeviceId(); final byte nextDeviceId = account.getNextDeviceId();
return CompletableFuture.allOf(
keysManager.deleteSingleUsePreKeys(account.getUuid(), nextDeviceId),
keysManager.deleteSingleUsePreKeys(account.getPhoneNumberIdentifier(), nextDeviceId),
messagesManager.clear(account.getUuid(), nextDeviceId))
.thenApply(ignored -> new Pair<>(account, nextDeviceId));
})
.thenCompose(accountAndNextDeviceId -> {
final Account account = accountAndNextDeviceId.first();
final byte nextDeviceId = accountAndNextDeviceId.second();
account.addDevice(deviceSpec.toDevice(nextDeviceId, clock)); account.addDevice(deviceSpec.toDevice(nextDeviceId, clock));
final List<TransactWriteItem> additionalWriteItems = keysManager.buildWriteItemsForNewDevice( final List<TransactWriteItem> additionalWriteItems = new ArrayList<>(keysManager.buildWriteItemsForNewDevice(
account.getIdentifier(IdentityType.ACI), account.getIdentifier(IdentityType.ACI),
account.getIdentifier(IdentityType.PNI), account.getIdentifier(IdentityType.PNI),
nextDeviceId, nextDeviceId,
deviceSpec.aciSignedPreKey(), deviceSpec.aciSignedPreKey(),
deviceSpec.pniSignedPreKey(), deviceSpec.pniSignedPreKey(),
deviceSpec.aciPqLastResortPreKey(), deviceSpec.aciPqLastResortPreKey(),
deviceSpec.pniPqLastResortPreKey()); deviceSpec.pniPqLastResortPreKey()));
return CompletableFuture.allOf( additionalWriteItems.add(accounts.buildTransactWriteItemForLinkDevice(linkDeviceToken, LINK_DEVICE_TOKEN_EXPIRATION_DURATION));
keysManager.deleteSingleUsePreKeys(account.getUuid(), nextDeviceId),
keysManager.deleteSingleUsePreKeys(account.getPhoneNumberIdentifier(), nextDeviceId), return accounts.updateTransactionallyAsync(account, additionalWriteItems)
messagesManager.clear(account.getUuid(), nextDeviceId))
.thenCompose(ignored -> accounts.updateTransactionallyAsync(account, additionalWriteItems))
.thenApply(ignored -> new Pair<>(account, account.getDevice(nextDeviceId).orElseThrow())); .thenApply(ignored -> new Pair<>(account, account.getDevice(nextDeviceId).orElseThrow()));
}) })
.thenCompose(updatedAccountAndDevice -> rateLimitCluster.withCluster(connection ->
connection.async().set(getUsedTokenKey(linkDeviceToken), "", new SetArgs().ex(LINK_DEVICE_TOKEN_EXPIRATION_DURATION)))
.thenApply(ignored -> updatedAccountAndDevice))
.thenCompose(updatedAccountAndDevice -> redisDeleteAsync(updatedAccountAndDevice.first()) .thenCompose(updatedAccountAndDevice -> redisDeleteAsync(updatedAccountAndDevice.first())
.thenApply(ignored -> updatedAccountAndDevice)) .thenApply(ignored -> updatedAccountAndDevice))
.exceptionallyCompose(throwable -> { .exceptionallyCompose(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof ContestedOptimisticLockException && retries > 0) { if (ExceptionUtils.unwrap(throwable) instanceof ContestedOptimisticLockException && retries > 0) {
return addDevice(accountIdentifier, deviceSpec, retries - 1); return addDevice(accountIdentifier, deviceSpec, linkDeviceToken, retries - 1);
} else if (ExceptionUtils.unwrap(throwable) instanceof TransactionCanceledException transactionCanceledException) {
// We can be confident the transaction was canceled because the linked device token was already used if the
// "check token" transaction write item is the only one that failed. That SHOULD be the last one in the
// list.
final long cancelledTransactions = transactionCanceledException.cancellationReasons().stream()
.filter(cancellationReason -> !"None".equals(cancellationReason.code()))
.count();
final boolean tokenReuseConditionFailed =
"ConditionalCheckFailed".equals(transactionCanceledException.cancellationReasons().getLast().code());
if (cancelledTransactions == 1 && tokenReuseConditionFailed) {
return CompletableFuture.failedFuture(new LinkDeviceTokenAlreadyUsedException());
}
} }
return CompletableFuture.failedFuture(throwable); return CompletableFuture.failedFuture(throwable);
}); });
} }
private Mac getInitializedMac() {
try {
return getInitializedMac(verificationTokenKey);
} catch (final InvalidKeyException e) {
// We checked the key at construction time, so this can never happen
throw new AssertionError("Previously valid key now invalid", e);
}
}
private static Mac getInitializedMac(final Key linkDeviceTokenKey) throws InvalidKeyException {
try {
final Mac mac = Mac.getInstance(LINK_DEVICE_VERIFICATION_TOKEN_ALGORITHM);
mac.init(linkDeviceTokenKey);
return mac;
} catch (final NoSuchAlgorithmException e) {
throw new AssertionError(e);
}
}
public String generateDeviceLinkingToken(final UUID aci) {
final String claims = aci + "." + clock.instant().toEpochMilli();
final byte[] signature = getInitializedMac().doFinal(claims.getBytes(StandardCharsets.UTF_8));
return claims + ":" + Base64.getUrlEncoder().encodeToString(signature);
}
@VisibleForTesting
static String generateDeviceLinkingToken(final UUID aci, final Key linkDeviceTokenKey, final Clock clock)
throws InvalidKeyException {
final String claims = aci + "." + clock.instant().toEpochMilli();
final byte[] signature = getInitializedMac(linkDeviceTokenKey).doFinal(claims.getBytes(StandardCharsets.UTF_8));
return claims + ":" + Base64.getUrlEncoder().encodeToString(signature);
}
/**
* Checks that a device-linking token is valid and returns the account identifier from the token if so, or empty if
* the token was invalid or has already been used
*
* @param token the device-linking token to check
*
* @return the account identifier from a valid token or empty if the token was invalid or already used
*/
public Optional<UUID> checkDeviceLinkingToken(final String token) {
final boolean tokenUsed = rateLimitCluster.withCluster(connection ->
connection.sync().get(getUsedTokenKey(token)) != null);
if (tokenUsed) {
return Optional.empty();
}
final String[] claimsAndSignature = token.split(":", 2);
if (claimsAndSignature.length != 2) {
return Optional.empty();
}
final byte[] expectedSignature = getInitializedMac().doFinal(claimsAndSignature[0].getBytes(StandardCharsets.UTF_8));
final byte[] providedSignature;
try {
providedSignature = Base64.getUrlDecoder().decode(claimsAndSignature[1]);
} catch (final IllegalArgumentException e) {
return Optional.empty();
}
if (!MessageDigest.isEqual(expectedSignature, providedSignature)) {
return Optional.empty();
}
final String[] aciAndTimestamp = claimsAndSignature[0].split("\\.", 2);
if (aciAndTimestamp.length != 2) {
return Optional.empty();
}
final UUID aci;
try {
aci = UUID.fromString(aciAndTimestamp[0]);
} catch (final IllegalArgumentException e) {
return Optional.empty();
}
final Instant timestamp;
try {
timestamp = Instant.ofEpochMilli(Long.parseLong(aciAndTimestamp[1]));
} catch (final NumberFormatException e) {
return Optional.empty();
}
final Instant tokenExpiration = timestamp.plus(LINK_DEVICE_TOKEN_EXPIRATION_DURATION);
if (tokenExpiration.isBefore(clock.instant())) {
return Optional.empty();
}
return Optional.of(aci);
}
private static String getUsedTokenKey(final String token) {
return "usedToken::" + token;
}
public CompletableFuture<Account> removeDevice(final Account account, final byte deviceId) { public CompletableFuture<Account> removeDevice(final Account account, final byte deviceId) {
if (deviceId == Device.PRIMARY_ID) { if (deviceId == Device.PRIMARY_ID) {
throw new IllegalArgumentException("Cannot remove primary device"); throw new IllegalArgumentException("Cannot remove primary device");

View File

@ -0,0 +1,4 @@
package org.whispersystems.textsecuregcm.storage;
public class LinkDeviceTokenAlreadyUsedException extends Exception {
}

View File

@ -180,7 +180,8 @@ record CommandDependencies(
configuration.getDynamoDbTables().getAccounts().getPhoneNumberTableName(), configuration.getDynamoDbTables().getAccounts().getPhoneNumberTableName(),
configuration.getDynamoDbTables().getAccounts().getPhoneNumberIdentifierTableName(), configuration.getDynamoDbTables().getAccounts().getPhoneNumberIdentifierTableName(),
configuration.getDynamoDbTables().getAccounts().getUsernamesTableName(), configuration.getDynamoDbTables().getAccounts().getUsernamesTableName(),
configuration.getDynamoDbTables().getDeletedAccounts().getTableName()); configuration.getDynamoDbTables().getDeletedAccounts().getTableName(),
configuration.getDynamoDbTables().getAccounts().getUsedLinkDeviceTokensTableName());
PhoneNumberIdentifiers phoneNumberIdentifiers = new PhoneNumberIdentifiers(dynamoDbClient, PhoneNumberIdentifiers phoneNumberIdentifiers = new PhoneNumberIdentifiers(dynamoDbClient,
configuration.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName()); configuration.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName());
Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient, Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient,
@ -225,10 +226,10 @@ record CommandDependencies(
ClientPublicKeysManager clientPublicKeysManager = ClientPublicKeysManager clientPublicKeysManager =
new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor); new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor);
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster, AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
accountLockManager, keys, messagesManager, profilesManager, rateLimitersCluster, accountLockManager, keys, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client, clientPresenceManager, secureStorageClient, secureValueRecovery2Client, clientPresenceManager,
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, clientPresenceExecutor, registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, clientPresenceExecutor,
clock, dynamicConfigurationManager); clock, configuration.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager);
RateLimiters rateLimiters = RateLimiters.createAndValidate(configuration.getLimitsConfiguration(), RateLimiters rateLimiters = RateLimiters.createAndValidate(configuration.getLimitsConfiguration(),
dynamicConfigurationManager, rateLimitersCluster); dynamicConfigurationManager, rateLimitersCluster);
final BackupsDb backupsDb = final BackupsDb backupsDb =

View File

@ -25,12 +25,10 @@ import io.dropwizard.testing.junit5.ResourceExtension;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -73,11 +71,11 @@ import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.DeviceSpec; import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.LinkDeviceTokenAlreadyUsedException;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture; import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
@ -102,16 +100,10 @@ class DeviceControllerTest {
private static final byte NEXT_DEVICE_ID = 42; private static final byte NEXT_DEVICE_ID = 42;
private static DeviceController deviceController = new DeviceController( private static DeviceController deviceController = new DeviceController(
generateLinkDeviceSecret(),
accountsManager, accountsManager,
clientPublicKeysManager, clientPublicKeysManager,
rateLimiters, rateLimiters,
RedisClusterHelper.builder() deviceConfiguration);
.stringCommands(commands)
.stringAsyncCommands(asyncCommands)
.build(),
deviceConfiguration,
testClock);
@RegisterExtension @RegisterExtension
public static final AuthHelper.AuthFilterExtension AUTH_FILTER_EXTENSION = new AuthHelper.AuthFilterExtension(); public static final AuthHelper.AuthFilterExtension AUTH_FILTER_EXTENSION = new AuthHelper.AuthFilterExtension();
@ -126,10 +118,6 @@ class DeviceControllerTest {
.addResource(deviceController) .addResource(deviceController)
.build(); .build();
private static byte[] generateLinkDeviceSecret() {
return TestRandomUtil.nextBytes(32);
}
@BeforeEach @BeforeEach
void setup() { void setup() {
when(rateLimiters.getAllocateDeviceLimiter()).thenReturn(rateLimiter); when(rateLimiters.getAllocateDeviceLimiter()).thenReturn(rateLimiter);
@ -183,12 +171,6 @@ class DeviceControllerTest {
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID); when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
final ECSignedPreKey aciSignedPreKey; final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey; final ECSignedPreKey pniSignedPreKey;
final KEMSignedPreKey aciPqLastResortPreKey; final KEMSignedPreKey aciPqLastResortPreKey;
@ -205,7 +187,9 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(accountsManager.addDevice(any(), any())).thenAnswer(invocation -> { when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
when(accountsManager.addDevice(any(), any(), any())).thenAnswer(invocation -> {
final Account a = invocation.getArgument(0); final Account a = invocation.getArgument(0);
final DeviceSpec deviceSpec = invocation.getArgument(1); final DeviceSpec deviceSpec = invocation.getArgument(1);
@ -217,7 +201,7 @@ class DeviceControllerTest {
final AccountAttributes accountAttributes = new AccountAttributes(fetchesMessages, 1234, 5678, null, final AccountAttributes accountAttributes = new AccountAttributes(fetchesMessages, 1234, 5678, null,
null, true, new DeviceCapabilities(true, true, true, false, false)); null, true, new DeviceCapabilities(true, true, true, false, false));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(), final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
accountAttributes, accountAttributes,
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId)); new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId));
@ -230,7 +214,7 @@ class DeviceControllerTest {
assertThat(response.getDeviceId()).isEqualTo(NEXT_DEVICE_ID); assertThat(response.getDeviceId()).isEqualTo(NEXT_DEVICE_ID);
final ArgumentCaptor<DeviceSpec> deviceSpecCaptor = ArgumentCaptor.forClass(DeviceSpec.class); final ArgumentCaptor<DeviceSpec> deviceSpecCaptor = ArgumentCaptor.forClass(DeviceSpec.class);
verify(accountsManager).addDevice(eq(account), deviceSpecCaptor.capture()); verify(accountsManager).addDevice(eq(account), deviceSpecCaptor.capture(), any());
final Device device = deviceSpecCaptor.getValue().toDevice(NEXT_DEVICE_ID, testClock); final Device device = deviceSpecCaptor.getValue().toDevice(NEXT_DEVICE_ID, testClock);
@ -241,8 +225,6 @@ class DeviceControllerTest {
expectedGcmToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getGcmId()), expectedGcmToken.ifPresentOrElse(expectedToken -> assertEquals(expectedToken, device.getGcmId()),
() -> assertNull(device.getGcmId())); () -> assertNull(device.getGcmId()));
verify(asyncCommands).set(anyString(), anyString(), any());
} }
private static Stream<Arguments> linkDeviceAtomic() { private static Stream<Arguments> linkDeviceAtomic() {
@ -261,7 +243,7 @@ class DeviceControllerTest {
@MethodSource @MethodSource
void deviceDowngradeDeleteSync(final boolean accountSupportsDeleteSync, final boolean deviceSupportsDeleteSync, final int expectedStatus) { void deviceDowngradeDeleteSync(final boolean accountSupportsDeleteSync, final boolean deviceSupportsDeleteSync, final int expectedStatus) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account)); when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
when(accountsManager.addDevice(any(), any())) when(accountsManager.addDevice(any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(new Pair<>(mock(Account.class), mock(Device.class)))); .thenReturn(CompletableFuture.completedFuture(new Pair<>(mock(Account.class), mock(Device.class))));
final Device primaryDevice = mock(Device.class); final Device primaryDevice = mock(Device.class);
@ -287,7 +269,9 @@ class DeviceControllerTest {
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null)); when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID), when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, 1234, 5678, null, null, true, new DeviceCapabilities(true, true, true, deviceSupportsDeleteSync, false)), new AccountAttributes(false, 1234, 5678, null, null, true, new DeviceCapabilities(true, true, true, deviceSupportsDeleteSync, false)),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id")))); new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id"))));
@ -314,7 +298,7 @@ class DeviceControllerTest {
void deviceDowngradeVersionedExpirationTimer(final boolean accountSupportsVersionedExpirationTimer, void deviceDowngradeVersionedExpirationTimer(final boolean accountSupportsVersionedExpirationTimer,
final boolean deviceSupportsVersionedExpirationTimer, final int expectedStatus) { final boolean deviceSupportsVersionedExpirationTimer, final int expectedStatus) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account)); when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
when(accountsManager.addDevice(any(), any())) when(accountsManager.addDevice(any(), any(), any()))
.thenReturn(CompletableFuture.completedFuture(new Pair<>(mock(Account.class), mock(Device.class)))); .thenReturn(CompletableFuture.completedFuture(new Pair<>(mock(Account.class), mock(Device.class))));
final Device primaryDevice = mock(Device.class); final Device primaryDevice = mock(Device.class);
@ -340,7 +324,9 @@ class DeviceControllerTest {
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null)); when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID), when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, 1234, 5678, null, null, true, new DeviceCapabilities(true, true, true, deviceSupportsVersionedExpirationTimer, false)), new AccountAttributes(false, 1234, 5678, null, null, true, new DeviceCapabilities(true, true, true, deviceSupportsVersionedExpirationTimer, false)),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id")))); new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id"))));
@ -386,7 +372,7 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID), final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, 1234, 5678, null, null, true, null), new AccountAttributes(false, 1234, 5678, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id")))); new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id"))));
@ -400,6 +386,52 @@ class DeviceControllerTest {
} }
} }
@Test
void linkDeviceAtomicReusedToken() {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
final KEMSignedPreKey aciPqLastResortPreKey;
final KEMSignedPreKey pniPqLastResortPreKey;
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
aciSignedPreKey = KeysHelper.signedECPreKey(1, aciIdentityKeyPair);
pniSignedPreKey = KeysHelper.signedECPreKey(2, pniIdentityKeyPair);
aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
when(accountsManager.addDevice(any(), any(), any()))
.thenReturn(CompletableFuture.failedFuture(new LinkDeviceTokenAlreadyUsedException()));
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final AccountAttributes accountAttributes = new AccountAttributes(true, 1234, 5678, null,
null, true, new DeviceCapabilities(true, true, true, false, false));
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
accountAttributes,
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty()));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/link")
.request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) {
assertEquals(403, response.getStatus());
}
}
@Test @Test
void linkDeviceAtomicWithVerificationTokenUsed() { void linkDeviceAtomicWithVerificationTokenUsed() {
@ -427,7 +459,7 @@ class DeviceControllerTest {
when(commands.get(anyString())).thenReturn(""); when(commands.get(anyString())).thenReturn("");
final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID), final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, 1234, 5678, null, null, true, null), new AccountAttributes(false, 1234, 5678, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id")))); new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id"))));
@ -577,16 +609,12 @@ class DeviceControllerTest {
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID); when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(), when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(true, 1234, 5678, null, null, true, null), new AccountAttributes(true, 1234, 5678, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty())); new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty()));
@ -614,17 +642,12 @@ class DeviceControllerTest {
final Device existingDevice = mock(Device.class); final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID); when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(aciIdentityKey); when(account.getIdentityKey(IdentityType.ACI)).thenReturn(aciIdentityKey);
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(pniIdentityKey); when(account.getIdentityKey(IdentityType.PNI)).thenReturn(pniIdentityKey);
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(), when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(true, 1234, 5678, null, null, true, null), new AccountAttributes(true, 1234, 5678, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty())); new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty()));
@ -683,7 +706,7 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceController.generateVerificationToken(AuthHelper.VALID_UUID), final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, 1234, 5678, TestRandomUtil.nextBytes(512), null, true, null), new AccountAttributes(false, 1234, 5678, TestRandomUtil.nextBytes(512), null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id")))); new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.of(new GcmRegistrationId("gcm-id"))));
@ -704,12 +727,6 @@ class DeviceControllerTest {
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID); when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
@ -721,16 +738,18 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
when(accountsManager.addDevice(any(), any())).thenAnswer(invocation -> { when(accountsManager.addDevice(any(), any(), any())).thenAnswer(invocation -> {
final Account a = invocation.getArgument(0); final Account a = invocation.getArgument(0);
final DeviceSpec deviceSpec = invocation.getArgument(1); final DeviceSpec deviceSpec = invocation.getArgument(1);
return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock))); return CompletableFuture.completedFuture(new Pair<>(a, deviceSpec.toDevice(NEXT_DEVICE_ID, testClock)));
}); });
when(accountsManager.checkDeviceLinkingToken(anyString())).thenReturn(Optional.of(AuthHelper.VALID_UUID));
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null)); when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture(null));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(), final LinkDeviceRequest request = new LinkDeviceRequest("link-device-token",
new AccountAttributes(false, registrationId, pniRegistrationId, null, null, true, new DeviceCapabilities(true, true, true, false, false)), new AccountAttributes(false, registrationId, pniRegistrationId, null, null, true, new DeviceCapabilities(true, true, true, false, false)),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.of(new ApnRegistrationId("apn")), Optional.empty())); new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.of(new ApnRegistrationId("apn")), Optional.empty()));
@ -785,7 +804,7 @@ class DeviceControllerTest {
.get(); .get();
assertEquals(411, response.getStatus()); assertEquals(411, response.getStatus());
verify(accountsManager, never()).addDevice(any(), any()); verify(accountsManager, never()).addDevice(any(), any(), any());
} }
@Test @Test
@ -898,60 +917,6 @@ class DeviceControllerTest {
} }
} }
@Test
void checkVerificationToken() {
final UUID uuid = UUID.randomUUID();
assertEquals(Optional.of(uuid),
deviceController.checkVerificationToken(deviceController.generateVerificationToken(uuid)));
}
@ParameterizedTest
@MethodSource
void checkVerificationTokenBadToken(final String token, final Instant currentTime) {
testClock.pin(currentTime);
assertEquals(Optional.empty(),
deviceController.checkVerificationToken(token));
}
private static Stream<Arguments> checkVerificationTokenBadToken() {
final Instant tokenTimestamp = testClock.instant();
return Stream.of(
// Expired token
Arguments.of(deviceController.generateVerificationToken(UUID.randomUUID()),
tokenTimestamp.plus(DeviceController.TOKEN_EXPIRATION_DURATION).plusSeconds(1)),
// Bad UUID
Arguments.of("not-a-valid-uuid.1691096565171:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// No UUID
Arguments.of(".1691096565171:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// Bad timestamp
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.not-a-valid-timestamp:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// No timestamp
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// Blank timestamp
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// No signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171", tokenTimestamp),
// Blank signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:", tokenTimestamp),
// Incorrect signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// Invalid signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:This is not valid base64", tokenTimestamp)
);
}
@Test @Test
void setPublicKey() { void setPublicKey() {
final SetPublicKeyRequest request = new SetPublicKeyRequest(Curve.generateKeyPair().getPublicKey()); final SetPublicKeyRequest request = new SetPublicKeyRequest(Curve.generateKeyPair().getPublicKey());

View File

@ -105,7 +105,8 @@ public class AccountCreationDeletionIntegrationTest {
DynamoDbExtensionSchema.Tables.NUMBERS.tableName(), DynamoDbExtensionSchema.Tables.NUMBERS.tableName(),
DynamoDbExtensionSchema.Tables.PNI_ASSIGNMENTS.tableName(), DynamoDbExtensionSchema.Tables.PNI_ASSIGNMENTS.tableName(),
DynamoDbExtensionSchema.Tables.USERNAMES.tableName(), DynamoDbExtensionSchema.Tables.USERNAMES.tableName(),
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS.tableName()); DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS.tableName(),
DynamoDbExtensionSchema.Tables.USED_LINK_DEVICE_TOKENS.tableName());
accountLockExecutor = Executors.newSingleThreadExecutor(); accountLockExecutor = Executors.newSingleThreadExecutor();
clientPresenceExecutor = Executors.newSingleThreadExecutor(); clientPresenceExecutor = Executors.newSingleThreadExecutor();
@ -141,6 +142,7 @@ public class AccountCreationDeletionIntegrationTest {
accounts, accounts,
phoneNumberIdentifiers, phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(), CACHE_CLUSTER_EXTENSION.getRedisCluster(),
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
accountLockManager, accountLockManager,
keysManager, keysManager,
messagesManager, messagesManager,
@ -153,6 +155,7 @@ public class AccountCreationDeletionIntegrationTest {
accountLockExecutor, accountLockExecutor,
clientPresenceExecutor, clientPresenceExecutor,
CLOCK, CLOCK,
"link-device-secret".getBytes(StandardCharsets.UTF_8),
dynamicConfigurationManager); dynamicConfigurationManager);
} }

View File

@ -98,7 +98,8 @@ class AccountsManagerChangeNumberIntegrationTest {
Tables.NUMBERS.tableName(), Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(), Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(), Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName()); Tables.DELETED_ACCOUNTS.tableName(),
Tables.USED_LINK_DEVICE_TOKENS.tableName());
accountLockExecutor = Executors.newSingleThreadExecutor(); accountLockExecutor = Executors.newSingleThreadExecutor();
clientPresenceExecutor = Executors.newSingleThreadExecutor(); clientPresenceExecutor = Executors.newSingleThreadExecutor();
@ -136,6 +137,7 @@ class AccountsManagerChangeNumberIntegrationTest {
accounts, accounts,
phoneNumberIdentifiers, phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(), CACHE_CLUSTER_EXTENSION.getRedisCluster(),
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
accountLockManager, accountLockManager,
keysManager, keysManager,
messagesManager, messagesManager,
@ -148,6 +150,7 @@ class AccountsManagerChangeNumberIntegrationTest {
accountLockExecutor, accountLockExecutor,
clientPresenceExecutor, clientPresenceExecutor,
mock(Clock.class), mock(Clock.class),
"link-device-secret".getBytes(StandardCharsets.UTF_8),
dynamicConfigurationManager); dynamicConfigurationManager);
} }
} }

View File

@ -93,7 +93,8 @@ class AccountsManagerConcurrentModificationIntegrationTest {
Tables.NUMBERS.tableName(), Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(), Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(), Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName()); Tables.DELETED_ACCOUNTS.tableName(),
Tables.USED_LINK_DEVICE_TOKENS.tableName());
{ {
//noinspection unchecked //noinspection unchecked
@ -123,6 +124,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
accounts, accounts,
phoneNumberIdentifiers, phoneNumberIdentifiers,
RedisClusterHelper.builder().stringCommands(commands).build(), RedisClusterHelper.builder().stringCommands(commands).build(),
RedisClusterHelper.builder().stringCommands(commands).build(),
accountLockManager, accountLockManager,
mock(KeysManager.class), mock(KeysManager.class),
mock(MessagesManager.class), mock(MessagesManager.class),
@ -135,6 +137,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
mock(Executor.class), mock(Executor.class),
mock(Executor.class), mock(Executor.class),
mock(Clock.class), mock(Clock.class),
"link-device-secret".getBytes(StandardCharsets.UTF_8),
dynamicConfigurationManager dynamicConfigurationManager
); );
} }

View File

@ -37,7 +37,9 @@ import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.io.InputStream; import java.io.InputStream;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.time.Duration; import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Base64; import java.util.Base64;
import java.util.Collections; import java.util.Collections;
@ -75,6 +77,7 @@ import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecoveryException; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecoveryException;
@ -89,6 +92,7 @@ import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import javax.crypto.spec.SecretKeySpec;
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) @Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
class AccountsManagerTest { class AccountsManagerTest {
@ -102,6 +106,10 @@ class AccountsManagerTest {
private static final byte[] ENCRYPTED_USERNAME_1 = Base64.getUrlDecoder().decode(BASE_64_URL_ENCRYPTED_USERNAME_1); private static final byte[] ENCRYPTED_USERNAME_1 = Base64.getUrlDecoder().decode(BASE_64_URL_ENCRYPTED_USERNAME_1);
private static final byte[] ENCRYPTED_USERNAME_2 = Base64.getUrlDecoder().decode(BASE_64_URL_ENCRYPTED_USERNAME_2); private static final byte[] ENCRYPTED_USERNAME_2 = Base64.getUrlDecoder().decode(BASE_64_URL_ENCRYPTED_USERNAME_2);
private static final byte[] LINK_DEVICE_SECRET = "link-device-secret".getBytes(StandardCharsets.UTF_8);
private static TestClock CLOCK;
private Accounts accounts; private Accounts accounts;
private KeysManager keysManager; private KeysManager keysManager;
private MessagesManager messagesManager; private MessagesManager messagesManager;
@ -113,7 +121,6 @@ class AccountsManagerTest {
private RedisAdvancedClusterCommands<String, String> commands; private RedisAdvancedClusterCommands<String, String> commands;
private RedisAdvancedClusterAsyncCommands<String, String> asyncCommands; private RedisAdvancedClusterAsyncCommands<String, String> asyncCommands;
private TestClock clock;
private AccountsManager accountsManager; private AccountsManager accountsManager;
private SecureValueRecovery2Client svr2Client; private SecureValueRecovery2Client svr2Client;
private DynamicConfiguration dynamicConfiguration; private DynamicConfiguration dynamicConfiguration;
@ -161,6 +168,7 @@ class AccountsManagerTest {
asyncCommands = mock(RedisAdvancedClusterAsyncCommands.class); asyncCommands = mock(RedisAdvancedClusterAsyncCommands.class);
when(asyncCommands.del(any(String[].class))).thenReturn(MockRedisFuture.completedFuture(0L)); when(asyncCommands.del(any(String[].class))).thenReturn(MockRedisFuture.completedFuture(0L));
when(asyncCommands.get(any())).thenReturn(MockRedisFuture.completedFuture(null)); when(asyncCommands.get(any())).thenReturn(MockRedisFuture.completedFuture(null));
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(accounts.updateAsync(any())).thenReturn(CompletableFuture.completedFuture(null)); when(accounts.updateAsync(any())).thenReturn(CompletableFuture.completedFuture(null));
@ -220,16 +228,18 @@ class AccountsManagerTest {
when(messagesManager.clear(any())).thenReturn(CompletableFuture.completedFuture(null)); when(messagesManager.clear(any())).thenReturn(CompletableFuture.completedFuture(null));
when(profilesManager.deleteAll(any())).thenReturn(CompletableFuture.completedFuture(null)); when(profilesManager.deleteAll(any())).thenReturn(CompletableFuture.completedFuture(null));
clock = TestClock.now(); CLOCK = TestClock.now();
final FaultTolerantRedisCluster redisCluster = RedisClusterHelper.builder()
.stringCommands(commands)
.stringAsyncCommands(asyncCommands)
.build();
accountsManager = new AccountsManager( accountsManager = new AccountsManager(
accounts, accounts,
phoneNumberIdentifiers, phoneNumberIdentifiers,
RedisClusterHelper.builder() redisCluster,
.stringCommands(commands) redisCluster,
.stringAsyncCommands(asyncCommands)
.build(),
accountLockManager, accountLockManager,
keysManager, keysManager,
messagesManager, messagesManager,
@ -241,7 +251,8 @@ class AccountsManagerTest {
clientPublicKeysManager, clientPublicKeysManager,
mock(Executor.class), mock(Executor.class),
clientPresenceExecutor, clientPresenceExecutor,
clock, CLOCK,
LINK_DEVICE_SECRET,
dynamicConfigurationManager); dynamicConfigurationManager);
} }
@ -920,7 +931,7 @@ class AccountsManagerTest {
PhoneNumberUtil.getInstance().format(PhoneNumberUtil.getInstance().getExampleNumber("US"), PhoneNumberUtil.getInstance().format(PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164); PhoneNumberUtil.PhoneNumberFormat.E164);
final Account account = AccountsHelper.generateTestAccount(phoneNumber, List.of(generateTestDevice(clock.millis()))); final Account account = AccountsHelper.generateTestAccount(phoneNumber, List.of(generateTestDevice(CLOCK.millis())));
final UUID aci = account.getIdentifier(IdentityType.ACI); final UUID aci = account.getIdentifier(IdentityType.ACI);
final UUID pni = account.getIdentifier(IdentityType.PNI); final UUID pni = account.getIdentifier(IdentityType.PNI);
@ -945,7 +956,7 @@ class AccountsManagerTest {
when(accounts.getByAccountIdentifierAsync(aci)).thenReturn(CompletableFuture.completedFuture(Optional.of(account))); when(accounts.getByAccountIdentifierAsync(aci)).thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(accounts.updateTransactionallyAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(accounts.updateTransactionallyAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
clock.pin(clock.instant().plusSeconds(60)); CLOCK.pin(CLOCK.instant().plusSeconds(60));
final Pair<Account, Device> updatedAccountAndDevice = accountsManager.addDevice(account, new DeviceSpec( final Pair<Account, Device> updatedAccountAndDevice = accountsManager.addDevice(account, new DeviceSpec(
deviceNameCiphertext, deviceNameCiphertext,
@ -960,7 +971,8 @@ class AccountsManagerTest {
aciSignedPreKey, aciSignedPreKey,
pniSignedPreKey, pniSignedPreKey,
aciPqLastResortPreKey, aciPqLastResortPreKey,
pniPqLastResortPreKey)) pniPqLastResortPreKey),
accountsManager.generateDeviceLinkingToken(aci))
.join(); .join();
verify(keysManager).deleteSingleUsePreKeys(aci, nextDeviceId); verify(keysManager).deleteSingleUsePreKeys(aci, nextDeviceId);
@ -1589,4 +1601,59 @@ class AccountsManagerTest {
KeysHelper.signedKEMPreKey(4, pniKeyPair)), KeysHelper.signedKEMPreKey(4, pniKeyPair)),
null); null);
} }
@Test
void checkDeviceLinkingToken() {
final UUID aci = UUID.randomUUID();
assertEquals(Optional.of(aci),
accountsManager.checkDeviceLinkingToken(accountsManager.generateDeviceLinkingToken(aci)));
}
@ParameterizedTest
@MethodSource
void checkVerificationTokenBadToken(final String token, final Instant currentTime) {
CLOCK.pin(currentTime);
assertEquals(Optional.empty(), accountsManager.checkDeviceLinkingToken(token));
}
private static Stream<Arguments> checkVerificationTokenBadToken() throws InvalidKeyException {
final Instant tokenTimestamp = Instant.now();
return Stream.of(
// Expired token
Arguments.of(AccountsManager.generateDeviceLinkingToken(UUID.randomUUID(),
new SecretKeySpec(LINK_DEVICE_SECRET, AccountsManager.LINK_DEVICE_VERIFICATION_TOKEN_ALGORITHM),
CLOCK),
tokenTimestamp.plus(AccountsManager.LINK_DEVICE_TOKEN_EXPIRATION_DURATION).plusSeconds(1)),
// Bad UUID
Arguments.of("not-a-valid-uuid.1691096565171:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// No UUID
Arguments.of(".1691096565171:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// Bad timestamp
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.not-a-valid-timestamp:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// No timestamp
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// Blank timestamp
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// No signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171", tokenTimestamp),
// Blank signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:", tokenTimestamp),
// Incorrect signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:0CKWF7q3E9fi4sB2or4q1A0Up2z_73EQlMAy7Dpel9c=", tokenTimestamp),
// Invalid signature
Arguments.of("e552603a-1492-4de6-872d-bac19a2825b4.1691096565171:This is not valid base64", tokenTimestamp)
);
}
} }

View File

@ -14,6 +14,7 @@ import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import java.nio.charset.StandardCharsets;
import java.time.Clock; import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
@ -105,7 +106,8 @@ class AccountsManagerUsernameIntegrationTest {
Tables.NUMBERS.tableName(), Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(), Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(), Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName())); Tables.DELETED_ACCOUNTS.tableName(),
Tables.USED_LINK_DEVICE_TOKENS.tableName()));
final AccountLockManager accountLockManager = mock(AccountLockManager.class); final AccountLockManager accountLockManager = mock(AccountLockManager.class);
@ -135,6 +137,7 @@ class AccountsManagerUsernameIntegrationTest {
accounts, accounts,
phoneNumberIdentifiers, phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(), CACHE_CLUSTER_EXTENSION.getRedisCluster(),
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
accountLockManager, accountLockManager,
keysManager, keysManager,
messageManager, messageManager,
@ -147,6 +150,7 @@ class AccountsManagerUsernameIntegrationTest {
Executors.newSingleThreadExecutor(), Executors.newSingleThreadExecutor(),
Executors.newSingleThreadExecutor(), Executors.newSingleThreadExecutor(),
mock(Clock.class), mock(Clock.class),
"link-device-secret".getBytes(StandardCharsets.UTF_8),
dynamicConfigurationManager); dynamicConfigurationManager);
} }

View File

@ -106,6 +106,7 @@ class AccountsTest {
Tables.PNI_ASSIGNMENTS, Tables.PNI_ASSIGNMENTS,
Tables.USERNAMES, Tables.USERNAMES,
Tables.DELETED_ACCOUNTS, Tables.DELETED_ACCOUNTS,
Tables.USED_LINK_DEVICE_TOKENS,
// This is an unrelated table used to test "tag-along" transactional updates // This is an unrelated table used to test "tag-along" transactional updates
Tables.CLIENT_RELEASES); Tables.CLIENT_RELEASES);
@ -132,7 +133,8 @@ class AccountsTest {
Tables.NUMBERS.tableName(), Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(), Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(), Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName()); Tables.DELETED_ACCOUNTS.tableName(),
Tables.USED_LINK_DEVICE_TOKENS.tableName());
} }
@Test @Test
@ -560,7 +562,8 @@ class AccountsTest {
Tables.NUMBERS.tableName(), Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(), Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(), Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName()); Tables.DELETED_ACCOUNTS.tableName(),
Tables.USED_LINK_DEVICE_TOKENS.tableName());
Exception e = TransactionConflictException.builder().build(); Exception e = TransactionConflictException.builder().build();
e = wrapException ? new CompletionException(e) : e; e = wrapException ? new CompletionException(e) : e;
@ -648,7 +651,8 @@ class AccountsTest {
Tables.NUMBERS.tableName(), Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(), Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(), Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName()); Tables.DELETED_ACCOUNTS.tableName(),
Tables.USED_LINK_DEVICE_TOKENS.tableName());
when(dynamoDbAsyncClient.transactWriteItems(any(TransactWriteItemsRequest.class))) when(dynamoDbAsyncClient.transactWriteItems(any(TransactWriteItemsRequest.class)))
.thenReturn(CompletableFuture.failedFuture(TransactionCanceledException.builder() .thenReturn(CompletableFuture.failedFuture(TransactionCanceledException.builder()
@ -1039,7 +1043,8 @@ class AccountsTest {
Tables.NUMBERS.tableName(), Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(), Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(), Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName()); Tables.DELETED_ACCOUNTS.tableName(),
Tables.USED_LINK_DEVICE_TOKENS.tableName());
final Account account = generateAccount("+14155551111", UUID.randomUUID(), UUID.randomUUID()); final Account account = generateAccount("+14155551111", UUID.randomUUID(), UUID.randomUUID());
createAccount(account); createAccount(account);
@ -1081,7 +1086,8 @@ class AccountsTest {
Tables.NUMBERS.tableName(), Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(), Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(), Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName()); Tables.DELETED_ACCOUNTS.tableName(),
Tables.USED_LINK_DEVICE_TOKENS.tableName());
final Account account = generateAccount("+14155551111", UUID.randomUUID(), UUID.randomUUID()); final Account account = generateAccount("+14155551111", UUID.randomUUID(), UUID.randomUUID());
createAccount(account); createAccount(account);
@ -1181,7 +1187,8 @@ class AccountsTest {
Tables.NUMBERS.tableName(), Tables.NUMBERS.tableName(),
Tables.PNI_ASSIGNMENTS.tableName(), Tables.PNI_ASSIGNMENTS.tableName(),
Tables.USERNAMES.tableName(), Tables.USERNAMES.tableName(),
Tables.DELETED_ACCOUNTS.tableName()); Tables.DELETED_ACCOUNTS.tableName(),
Tables.USED_LINK_DEVICE_TOKENS.tableName());
final Account account = generateAccount("+14155551111", UUID.randomUUID(), UUID.randomUUID()); final Account account = generateAccount("+14155551111", UUID.randomUUID(), UUID.randomUUID());
createAccount(account); createAccount(account);

View File

@ -2,6 +2,7 @@ package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
@ -45,6 +46,7 @@ public class AddRemoveDeviceIntegrationTest {
DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS, DynamoDbExtensionSchema.Tables.CLIENT_PUBLIC_KEYS,
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS, DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS,
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS_LOCK, DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS_LOCK,
DynamoDbExtensionSchema.Tables.USED_LINK_DEVICE_TOKENS,
DynamoDbExtensionSchema.Tables.NUMBERS, DynamoDbExtensionSchema.Tables.NUMBERS,
DynamoDbExtensionSchema.Tables.PNI, DynamoDbExtensionSchema.Tables.PNI,
DynamoDbExtensionSchema.Tables.PNI_ASSIGNMENTS, DynamoDbExtensionSchema.Tables.PNI_ASSIGNMENTS,
@ -93,7 +95,8 @@ public class AddRemoveDeviceIntegrationTest {
DynamoDbExtensionSchema.Tables.NUMBERS.tableName(), DynamoDbExtensionSchema.Tables.NUMBERS.tableName(),
DynamoDbExtensionSchema.Tables.PNI_ASSIGNMENTS.tableName(), DynamoDbExtensionSchema.Tables.PNI_ASSIGNMENTS.tableName(),
DynamoDbExtensionSchema.Tables.USERNAMES.tableName(), DynamoDbExtensionSchema.Tables.USERNAMES.tableName(),
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS.tableName()); DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS.tableName(),
DynamoDbExtensionSchema.Tables.USED_LINK_DEVICE_TOKENS.tableName());
accountLockExecutor = Executors.newSingleThreadExecutor(); accountLockExecutor = Executors.newSingleThreadExecutor();
clientPresenceExecutor = Executors.newSingleThreadExecutor(); clientPresenceExecutor = Executors.newSingleThreadExecutor();
@ -129,6 +132,7 @@ public class AddRemoveDeviceIntegrationTest {
accounts, accounts,
phoneNumberIdentifiers, phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(), CACHE_CLUSTER_EXTENSION.getRedisCluster(),
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
accountLockManager, accountLockManager,
keysManager, keysManager,
messagesManager, messagesManager,
@ -141,6 +145,7 @@ public class AddRemoveDeviceIntegrationTest {
accountLockExecutor, accountLockExecutor,
clientPresenceExecutor, clientPresenceExecutor,
CLOCK, CLOCK,
"link-device-secret".getBytes(StandardCharsets.UTF_8),
dynamicConfigurationManager); dynamicConfigurationManager);
} }
@ -182,7 +187,8 @@ public class AddRemoveDeviceIntegrationTest {
KeysHelper.signedECPreKey(1, aciKeyPair), KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair), KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair), KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair))) KeysHelper.signedKEMPreKey(4, pniKeyPair)),
accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI)))
.join(); .join();
assertEquals(2, updatedAccountAndDevice.first().getDevices().size()); assertEquals(2, updatedAccountAndDevice.first().getDevices().size());
@ -199,6 +205,67 @@ public class AddRemoveDeviceIntegrationTest {
assertTrue(keysManager.getLastResort(updatedAccountAndDevice.first().getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); assertTrue(keysManager.getLastResort(updatedAccountAndDevice.first().getPhoneNumberIdentifier(), addedDeviceId).join().isPresent());
} }
@Test
void addDeviceReusedToken() throws InterruptedException {
final String number = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
final Account account = AccountsHelper.createAccount(accountsManager, number);
assertEquals(1, accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getDevices().size());
final String linkDeviceToken = accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI));
final Pair<Account, Device> updatedAccountAndDevice =
accountsManager.addDevice(account, new DeviceSpec(
"device-name".getBytes(StandardCharsets.UTF_8),
"password",
"OWT",
new Device.DeviceCapabilities(true, true, true, false, false),
1,
2,
true,
Optional.empty(),
Optional.empty(),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)),
linkDeviceToken)
.join();
assertEquals(2,
accountsManager.getByAccountIdentifier(updatedAccountAndDevice.first().getUuid()).orElseThrow().getDevices()
.size());
final CompletionException completionException = assertThrows(CompletionException.class,
() -> accountsManager.addDevice(account, new DeviceSpec(
"device-name".getBytes(StandardCharsets.UTF_8),
"password",
"OWT",
new Device.DeviceCapabilities(true, true, true, false, false),
1,
2,
true,
Optional.empty(),
Optional.empty(),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)),
linkDeviceToken)
.join());
assertInstanceOf(LinkDeviceTokenAlreadyUsedException.class, completionException.getCause());
assertEquals(2,
accountsManager.getByAccountIdentifier(updatedAccountAndDevice.first().getUuid()).orElseThrow().getDevices()
.size());
}
@Test @Test
void removeDevice() throws InterruptedException { void removeDevice() throws InterruptedException {
final String number = PhoneNumberUtil.getInstance().format( final String number = PhoneNumberUtil.getInstance().format(
@ -225,7 +292,8 @@ public class AddRemoveDeviceIntegrationTest {
KeysHelper.signedECPreKey(1, aciKeyPair), KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair), KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair), KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair))) KeysHelper.signedKEMPreKey(4, pniKeyPair)),
accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI)))
.join(); .join();
final byte addedDeviceId = updatedAccountAndDevice.second().getId(); final byte addedDeviceId = updatedAccountAndDevice.second().getId();
@ -278,7 +346,8 @@ public class AddRemoveDeviceIntegrationTest {
KeysHelper.signedECPreKey(1, aciKeyPair), KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair), KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair), KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair))) KeysHelper.signedKEMPreKey(4, pniKeyPair)),
accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI)))
.join(); .join();
final byte addedDeviceId = updatedAccountAndDevice.second().getId(); final byte addedDeviceId = updatedAccountAndDevice.second().getId();

View File

@ -372,6 +372,16 @@ public final class DynamoDbExtensionSchema {
List.of(), List.of(),
List.of()), List.of()),
USED_LINK_DEVICE_TOKENS("used_link_device_tokens_test",
Accounts.KEY_LINK_DEVICE_TOKEN_HASH,
null,
List.of(AttributeDefinition.builder()
.attributeName(Accounts.KEY_LINK_DEVICE_TOKEN_HASH)
.attributeType(ScalarAttributeType.B)
.build()),
List.of(),
List.of()),
USERNAMES("usernames_test", USERNAMES("usernames_test",
Accounts.ATTR_USERNAME_HASH, Accounts.ATTR_USERNAME_HASH,
null, null,

View File

@ -97,6 +97,7 @@ dynamoDbTables:
phoneNumberTableName: numbers_test phoneNumberTableName: numbers_test
phoneNumberIdentifierTableName: pni_assignment_test phoneNumberIdentifierTableName: pni_assignment_test
usernamesTableName: usernames_test usernamesTableName: usernames_test
usedLinkDeviceTokensTableName: used_link_device_tokens_test
backups: backups:
tableName: backups_test tableName: backups_test
clientReleases: clientReleases: