Add API endpoints for waiting for newly-linked devices

This commit is contained in:
Jon Chambers 2024-10-10 10:11:32 -04:00 committed by GitHub
parent 087c2b61ee
commit 8c30a359e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 793 additions and 122 deletions

View File

@ -642,7 +642,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
ClientPublicKeysManager clientPublicKeysManager =
new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor);
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
accountLockManager, keysManager, messagesManager, profilesManager,
pubsubClient, accountLockManager, keysManager, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client,
clientPresenceManager,
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, clientPresenceExecutor,
@ -764,6 +764,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.lifecycle().manage(keyTransparencyServiceClient);
environment.lifecycle().manage(clientReleaseManager);
environment.lifecycle().manage(virtualThreadPinEventMonitor);
environment.lifecycle().manage(accountsManager);
final RegistrationCaptchaManager registrationCaptchaManager = new RegistrationCaptchaManager(captchaChecker);

View File

@ -4,22 +4,36 @@
*/
package org.whispersystems.textsecuregcm.controllers;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth;
import io.lettuce.core.RedisException;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.headers.Header;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.LinkedList;
import java.time.Duration;
import java.util.Arrays;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import javax.validation.Valid;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotNull;
import javax.validation.constraints.Size;
import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE;
import javax.ws.rs.DefaultValue;
import javax.ws.rs.ForbiddenException;
import javax.ws.rs.GET;
import javax.ws.rs.HeaderParam;
@ -27,10 +41,12 @@ import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import io.swagger.v3.oas.annotations.tags.Tag;
import org.glassfish.jersey.server.ContainerRequest;
import org.whispersystems.textsecuregcm.auth.LinkedDeviceRefreshRequirementProvider;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
@ -47,6 +63,8 @@ import org.whispersystems.textsecuregcm.entities.ProvisioningMessage;
import org.whispersystems.textsecuregcm.entities.SetPublicKeyRequest;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
@ -54,7 +72,11 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.LinkDeviceTokenAlreadyUsedException;
import org.whispersystems.textsecuregcm.util.VerificationCode;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.LinkDeviceToken;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.websocket.auth.Mutable;
import org.whispersystems.websocket.auth.ReadOnly;
@ -69,6 +91,21 @@ public class DeviceController {
private final RateLimiters rateLimiters;
private final Map<String, Integer> maxDeviceConfiguration;
private final EnumMap<ClientPlatform, AtomicInteger> linkedDeviceListenersByPlatform;
private final AtomicInteger linkedDeviceListenersForUnrecognizedPlatforms;
private static final String LINKED_DEVICE_LISTENER_GAUGE_NAME =
MetricsUtil.name(DeviceController.class, "linkedDeviceListeners");
private static final String WAIT_FOR_LINKED_DEVICE_TIMER_NAME =
MetricsUtil.name(DeviceController.class, "waitForLinkedDeviceDuration");
@VisibleForTesting
static final int MIN_TOKEN_IDENTIFIER_LENGTH = 32;
@VisibleForTesting
static final int MAX_TOKEN_IDENTIFIER_LENGTH = 64;
public DeviceController(final AccountsManager accounts,
final ClientPublicKeysManager clientPublicKeysManager,
final RateLimiters rateLimiters,
@ -78,19 +115,32 @@ public class DeviceController {
this.clientPublicKeysManager = clientPublicKeysManager;
this.rateLimiters = rateLimiters;
this.maxDeviceConfiguration = maxDeviceConfiguration;
linkedDeviceListenersByPlatform = Arrays.stream(ClientPlatform.values())
.collect(Collectors.toMap(
Function.identity(),
clientPlatform -> buildGauge(clientPlatform.name().toLowerCase()),
(a, b) -> {
throw new AssertionError("Duplicate client platform enumeration key");
},
() -> new EnumMap<>(ClientPlatform.class)
));
linkedDeviceListenersForUnrecognizedPlatforms = buildGauge("unknown");
}
private static AtomicInteger buildGauge(final String clientPlatformName) {
return Metrics.gauge(LINKED_DEVICE_LISTENER_GAUGE_NAME,
Tags.of(io.micrometer.core.instrument.Tag.of(UserAgentTagUtil.PLATFORM_TAG, clientPlatformName)),
new AtomicInteger(0));
}
@GET
@Produces(MediaType.APPLICATION_JSON)
public DeviceInfoList getDevices(@ReadOnly @Auth AuthenticatedDevice auth) {
List<DeviceInfo> devices = new LinkedList<>();
for (Device device : auth.getAccount().getDevices()) {
devices.add(new DeviceInfo(device.getId(), device.getName(),
device.getLastSeen(), device.getCreated()));
}
return new DeviceInfoList(devices);
return new DeviceInfoList(auth.getAccount().getDevices().stream()
.map(DeviceInfo::forDevice)
.toList());
}
@DELETE
@ -138,7 +188,7 @@ public class DeviceController {
@ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header(
name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public VerificationCode createDeviceToken(@ReadOnly @Auth AuthenticatedDevice auth)
public LinkDeviceToken createDeviceToken(@ReadOnly @Auth AuthenticatedDevice auth)
throws RateLimitExceededException, DeviceLimitExceededException {
final Account account = auth.getAccount();
@ -159,7 +209,9 @@ public class DeviceController {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}
return new VerificationCode(accounts.generateDeviceLinkingToken(account.getUuid()));
final String token = accounts.generateLinkDeviceToken(account.getUuid());
return new LinkDeviceToken(token, AccountsManager.getLinkDeviceTokenIdentifier(token));
}
@PUT
@ -266,6 +318,83 @@ public class DeviceController {
}
}
@GET
@Path("/wait_for_linked_device/{tokenIdentifier}")
@Produces(MediaType.APPLICATION_JSON)
@Operation(summary = "Wait for a new device to be linked to an account",
description = """
Waits for a new device to be linked to an account and returns basic information about the new device when
available.
""")
@ApiResponse(responseCode = "200", description = "The specified was linked to an account")
@ApiResponse(responseCode = "204", description = "No device was linked to the account before the call completed")
@ApiResponse(responseCode = "400", description = "The given token identifier or timeout was invalid")
@ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay")
@Schema(description = "Basic information about the linked device", implementation = DeviceInfo.class)
public CompletableFuture<Response> waitForLinkedDevice(
@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice,
@PathParam("tokenIdentifier")
@Schema(description = "A 'link device' token identifier provided by the 'create link device token' endpoint")
@Size(min = MIN_TOKEN_IDENTIFIER_LENGTH, max = MAX_TOKEN_IDENTIFIER_LENGTH)
final String tokenIdentifier,
@QueryParam("timeout")
@DefaultValue("30")
@Min(1)
@Max(3600)
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED,
minimum = "1",
maximum = "3600",
description = """
The amount of time (in seconds) to wait for a response. If the expected device is not linked within the
given amount of time, this endpoint will return a status of HTTP/204.
""") final int timeoutSeconds,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent) throws RateLimitExceededException {
rateLimiters.getWaitForLinkedDeviceLimiter().validate(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI));
final AtomicInteger linkedDeviceListenerCounter = getCounterForLinkedDeviceListeners(userAgent);
linkedDeviceListenerCounter.incrementAndGet();
final Timer.Sample sample = Timer.start();
try {
return accounts.waitForNewLinkedDevice(tokenIdentifier, Duration.ofSeconds(timeoutSeconds))
.thenApply(maybeDeviceInfo -> maybeDeviceInfo
.map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build())
.orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build()))
.exceptionally(ExceptionUtils.exceptionallyHandler(IllegalArgumentException.class,
e -> Response.status(Response.Status.BAD_REQUEST).build()))
.whenComplete((response, throwable) -> {
linkedDeviceListenerCounter.decrementAndGet();
if (response != null) {
sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME)
.publishPercentileHistogram(true)
.tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
io.micrometer.core.instrument.Tag.of("deviceFound",
String.valueOf(response.getStatus() == Response.Status.OK.getStatusCode()))))
.register(Metrics.globalRegistry));
}
});
} catch (final RedisException e) {
// `waitForNewLinkedDevice` could fail synchronously if the Redis circuit breaker is open; prevent counter drift
// if that happens
linkedDeviceListenerCounter.decrementAndGet();
throw e;
}
}
private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) {
try {
return linkedDeviceListenersByPlatform.get(UserAgentUtil.parseUserAgentString(userAgent).getPlatform());
} catch (final UnrecognizedUserAgentException ignored) {
return linkedDeviceListenersForUnrecognizedPlatforms;
}
}
@PUT
@Produces(MediaType.APPLICATION_JSON)
@Path("/unauthenticated_delivery")

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.ByteArrayBase64WithPaddingAdapter;
public record DeviceInfo(long id,
@ -17,4 +18,8 @@ public record DeviceInfo(long id,
long lastSeen,
long created) {
public static DeviceInfo forDevice(final Device device) {
return new DeviceInfo(device.getId(), device.getName(), device.getLastSeen(), device.getCreated());
}
}

View File

@ -50,6 +50,7 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", true, new RateLimiterConfig(100, Duration.ofMinutes(15))),
KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", true, new RateLimiterConfig(100, Duration.ofSeconds(15))),
KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", true, new RateLimiterConfig(100, Duration.ofSeconds(15))),
WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", true, new RateLimiterConfig(10, Duration.ofSeconds(30))),
;
private final String id;
@ -205,4 +206,8 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
public RateLimiter getStoriesLimiter() {
return forDescriptor(For.STORIES);
}
public RateLimiter getWaitForLinkedDeviceLimiter() {
return forDescriptor(For.WAIT_FOR_LINKED_DEVICE);
}
}

View File

@ -12,8 +12,11 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectWriter;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.dropwizard.lifecycle.Managed;
import io.lettuce.core.RedisException;
import io.lettuce.core.SetArgs;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import io.lettuce.core.pubsub.RedisPubSubAdapter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
@ -42,7 +45,9 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import java.util.function.Consumer;
@ -61,13 +66,16 @@ import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
@ -82,14 +90,13 @@ import reactor.core.scheduler.Scheduler;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
import software.amazon.awssdk.services.dynamodb.model.TransactionCanceledException;
public class AccountsManager {
public class AccountsManager extends RedisPubSubAdapter<String, String> implements Managed {
private static final Timer createTimer = Metrics.timer(name(AccountsManager.class, "create"));
private static final Timer updateTimer = Metrics.timer(name(AccountsManager.class, "update"));
private static final Timer getByNumberTimer = Metrics.timer(name(AccountsManager.class, "getByNumber"));
private static final Timer getByUsernameHashTimer = Metrics.timer(name(AccountsManager.class, "getByUsernameHash"));
private static final Timer getByUsernameLinkHandleTimer = Metrics.timer(
name(AccountsManager.class, "getByUsernameLinkHandle"));
private static final Timer getByUsernameLinkHandleTimer = Metrics.timer(name(AccountsManager.class, "getByUsernameLinkHandle"));
private static final Timer getByUuidTimer = Metrics.timer(name(AccountsManager.class, "getByUuid"));
private static final Timer deleteTimer = Metrics.timer(name(AccountsManager.class, "delete"));
@ -108,6 +115,7 @@ public class AccountsManager {
private final Accounts accounts;
private final PhoneNumberIdentifiers phoneNumberIdentifiers;
private final FaultTolerantRedisClusterClient cacheCluster;
private final FaultTolerantRedisClient pubSubRedisSingleton;
private final AccountLockManager accountLockManager;
private final KeysManager keysManager;
private final MessagesManager messagesManager;
@ -124,6 +132,16 @@ public class AccountsManager {
private final Key verificationTokenKey;
private final FaultTolerantPubSubConnection<String, String> pubSubConnection;
private final Map<String, CompletableFuture<Optional<DeviceInfo>>> waitForDeviceFuturesByTokenIdentifier =
new ConcurrentHashMap<>();
private static final int SHA256_HASH_LENGTH = getSha256MessageDigest().getDigestLength();
private static final Duration RECENTLY_ADDED_DEVICE_TTL = Duration.ofHours(1);
private static final String LINKED_DEVICE_PREFIX = "linked_device::";
private static final String LINKED_DEVICE_KEYSPACE_PATTERN = "__keyspace@0__:" + LINKED_DEVICE_PREFIX + "*";
private static final ObjectWriter ACCOUNT_REDIS_JSON_WRITER = SystemMapper.jsonMapper()
.writer(SystemMapper.excludingField(Account.class, List.of("uuid")));
@ -158,6 +176,7 @@ public class AccountsManager {
public AccountsManager(final Accounts accounts,
final PhoneNumberIdentifiers phoneNumberIdentifiers,
final FaultTolerantRedisClusterClient cacheCluster,
final FaultTolerantRedisClient pubSubRedisSingleton,
final AccountLockManager accountLockManager,
final KeysManager keysManager,
final MessagesManager messagesManager,
@ -175,6 +194,7 @@ public class AccountsManager {
this.accounts = accounts;
this.phoneNumberIdentifiers = phoneNumberIdentifiers;
this.cacheCluster = cacheCluster;
this.pubSubRedisSingleton = pubSubRedisSingleton;
this.accountLockManager = accountLockManager;
this.keysManager = keysManager;
this.messagesManager = messagesManager;
@ -197,6 +217,20 @@ public class AccountsManager {
} catch (final InvalidKeyException e) {
throw new IllegalArgumentException(e);
}
this.pubSubConnection = pubSubRedisSingleton.createPubSubConnection();
}
@Override
public void start() {
pubSubConnection.usePubSubConnection(connection -> connection.addListener(this));
pubSubConnection.usePubSubConnection(connection -> connection.sync().psubscribe(LINKED_DEVICE_KEYSPACE_PATTERN));
}
@Override
public void stop() {
pubSubConnection.usePubSubConnection(connection -> connection.sync().punsubscribe());
pubSubConnection.usePubSubConnection(connection -> connection.removeListener(this));
}
public Account create(final String number,
@ -363,6 +397,26 @@ public class AccountsManager {
}
return CompletableFuture.failedFuture(throwable);
})
.whenComplete((updatedAccountAndDevice, throwable) -> {
if (updatedAccountAndDevice != null) {
final String key = getLinkedDeviceKey(getLinkDeviceTokenIdentifier(linkDeviceToken));
final String deviceInfoJson;
try {
deviceInfoJson = SystemMapper.jsonMapper().writeValueAsString(DeviceInfo.forDevice(updatedAccountAndDevice.second()));
} catch (final JsonProcessingException e) {
throw new UncheckedIOException(e);
}
pubSubRedisSingleton.withConnection(connection ->
connection.async().set(key, deviceInfoJson, SetArgs.Builder.ex(RECENTLY_ADDED_DEVICE_TTL)))
.whenComplete((ignored, pubSubThrowable) -> {
if (pubSubThrowable != null) {
logger.warn("Failed to record recently-created device", pubSubThrowable);
}
});
}
});
}
@ -386,7 +440,7 @@ public class AccountsManager {
}
}
public String generateDeviceLinkingToken(final UUID aci) {
public String generateLinkDeviceToken(final UUID aci) {
final String claims = aci + "." + clock.instant().toEpochMilli();
final byte[] signature = getInitializedMac().doFinal(claims.getBytes(StandardCharsets.UTF_8));
@ -394,7 +448,7 @@ public class AccountsManager {
}
@VisibleForTesting
static String generateDeviceLinkingToken(final UUID aci, final Key linkDeviceTokenKey, final Clock clock)
static String generateLinkDeviceToken(final UUID aci, final Key linkDeviceTokenKey, final Clock clock)
throws InvalidKeyException {
final String claims = aci + "." + clock.instant().toEpochMilli();
@ -403,6 +457,11 @@ public class AccountsManager {
return claims + ":" + Base64.getUrlEncoder().encodeToString(signature);
}
public static String getLinkDeviceTokenIdentifier(final String linkDeviceToken) {
return Base64.getUrlEncoder().withoutPadding().encodeToString(
getSha256MessageDigest().digest(linkDeviceToken.getBytes(StandardCharsets.UTF_8)));
}
/**
* Checks that a device-linking token is valid and returns the account identifier from the token if so, or empty if
* the token was invalid
@ -1340,4 +1399,75 @@ public class AccountsManager {
.whenComplete((ignoredResult, ignoredException) -> sample.stop(redisDeleteTimer))
.thenRun(Util.NOOP);
}
public CompletableFuture<Optional<DeviceInfo>> waitForNewLinkedDevice(final String linkDeviceTokenIdentifier, final Duration timeout) {
// Unbeknownst to callers but beknownst to us, the "link device token identifier" is the base64/url-encoded SHA256
// hash of a device-linking token. Before we use the string anywhere, make sure it's the right "shape" for a hash.
if (Base64.getUrlDecoder().decode(linkDeviceTokenIdentifier).length != SHA256_HASH_LENGTH) {
return CompletableFuture.failedFuture(new IllegalArgumentException("Invalid token identifier"));
}
final CompletableFuture<Optional<DeviceInfo>> waitForDeviceFuture = new CompletableFuture<>();
waitForDeviceFuture
.completeOnTimeout(Optional.empty(), TimeUnit.MILLISECONDS.convert(timeout), TimeUnit.MILLISECONDS)
.whenComplete((maybeDevice, throwable) -> waitForDeviceFuturesByTokenIdentifier.compute(linkDeviceTokenIdentifier,
(ignored, existingFuture) -> {
// Only remove the future from the map if it's THIS future, and not one that later displaced this one
return existingFuture == waitForDeviceFuture ? null : existingFuture;
}));
{
final CompletableFuture<Optional<DeviceInfo>> displacedFuture =
waitForDeviceFuturesByTokenIdentifier.put(linkDeviceTokenIdentifier, waitForDeviceFuture);
if (displacedFuture != null) {
displacedFuture.complete(Optional.empty());
}
}
// The device may already have been linked by the time the caller started watching for it; perform an immediate
// check to see if the device is already there.
pubSubRedisSingleton.withConnection(connection -> connection.async().get(getLinkedDeviceKey(linkDeviceTokenIdentifier)))
.thenAccept(response -> {
if (StringUtils.isNotBlank(response)) {
handleDeviceAdded(waitForDeviceFuture, response);
}
});
return waitForDeviceFuture;
}
private static String getLinkedDeviceKey(final String linkDeviceTokenIdentifier) {
return LINKED_DEVICE_PREFIX + linkDeviceTokenIdentifier;
}
@Override
public void message(final String pattern, final String channel, final String message) {
if (LINKED_DEVICE_KEYSPACE_PATTERN.equals(pattern) && "set".equalsIgnoreCase(message)) {
// The `- 1` here compensates for the '*' in the pattern
final String tokenIdentifier = channel.substring(LINKED_DEVICE_KEYSPACE_PATTERN.length() - 1);
Optional.ofNullable(waitForDeviceFuturesByTokenIdentifier.remove(tokenIdentifier))
.ifPresent(future -> pubSubRedisSingleton.withConnection(connection -> connection.async().get(getLinkedDeviceKey(tokenIdentifier)))
.thenAccept(deviceInfoJson -> handleDeviceAdded(future, deviceInfoJson)));
}
}
private void handleDeviceAdded(final CompletableFuture<Optional<DeviceInfo>> future, final String deviceInfoJson) {
try {
future.complete(Optional.of(SystemMapper.jsonMapper().readValue(deviceInfoJson, DeviceInfo.class)));
} catch (final JsonProcessingException e) {
logger.error("Could not parse device json", e);
future.completeExceptionally(e);
}
}
private static MessageDigest getSha256MessageDigest() {
try {
return MessageDigest.getInstance("SHA-256");
} catch (final NoSuchAlgorithmException e) {
throw new AssertionError("Every implementation of the Java platform is required to support the SHA-256 MessageDigest algorithm", e);
}
}
}

View File

@ -0,0 +1,24 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import io.swagger.v3.oas.annotations.media.Schema;
public record LinkDeviceToken(
@Schema(description = """
An opaque token to send to a new linked device that authorizes the new device to link itself to the account that
requested this token.
""")
@JsonProperty("verificationCode") String token,
@Schema(description = """
An opaque identifier for the generated token that the caller may use to watch for a new device to complete the
linking process.
""")
String tokenIdentifier) {
}

View File

@ -1,8 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
public record VerificationCode(String verificationCode) {
}

View File

@ -40,6 +40,7 @@ import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.FcmSender;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.storage.AccountLockManager;
@ -112,6 +113,8 @@ record CommandDependencies(
.build("main_cache", redisClientResourcesBuilder);
FaultTolerantRedisClusterClient pushSchedulerCluster = configuration.getPushSchedulerCluster()
.build("push_scheduler", redisClientResourcesBuilder);
FaultTolerantRedisClient pubsubClient =
configuration.getRedisPubSubConfiguration().build("pubsub", redisClientResourcesBuilder.build());
ScheduledExecutorService recurringJobExecutor = environment.lifecycle()
.scheduledExecutorService(name(name, "recurringJob-%d")).threads(2).build();
@ -225,7 +228,7 @@ record CommandDependencies(
ClientPublicKeysManager clientPublicKeysManager =
new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor);
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
accountLockManager, keys, messagesManager, profilesManager,
pubsubClient, accountLockManager, keys, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client, clientPresenceManager,
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, clientPresenceExecutor,
clock, configuration.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager);

View File

@ -5,12 +5,14 @@
package org.whispersystems.textsecuregcm.controllers;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
@ -18,6 +20,7 @@ import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.amazonaws.util.Base64;
import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
@ -25,6 +28,7 @@ import io.dropwizard.testing.junit5.ResourceExtension;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -35,6 +39,7 @@ import java.util.stream.Stream;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.apache.commons.lang3.RandomStringUtils;
import org.glassfish.jersey.server.ServerProperties;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
@ -54,6 +59,7 @@ import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventLis
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest;
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
import org.whispersystems.textsecuregcm.entities.DeviceResponse;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
@ -64,6 +70,7 @@ import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
@ -79,7 +86,7 @@ import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.VerificationCode;
import org.whispersystems.textsecuregcm.util.LinkDeviceToken;
@ExtendWith(DropwizardExtensionsSupport.class)
class DeviceControllerTest {
@ -112,6 +119,7 @@ class DeviceControllerTest {
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class))
.addProvider(new RateLimitExceededExceptionMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager))
.addProvider(new DeviceLimitExceededExceptionMapper())
@ -122,6 +130,7 @@ class DeviceControllerTest {
void setup() {
when(rateLimiters.getAllocateDeviceLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVerifyDeviceLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getWaitForLinkedDeviceLimiter()).thenReturn(rateLimiter);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
@ -479,16 +488,17 @@ class DeviceControllerTest {
final Optional<ApnRegistrationId> apnRegistrationId,
final Optional<GcmRegistrationId> gcmRegistrationId) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.generateLinkDeviceToken(any())).thenReturn("test");
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
final LinkDeviceToken deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
.get(LinkDeviceToken.class);
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@ -506,7 +516,7 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.token(),
new AccountAttributes(fetchesMessages, 1234, 5678, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId));
@ -539,21 +549,22 @@ class DeviceControllerTest {
final KEMSignedPreKey pniPqLastResortPreKey) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.generateLinkDeviceToken(any())).thenReturn("test");
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest()
final LinkDeviceToken deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class);
.get(LinkDeviceToken.class);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(aciIdentityKey);
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(pniIdentityKey);
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(),
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.token(),
new AccountAttributes(true, 1234, 5678, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty()));
@ -931,4 +942,120 @@ class DeviceControllerTest {
verify(clientPublicKeysManager).setPublicKey(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE.getId(), request.publicKey());
}
@Test
void waitForLinkedDevice() {
final DeviceInfo deviceInfo = new DeviceInfo(Device.PRIMARY_ID,
"Device name ciphertext".getBytes(StandardCharsets.UTF_8),
System.currentTimeMillis(),
System.currentTimeMillis());
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo)));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(200, response.getStatus());
final DeviceInfo retrievedDeviceInfo = response.readEntity(DeviceInfo.class);
assertEquals(deviceInfo.id(), retrievedDeviceInfo.id());
assertArrayEquals(deviceInfo.name(), retrievedDeviceInfo.name());
assertEquals(deviceInfo.created(), retrievedDeviceInfo.created());
assertEquals(deviceInfo.lastSeen(), retrievedDeviceInfo.lastSeen());
}
}
@Test
void waitForLinkedDeviceNoDeviceLinked() {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(204, response.getStatus());
}
}
@Test
void waitForLinkedDeviceBadTokenIdentifier() {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException()));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(400, response.getStatus());
}
}
@ParameterizedTest
@MethodSource
void waitForLinkedDeviceBadTimeout(final int timeoutSeconds) {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.queryParam("timeout", timeoutSeconds)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(400, response.getStatus());
}
}
private static List<Integer> waitForLinkedDeviceBadTimeout() {
return List.of(0, -1, 3601);
}
@ParameterizedTest
@MethodSource
void waitForLinkedDeviceBadTokenIdentifierLength(final String tokenIdentifier) {
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(400, response.getStatus());
}
}
private static List<String> waitForLinkedDeviceBadTokenIdentifierLength() {
return List.of(RandomStringUtils.randomAlphanumeric(DeviceController.MIN_TOKEN_IDENTIFIER_LENGTH - 1),
RandomStringUtils.randomAlphanumeric(DeviceController.MAX_TOKEN_IDENTIFIER_LENGTH + 1));
}
@Test
void waitForLinkedDeviceRateLimited() throws RateLimitExceededException {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
doThrow(new RateLimitExceededException(null)).when(rateLimiter).validate(AuthHelper.VALID_UUID);
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(429, response.getStatus());
}
}
}

View File

@ -43,6 +43,7 @@ import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
@ -142,6 +143,7 @@ public class AccountCreationDeletionIntegrationTest {
accounts,
phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
mock(FaultTolerantRedisClient.class),
accountLockManager,
keysManager,
messagesManager,

View File

@ -36,6 +36,7 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
@ -137,6 +138,7 @@ class AccountsManagerChangeNumberIntegrationTest {
accounts,
phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
mock(FaultTolerantRedisClient.class),
accountLockManager,
keysManager,
messagesManager,

View File

@ -48,6 +48,7 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
@ -124,6 +125,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
accounts,
phoneNumberIdentifiers,
RedisClusterHelper.builder().stringCommands(commands).build(),
mock(FaultTolerantRedisClient.class),
accountLockManager,
mock(KeysManager.class),
mock(MessagesManager.class),

View File

@ -33,6 +33,7 @@ import static org.mockito.Mockito.when;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import io.lettuce.core.RedisException;
import io.lettuce.core.api.async.RedisAsyncCommands;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.io.InputStream;
@ -78,6 +79,7 @@ import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecoveryException;
@ -88,6 +90,7 @@ import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.tests.util.RedisServerHelper;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock;
@ -119,8 +122,9 @@ class AccountsManagerTest {
private Map<String, UUID> phoneNumberIdentifiersByE164;
private RedisAdvancedClusterCommands<String, String> commands;
private RedisAdvancedClusterAsyncCommands<String, String> asyncCommands;
private RedisAsyncCommands<String, String> asyncCommands;
private RedisAdvancedClusterCommands<String, String> clusterCommands;
private RedisAdvancedClusterAsyncCommands<String, String> asyncClusterCommands;
private AccountsManager accountsManager;
private SecureValueRecovery2Client svr2Client;
private DynamicConfiguration dynamicConfiguration;
@ -162,14 +166,18 @@ class AccountsManagerTest {
}).when(clientPresenceExecutor).execute(any());
//noinspection unchecked
commands = mock(RedisAdvancedClusterCommands.class);
asyncCommands = mock(RedisAsyncCommands.class);
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
//noinspection unchecked
asyncCommands = mock(RedisAdvancedClusterAsyncCommands.class);
when(asyncCommands.del(any(String[].class))).thenReturn(MockRedisFuture.completedFuture(0L));
when(asyncCommands.get(any())).thenReturn(MockRedisFuture.completedFuture(null));
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
clusterCommands = mock(RedisAdvancedClusterCommands.class);
//noinspection unchecked
asyncClusterCommands = mock(RedisAdvancedClusterAsyncCommands.class);
when(asyncClusterCommands.del(any(String[].class))).thenReturn(MockRedisFuture.completedFuture(0L));
when(asyncClusterCommands.get(any())).thenReturn(MockRedisFuture.completedFuture(null));
when(asyncClusterCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(accounts.updateAsync(any())).thenReturn(CompletableFuture.completedFuture(null));
when(accounts.updateTransactionallyAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
@ -230,15 +238,20 @@ class AccountsManagerTest {
CLOCK = TestClock.now();
final FaultTolerantRedisClusterClient redisCluster = RedisClusterHelper.builder()
.stringCommands(commands)
final FaultTolerantRedisClient pubSubClient = RedisServerHelper.builder()
.stringAsyncCommands(asyncCommands)
.build();
final FaultTolerantRedisClusterClient redisCluster = RedisClusterHelper.builder()
.stringCommands(clusterCommands)
.stringAsyncCommands(asyncClusterCommands)
.build();
accountsManager = new AccountsManager(
accounts,
phoneNumberIdentifiers,
redisCluster,
pubSubClient,
accountLockManager,
keysManager,
messagesManager,
@ -285,8 +298,8 @@ class AccountsManagerTest {
final UUID aci = UUID.randomUUID();
final UUID pni = UUID.randomUUID();
when(commands.get(eq("AccountMap::" + pni))).thenReturn(aci.toString());
when(commands.get(eq("Account3::" + aci))).thenReturn(
when(clusterCommands.get(eq("AccountMap::" + pni))).thenReturn(aci.toString());
when(clusterCommands.get(eq("Account3::" + aci))).thenReturn(
"{\"number\": \"+14152222222\", \"pni\": \"" + pni + "\"}");
assertTrue(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(aci)).isPresent());
@ -300,11 +313,11 @@ class AccountsManagerTest {
final UUID aci = UUID.randomUUID();
final UUID pni = UUID.randomUUID();
when(asyncCommands.get(eq("AccountMap::" + pni))).thenReturn(MockRedisFuture.completedFuture(aci.toString()));
when(asyncCommands.get(eq("Account3::" + aci))).thenReturn(MockRedisFuture.completedFuture(
when(asyncClusterCommands.get(eq("AccountMap::" + pni))).thenReturn(MockRedisFuture.completedFuture(aci.toString()));
when(asyncClusterCommands.get(eq("Account3::" + aci))).thenReturn(MockRedisFuture.completedFuture(
"{\"number\": \"+14152222222\", \"pni\": \"" + pni + "\"}"));
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(accounts.getByAccountIdentifierAsync(any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
@ -323,7 +336,7 @@ class AccountsManagerTest {
void testGetAccountByUuidInCache() {
UUID uuid = UUID.randomUUID();
when(commands.get(eq("Account3::" + uuid))).thenReturn(
when(clusterCommands.get(eq("Account3::" + uuid))).thenReturn(
"{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}");
Optional<Account> account = accountsManager.getByAccountIdentifier(uuid);
@ -333,8 +346,8 @@ class AccountsManagerTest {
assertEquals(account.get().getUuid(), uuid);
assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier());
verify(commands, times(1)).get(eq("Account3::" + uuid));
verifyNoMoreInteractions(commands);
verify(clusterCommands, times(1)).get(eq("Account3::" + uuid));
verifyNoMoreInteractions(clusterCommands);
verifyNoInteractions(accounts);
}
@ -343,10 +356,10 @@ class AccountsManagerTest {
void testGetAccountByUuidInCacheAsync() {
UUID uuid = UUID.randomUUID();
when(asyncCommands.get(eq("Account3::" + uuid))).thenReturn(MockRedisFuture.completedFuture(
when(asyncClusterCommands.get(eq("Account3::" + uuid))).thenReturn(MockRedisFuture.completedFuture(
"{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}"));
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
Optional<Account> account = accountsManager.getByAccountIdentifierAsync(uuid).join();
@ -355,8 +368,8 @@ class AccountsManagerTest {
assertEquals(account.get().getUuid(), uuid);
assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier());
verify(asyncCommands, times(1)).get(eq("Account3::" + uuid));
verifyNoMoreInteractions(asyncCommands);
verify(asyncClusterCommands, times(1)).get(eq("Account3::" + uuid));
verifyNoMoreInteractions(asyncClusterCommands);
verifyNoInteractions(accounts);
}
@ -366,8 +379,8 @@ class AccountsManagerTest {
UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID();
when(commands.get(eq("AccountMap::" + pni))).thenReturn(uuid.toString());
when(commands.get(eq("Account3::" + uuid))).thenReturn(
when(clusterCommands.get(eq("AccountMap::" + pni))).thenReturn(uuid.toString());
when(clusterCommands.get(eq("Account3::" + uuid))).thenReturn(
"{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}");
Optional<Account> account = accountsManager.getByPhoneNumberIdentifier(pni);
@ -376,9 +389,9 @@ class AccountsManagerTest {
assertEquals(account.get().getNumber(), "+14152222222");
assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier());
verify(commands).get(eq("AccountMap::" + pni));
verify(commands).get(eq("Account3::" + uuid));
verifyNoMoreInteractions(commands);
verify(clusterCommands).get(eq("AccountMap::" + pni));
verify(clusterCommands).get(eq("Account3::" + uuid));
verifyNoMoreInteractions(clusterCommands);
verifyNoInteractions(accounts);
}
@ -388,13 +401,13 @@ class AccountsManagerTest {
UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID();
when(asyncCommands.get(eq("AccountMap::" + pni)))
when(asyncClusterCommands.get(eq("AccountMap::" + pni)))
.thenReturn(MockRedisFuture.completedFuture(uuid.toString()));
when(asyncCommands.get(eq("Account3::" + uuid))).thenReturn(MockRedisFuture.completedFuture(
when(asyncClusterCommands.get(eq("Account3::" + uuid))).thenReturn(MockRedisFuture.completedFuture(
"{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}"));
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
Optional<Account> account = accountsManager.getByPhoneNumberIdentifierAsync(pni).join();
@ -402,9 +415,9 @@ class AccountsManagerTest {
assertEquals(account.get().getNumber(), "+14152222222");
assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier());
verify(asyncCommands).get(eq("AccountMap::" + pni));
verify(asyncCommands).get(eq("Account3::" + uuid));
verifyNoMoreInteractions(asyncCommands);
verify(asyncClusterCommands).get(eq("AccountMap::" + pni));
verify(asyncClusterCommands).get(eq("Account3::" + uuid));
verifyNoMoreInteractions(asyncClusterCommands);
verifyNoInteractions(accounts);
}
@ -415,7 +428,7 @@ class AccountsManagerTest {
UUID pni = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
when(clusterCommands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifier(eq(uuid))).thenReturn(Optional.of(account));
Optional<Account> retrieved = accountsManager.getByAccountIdentifier(uuid);
@ -423,10 +436,10 @@ class AccountsManagerTest {
assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account);
verify(commands, times(1)).get(eq("Account3::" + uuid));
verify(commands, times(1)).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(commands, times(1)).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(commands);
verify(clusterCommands, times(1)).get(eq("Account3::" + uuid));
verify(clusterCommands, times(1)).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(clusterCommands, times(1)).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(clusterCommands);
verify(accounts, times(1)).getByAccountIdentifier(eq(uuid));
verifyNoMoreInteractions(accounts);
@ -438,8 +451,8 @@ class AccountsManagerTest {
UUID pni = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(asyncCommands.get(eq("Account3::" + uuid))).thenReturn(MockRedisFuture.completedFuture(null));
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(asyncClusterCommands.get(eq("Account3::" + uuid))).thenReturn(MockRedisFuture.completedFuture(null));
when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(accounts.getByAccountIdentifierAsync(eq(uuid)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
@ -448,10 +461,10 @@ class AccountsManagerTest {
assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account);
verify(asyncCommands).get(eq("Account3::" + uuid));
verify(asyncCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(asyncCommands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(asyncCommands);
verify(asyncClusterCommands).get(eq("Account3::" + uuid));
verify(asyncClusterCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(asyncClusterCommands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(asyncClusterCommands);
verify(accounts).getByAccountIdentifierAsync(eq(uuid));
verifyNoMoreInteractions(accounts);
@ -464,7 +477,7 @@ class AccountsManagerTest {
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(commands.get(eq("AccountMap::" + pni))).thenReturn(null);
when(clusterCommands.get(eq("AccountMap::" + pni))).thenReturn(null);
when(accounts.getByPhoneNumberIdentifier(pni)).thenReturn(Optional.of(account));
Optional<Account> retrieved = accountsManager.getByPhoneNumberIdentifier(pni);
@ -472,10 +485,10 @@ class AccountsManagerTest {
assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account);
verify(commands).get(eq("AccountMap::" + pni));
verify(commands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(commands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(commands);
verify(clusterCommands).get(eq("AccountMap::" + pni));
verify(clusterCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(clusterCommands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(clusterCommands);
verify(accounts).getByPhoneNumberIdentifier(pni);
verifyNoMoreInteractions(accounts);
@ -488,8 +501,8 @@ class AccountsManagerTest {
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(asyncCommands.get(eq("AccountMap::" + pni))).thenReturn(MockRedisFuture.completedFuture(null));
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(asyncClusterCommands.get(eq("AccountMap::" + pni))).thenReturn(MockRedisFuture.completedFuture(null));
when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(accounts.getByPhoneNumberIdentifierAsync(pni))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
@ -498,10 +511,10 @@ class AccountsManagerTest {
assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account);
verify(asyncCommands).get(eq("AccountMap::" + pni));
verify(asyncCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(asyncCommands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(asyncCommands);
verify(asyncClusterCommands).get(eq("AccountMap::" + pni));
verify(asyncClusterCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(asyncClusterCommands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(asyncClusterCommands);
verify(accounts).getByPhoneNumberIdentifierAsync(pni);
verifyNoMoreInteractions(accounts);
@ -528,7 +541,7 @@ class AccountsManagerTest {
UUID pni = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(commands.get(eq("Account3::" + uuid))).thenThrow(new RedisException("Connection lost!"));
when(clusterCommands.get(eq("Account3::" + uuid))).thenThrow(new RedisException("Connection lost!"));
when(accounts.getByAccountIdentifier(eq(uuid))).thenReturn(Optional.of(account));
Optional<Account> retrieved = accountsManager.getByAccountIdentifier(uuid);
@ -536,10 +549,10 @@ class AccountsManagerTest {
assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account);
verify(commands, times(1)).get(eq("Account3::" + uuid));
verify(commands, times(1)).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(commands, times(1)).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(commands);
verify(clusterCommands, times(1)).get(eq("Account3::" + uuid));
verify(clusterCommands, times(1)).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(clusterCommands, times(1)).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(clusterCommands);
verify(accounts, times(1)).getByAccountIdentifier(eq(uuid));
verifyNoMoreInteractions(accounts);
@ -551,10 +564,10 @@ class AccountsManagerTest {
UUID pni = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(asyncCommands.get(eq("Account3::" + uuid)))
when(asyncClusterCommands.get(eq("Account3::" + uuid)))
.thenReturn(MockRedisFuture.failedFuture(new RedisException("Connection lost!")));
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(accounts.getByAccountIdentifierAsync(eq(uuid)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
@ -564,10 +577,10 @@ class AccountsManagerTest {
assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account);
verify(asyncCommands).get(eq("Account3::" + uuid));
verify(asyncCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(asyncCommands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(asyncCommands);
verify(asyncClusterCommands).get(eq("Account3::" + uuid));
verify(asyncClusterCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(asyncClusterCommands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(asyncClusterCommands);
verify(accounts).getByAccountIdentifierAsync(eq(uuid));
verifyNoMoreInteractions(accounts);
@ -580,7 +593,7 @@ class AccountsManagerTest {
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(commands.get(eq("AccountMap::" + pni))).thenThrow(new RedisException("OH NO"));
when(clusterCommands.get(eq("AccountMap::" + pni))).thenThrow(new RedisException("OH NO"));
when(accounts.getByPhoneNumberIdentifier(pni)).thenReturn(Optional.of(account));
Optional<Account> retrieved = accountsManager.getByPhoneNumberIdentifier(pni);
@ -588,10 +601,10 @@ class AccountsManagerTest {
assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account);
verify(commands).get(eq("AccountMap::" + pni));
verify(commands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(commands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(commands);
verify(clusterCommands).get(eq("AccountMap::" + pni));
verify(clusterCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(clusterCommands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(clusterCommands);
verify(accounts).getByPhoneNumberIdentifier(pni);
verifyNoMoreInteractions(accounts);
@ -604,10 +617,10 @@ class AccountsManagerTest {
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(asyncCommands.get(eq("AccountMap::" + pni)))
when(asyncClusterCommands.get(eq("AccountMap::" + pni)))
.thenReturn(MockRedisFuture.failedFuture(new RedisException("OH NO")));
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(accounts.getByPhoneNumberIdentifierAsync(pni))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
@ -617,10 +630,10 @@ class AccountsManagerTest {
assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account);
verify(asyncCommands).get(eq("AccountMap::" + pni));
verify(asyncCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(asyncCommands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(asyncCommands);
verify(asyncClusterCommands).get(eq("AccountMap::" + pni));
verify(asyncClusterCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(asyncClusterCommands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(asyncClusterCommands);
verify(accounts).getByPhoneNumberIdentifierAsync(pni);
verifyNoMoreInteractions(accounts);
@ -632,7 +645,7 @@ class AccountsManagerTest {
UUID pni = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
when(clusterCommands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifier(uuid)).thenReturn(
Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH])));
@ -658,7 +671,7 @@ class AccountsManagerTest {
UUID pni = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(asyncCommands.get(eq("Account3::" + uuid))).thenReturn(null);
when(asyncClusterCommands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifierAsync(uuid)).thenReturn(CompletableFuture.completedFuture(
Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]))));
@ -684,7 +697,7 @@ class AccountsManagerTest {
UUID uuid = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null);
when(clusterCommands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifier(uuid)).thenReturn(Optional.empty())
.thenReturn(Optional.of(account));
when(accounts.create(any(), any())).thenThrow(ContestedOptimisticLockException.class);
@ -971,7 +984,7 @@ class AccountsManagerTest {
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey),
accountsManager.generateDeviceLinkingToken(aci))
accountsManager.generateLinkDeviceToken(aci))
.join();
verify(keysManager).deleteSingleUsePreKeys(aci, nextDeviceId);
@ -1606,7 +1619,7 @@ class AccountsManagerTest {
final UUID aci = UUID.randomUUID();
assertEquals(Optional.of(aci),
accountsManager.checkDeviceLinkingToken(accountsManager.generateDeviceLinkingToken(aci)));
accountsManager.checkDeviceLinkingToken(accountsManager.generateLinkDeviceToken(aci)));
}
@ParameterizedTest
@ -1622,7 +1635,7 @@ class AccountsManagerTest {
return Stream.of(
// Expired token
Arguments.of(AccountsManager.generateDeviceLinkingToken(UUID.randomUUID(),
Arguments.of(AccountsManager.generateLinkDeviceToken(UUID.randomUUID(),
new SecretKeySpec(LINK_DEVICE_SECRET, AccountsManager.LINK_DEVICE_VERIFICATION_TOKEN_ALGORITHM),
CLOCK),
tokenTimestamp.plus(AccountsManager.LINK_DEVICE_TOKEN_EXPIRATION_DURATION).plusSeconds(1)),

View File

@ -36,6 +36,7 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
@ -137,6 +138,7 @@ class AccountsManagerUsernameIntegrationTest {
accounts,
phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
mock(FaultTolerantRedisClient.class),
accountLockManager,
keysManager,
messageManager,

View File

@ -13,6 +13,7 @@ import static org.mockito.Mockito.when;
import com.google.i18n.phonenumbers.PhoneNumberUtil;
import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.ZoneId;
import java.util.Optional;
@ -29,9 +30,11 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.redis.RedisServerExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
@ -59,6 +62,9 @@ public class AddRemoveDeviceIntegrationTest {
@RegisterExtension
static final RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@RegisterExtension
static final RedisServerExtension PUBSUB_SERVER_EXTENSION = RedisServerExtension.builder().build();
private static final Clock CLOCK = Clock.fixed(Instant.now(), ZoneId.systemDefault());
private ExecutorService accountLockExecutor;
@ -128,10 +134,16 @@ public class AddRemoveDeviceIntegrationTest {
when(registrationRecoveryPasswordsManager.removeForNumber(any()))
.thenReturn(CompletableFuture.completedFuture(null));
PUBSUB_SERVER_EXTENSION.getRedisClient().useConnection(connection -> {
connection.sync().flushall();
connection.sync().configSet("notify-keyspace-events", "K$");
});
accountsManager = new AccountsManager(
accounts,
phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(),
PUBSUB_SERVER_EXTENSION.getRedisClient(),
accountLockManager,
keysManager,
messagesManager,
@ -146,10 +158,14 @@ public class AddRemoveDeviceIntegrationTest {
CLOCK,
"link-device-secret".getBytes(StandardCharsets.UTF_8),
dynamicConfigurationManager);
accountsManager.start();
}
@AfterEach
void tearDown() throws InterruptedException {
accountsManager.stop();
accountLockExecutor.shutdown();
clientPresenceExecutor.shutdown();
@ -187,7 +203,7 @@ public class AddRemoveDeviceIntegrationTest {
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)),
accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI)))
accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI)))
.join();
assertEquals(2, updatedAccountAndDevice.first().getDevices().size());
@ -216,7 +232,7 @@ public class AddRemoveDeviceIntegrationTest {
final Account account = AccountsHelper.createAccount(accountsManager, number);
assertEquals(1, accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getDevices().size());
final String linkDeviceToken = accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI));
final String linkDeviceToken = accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI));
final Pair<Account, Device> updatedAccountAndDevice =
accountsManager.addDevice(account, new DeviceSpec(
@ -292,7 +308,7 @@ public class AddRemoveDeviceIntegrationTest {
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)),
accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI)))
accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI)))
.join();
final byte addedDeviceId = updatedAccountAndDevice.second().getId();
@ -346,7 +362,7 @@ public class AddRemoveDeviceIntegrationTest {
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)),
accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI)))
accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI)))
.join();
final byte addedDeviceId = updatedAccountAndDevice.second().getId();
@ -376,4 +392,110 @@ public class AddRemoveDeviceIntegrationTest {
assertTrue(keysManager.getLastResort(retrievedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent());
assertTrue(clientPublicKeysManager.findPublicKey(retrievedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent());
}
@Test
void waitForNewLinkedDevice() throws InterruptedException {
final String number = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
final Account account = AccountsHelper.createAccount(accountsManager, number);
final String linkDeviceToken = accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI));
final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken);
final CompletableFuture<Optional<DeviceInfo>> displacedFuture =
accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofSeconds(5));
final CompletableFuture<Optional<DeviceInfo>> activeFuture =
accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofSeconds(5));
assertEquals(Optional.empty(), displacedFuture.join());
final Pair<Account, Device> updatedAccountAndDevice =
accountsManager.addDevice(account, new DeviceSpec(
"device-name".getBytes(StandardCharsets.UTF_8),
"password",
"OWT",
new Device.DeviceCapabilities(true, true, true, false),
1,
2,
true,
Optional.empty(),
Optional.empty(),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)),
linkDeviceToken)
.join();
final Optional<DeviceInfo> maybeDeviceInfo = activeFuture.join();
assertTrue(maybeDeviceInfo.isPresent());
final DeviceInfo deviceInfo = maybeDeviceInfo.get();
assertEquals(updatedAccountAndDevice.second().getId(), deviceInfo.id());
assertEquals(updatedAccountAndDevice.second().getCreated(), deviceInfo.created());
}
@Test
void waitForNewLinkedDeviceAlreadyAdded() throws InterruptedException {
final String number = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
final Account account = AccountsHelper.createAccount(accountsManager, number);
final String linkDeviceToken = accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI));
final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken);
final Pair<Account, Device> updatedAccountAndDevice =
accountsManager.addDevice(account, new DeviceSpec(
"device-name".getBytes(StandardCharsets.UTF_8),
"password",
"OWT",
new Device.DeviceCapabilities(true, true, true, false),
1,
2,
true,
Optional.empty(),
Optional.empty(),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)),
linkDeviceToken)
.join();
final CompletableFuture<Optional<DeviceInfo>> linkedDeviceFuture =
accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofMinutes(1));
final Optional<DeviceInfo> maybeDeviceInfo = linkedDeviceFuture.join();
assertTrue(maybeDeviceInfo.isPresent());
final DeviceInfo deviceInfo = maybeDeviceInfo.get();
assertEquals(updatedAccountAndDevice.second().getId(), deviceInfo.id());
assertEquals(updatedAccountAndDevice.second().getCreated(), deviceInfo.created());
}
@Test
void waitForNewLinkedDeviceTimeout() {
final String linkDeviceToken = accountsManager.generateLinkDeviceToken(UUID.randomUUID());
final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken);
final CompletableFuture<Optional<DeviceInfo>> linkedDeviceFuture =
accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofMillis(10));
final Optional<DeviceInfo> maybeDeviceInfo = linkedDeviceFuture.join();
assertTrue(maybeDeviceInfo.isEmpty());
}
}

View File

@ -0,0 +1,112 @@
package org.whispersystems.textsecuregcm.tests.util;
import io.lettuce.core.api.StatefulRedisConnection;
import io.lettuce.core.api.async.RedisAsyncCommands;
import io.lettuce.core.api.reactive.RedisReactiveCommands;
import io.lettuce.core.api.sync.RedisCommands;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import java.util.function.Consumer;
import java.util.function.Function;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class RedisServerHelper {
public static RedisServerHelper.Builder builder() {
return new RedisServerHelper.Builder();
}
@SuppressWarnings("unchecked")
private static FaultTolerantRedisClient buildMockRedisClient(
final RedisCommands<String, String> stringCommands,
final RedisAsyncCommands<String, String> stringAsyncCommands,
final RedisCommands<byte[], byte[]> binaryCommands,
final RedisAsyncCommands<byte[], byte[]> binaryAsyncCommands,
final RedisReactiveCommands<byte[], byte[]> binaryReactiveCommands) {
final FaultTolerantRedisClient client = mock(FaultTolerantRedisClient.class);
final StatefulRedisConnection<String, String> stringConnection = mock(StatefulRedisConnection.class);
final StatefulRedisConnection<byte[], byte[]> binaryConnection = mock(StatefulRedisConnection.class);
when(stringConnection.sync()).thenReturn(stringCommands);
when(stringConnection.async()).thenReturn(stringAsyncCommands);
when(binaryConnection.sync()).thenReturn(binaryCommands);
when(binaryConnection.async()).thenReturn(binaryAsyncCommands);
when(binaryConnection.reactive()).thenReturn(binaryReactiveCommands);
when(client.withConnection(any(Function.class))).thenAnswer(invocation -> {
return invocation.getArgument(0, Function.class).apply(stringConnection);
});
doAnswer(invocation -> {
invocation.getArgument(0, Consumer.class).accept(stringConnection);
return null;
}).when(client).useConnection(any(Consumer.class));
when(client.withBinaryConnection(any(Function.class))).thenAnswer(invocation -> {
return invocation.getArgument(0, Function.class).apply(binaryConnection);
});
doAnswer(invocation -> {
invocation.getArgument(0, Consumer.class).accept(binaryConnection);
return null;
}).when(client).useBinaryConnection(any(Consumer.class));
return client;
}
@SuppressWarnings("unchecked")
public static class Builder {
private RedisCommands<String, String> stringCommands = mock(RedisCommands.class);
private RedisAsyncCommands<String, String> stringAsyncCommands = mock(RedisAsyncCommands.class);
private RedisCommands<byte[], byte[]> binaryCommands = mock(RedisCommands.class);
private RedisAsyncCommands<byte[], byte[]> binaryAsyncCommands =
mock(RedisAsyncCommands.class);
private RedisReactiveCommands<byte[], byte[]> binaryReactiveCommands =
mock(RedisReactiveCommands.class);
private Builder() {
}
public RedisServerHelper.Builder stringCommands(final RedisCommands<String, String> stringCommands) {
this.stringCommands = stringCommands;
return this;
}
public RedisServerHelper.Builder stringAsyncCommands(final RedisAsyncCommands<String, String> stringAsyncCommands) {
this.stringAsyncCommands = stringAsyncCommands;
return this;
}
public RedisServerHelper.Builder binaryCommands(final RedisCommands<byte[], byte[]> binaryCommands) {
this.binaryCommands = binaryCommands;
return this;
}
public RedisServerHelper.Builder binaryAsyncCommands(final RedisAsyncCommands<byte[], byte[]> binaryAsyncCommands) {
this.binaryAsyncCommands = binaryAsyncCommands;
return this;
}
public RedisServerHelper.Builder binaryReactiveCommands(
final RedisReactiveCommands<byte[], byte[]> binaryReactiveCommands) {
this.binaryReactiveCommands = binaryReactiveCommands;
return this;
}
public FaultTolerantRedisClient build() {
return RedisServerHelper.buildMockRedisClient(stringCommands,
stringAsyncCommands,
binaryCommands,
binaryAsyncCommands,
binaryReactiveCommands);
}
}
}