From 8c30a359e7457ca60731841726971bd501d78acc Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Thu, 10 Oct 2024 10:11:32 -0400 Subject: [PATCH] Add API endpoints for waiting for newly-linked devices --- .../textsecuregcm/WhisperServerService.java | 3 +- .../controllers/DeviceController.java | 155 +++++++++++++-- .../textsecuregcm/entities/DeviceInfo.java | 5 + .../textsecuregcm/limits/RateLimiters.java | 5 + .../storage/AccountsManager.java | 140 +++++++++++++- .../textsecuregcm/util/LinkDeviceToken.java | 24 +++ .../textsecuregcm/util/VerificationCode.java | 8 - .../workers/CommandDependencies.java | 5 +- .../controllers/DeviceControllerTest.java | 141 +++++++++++++- ...ccountCreationDeletionIntegrationTest.java | 2 + ...ntsManagerChangeNumberIntegrationTest.java | 2 + ...ConcurrentModificationIntegrationTest.java | 2 + .../storage/AccountsManagerTest.java | 179 ++++++++++-------- ...ccountsManagerUsernameIntegrationTest.java | 2 + .../AddRemoveDeviceIntegrationTest.java | 130 ++++++++++++- .../tests/util/RedisServerHelper.java | 112 +++++++++++ 16 files changed, 793 insertions(+), 122 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/util/LinkDeviceToken.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/util/VerificationCode.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisServerHelper.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 104a68f94..fb47fc141 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -642,7 +642,7 @@ public class WhisperServerService extends Application maxDeviceConfiguration; + private final EnumMap linkedDeviceListenersByPlatform; + private final AtomicInteger linkedDeviceListenersForUnrecognizedPlatforms; + + private static final String LINKED_DEVICE_LISTENER_GAUGE_NAME = + MetricsUtil.name(DeviceController.class, "linkedDeviceListeners"); + + private static final String WAIT_FOR_LINKED_DEVICE_TIMER_NAME = + MetricsUtil.name(DeviceController.class, "waitForLinkedDeviceDuration"); + + @VisibleForTesting + static final int MIN_TOKEN_IDENTIFIER_LENGTH = 32; + + @VisibleForTesting + static final int MAX_TOKEN_IDENTIFIER_LENGTH = 64; + public DeviceController(final AccountsManager accounts, final ClientPublicKeysManager clientPublicKeysManager, final RateLimiters rateLimiters, @@ -78,19 +115,32 @@ public class DeviceController { this.clientPublicKeysManager = clientPublicKeysManager; this.rateLimiters = rateLimiters; this.maxDeviceConfiguration = maxDeviceConfiguration; + + linkedDeviceListenersByPlatform = Arrays.stream(ClientPlatform.values()) + .collect(Collectors.toMap( + Function.identity(), + clientPlatform -> buildGauge(clientPlatform.name().toLowerCase()), + (a, b) -> { + throw new AssertionError("Duplicate client platform enumeration key"); + }, + () -> new EnumMap<>(ClientPlatform.class) + )); + + linkedDeviceListenersForUnrecognizedPlatforms = buildGauge("unknown"); + } + + private static AtomicInteger buildGauge(final String clientPlatformName) { + return Metrics.gauge(LINKED_DEVICE_LISTENER_GAUGE_NAME, + Tags.of(io.micrometer.core.instrument.Tag.of(UserAgentTagUtil.PLATFORM_TAG, clientPlatformName)), + new AtomicInteger(0)); } @GET @Produces(MediaType.APPLICATION_JSON) public DeviceInfoList getDevices(@ReadOnly @Auth AuthenticatedDevice auth) { - List devices = new LinkedList<>(); - - for (Device device : auth.getAccount().getDevices()) { - devices.add(new DeviceInfo(device.getId(), device.getName(), - device.getLastSeen(), device.getCreated())); - } - - return new DeviceInfoList(devices); + return new DeviceInfoList(auth.getAccount().getDevices().stream() + .map(DeviceInfo::forDevice) + .toList()); } @DELETE @@ -138,7 +188,7 @@ public class DeviceController { @ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header( name = "Retry-After", description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed")) - public VerificationCode createDeviceToken(@ReadOnly @Auth AuthenticatedDevice auth) + public LinkDeviceToken createDeviceToken(@ReadOnly @Auth AuthenticatedDevice auth) throws RateLimitExceededException, DeviceLimitExceededException { final Account account = auth.getAccount(); @@ -159,7 +209,9 @@ public class DeviceController { throw new WebApplicationException(Response.Status.UNAUTHORIZED); } - return new VerificationCode(accounts.generateDeviceLinkingToken(account.getUuid())); + final String token = accounts.generateLinkDeviceToken(account.getUuid()); + + return new LinkDeviceToken(token, AccountsManager.getLinkDeviceTokenIdentifier(token)); } @PUT @@ -266,6 +318,83 @@ public class DeviceController { } } + @GET + @Path("/wait_for_linked_device/{tokenIdentifier}") + @Produces(MediaType.APPLICATION_JSON) + @Operation(summary = "Wait for a new device to be linked to an account", + description = """ + Waits for a new device to be linked to an account and returns basic information about the new device when + available. + """) + @ApiResponse(responseCode = "200", description = "The specified was linked to an account") + @ApiResponse(responseCode = "204", description = "No device was linked to the account before the call completed") + @ApiResponse(responseCode = "400", description = "The given token identifier or timeout was invalid") + @ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay") + @Schema(description = "Basic information about the linked device", implementation = DeviceInfo.class) + public CompletableFuture waitForLinkedDevice( + @ReadOnly @Auth final AuthenticatedDevice authenticatedDevice, + + @PathParam("tokenIdentifier") + @Schema(description = "A 'link device' token identifier provided by the 'create link device token' endpoint") + @Size(min = MIN_TOKEN_IDENTIFIER_LENGTH, max = MAX_TOKEN_IDENTIFIER_LENGTH) + final String tokenIdentifier, + + @QueryParam("timeout") + @DefaultValue("30") + @Min(1) + @Max(3600) + @Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, + minimum = "1", + maximum = "3600", + description = """ + The amount of time (in seconds) to wait for a response. If the expected device is not linked within the + given amount of time, this endpoint will return a status of HTTP/204. + """) final int timeoutSeconds, + + @HeaderParam(HttpHeaders.USER_AGENT) String userAgent) throws RateLimitExceededException { + + rateLimiters.getWaitForLinkedDeviceLimiter().validate(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI)); + + final AtomicInteger linkedDeviceListenerCounter = getCounterForLinkedDeviceListeners(userAgent); + linkedDeviceListenerCounter.incrementAndGet(); + + final Timer.Sample sample = Timer.start(); + + try { + return accounts.waitForNewLinkedDevice(tokenIdentifier, Duration.ofSeconds(timeoutSeconds)) + .thenApply(maybeDeviceInfo -> maybeDeviceInfo + .map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build()) + .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())) + .exceptionally(ExceptionUtils.exceptionallyHandler(IllegalArgumentException.class, + e -> Response.status(Response.Status.BAD_REQUEST).build())) + .whenComplete((response, throwable) -> { + linkedDeviceListenerCounter.decrementAndGet(); + + if (response != null) { + sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME) + .publishPercentileHistogram(true) + .tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), + io.micrometer.core.instrument.Tag.of("deviceFound", + String.valueOf(response.getStatus() == Response.Status.OK.getStatusCode())))) + .register(Metrics.globalRegistry)); + } + }); + } catch (final RedisException e) { + // `waitForNewLinkedDevice` could fail synchronously if the Redis circuit breaker is open; prevent counter drift + // if that happens + linkedDeviceListenerCounter.decrementAndGet(); + throw e; + } + } + + private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) { + try { + return linkedDeviceListenersByPlatform.get(UserAgentUtil.parseUserAgentString(userAgent).getPlatform()); + } catch (final UnrecognizedUserAgentException ignored) { + return linkedDeviceListenersForUnrecognizedPlatforms; + } + } + @PUT @Produces(MediaType.APPLICATION_JSON) @Path("/unauthenticated_delivery") diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/DeviceInfo.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/DeviceInfo.java index 5e96168d2..6f627be3c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/DeviceInfo.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/DeviceInfo.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize; +import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.ByteArrayBase64WithPaddingAdapter; public record DeviceInfo(long id, @@ -17,4 +18,8 @@ public record DeviceInfo(long id, long lastSeen, long created) { + + public static DeviceInfo forDevice(final Device device) { + return new DeviceInfo(device.getId(), device.getName(), device.getLastSeen(), device.getCreated()); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java index 5ae44d05f..02bb63cb5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java @@ -50,6 +50,7 @@ public class RateLimiters extends BaseRateLimiters { EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", true, new RateLimiterConfig(100, Duration.ofMinutes(15))), KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", true, new RateLimiterConfig(100, Duration.ofSeconds(15))), KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", true, new RateLimiterConfig(100, Duration.ofSeconds(15))), + WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", true, new RateLimiterConfig(10, Duration.ofSeconds(30))), ; private final String id; @@ -205,4 +206,8 @@ public class RateLimiters extends BaseRateLimiters { public RateLimiter getStoriesLimiter() { return forDescriptor(For.STORIES); } + + public RateLimiter getWaitForLinkedDeviceLimiter() { + return forDescriptor(For.WAIT_FOR_LINKED_DEVICE); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 8f9610a65..63e0fc2d7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -12,8 +12,11 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectWriter; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import io.dropwizard.lifecycle.Managed; import io.lettuce.core.RedisException; +import io.lettuce.core.SetArgs; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; +import io.lettuce.core.pubsub.RedisPubSubAdapter; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tags; @@ -42,7 +45,9 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiFunction; import java.util.function.Consumer; @@ -61,13 +66,16 @@ import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.entities.AccountAttributes; +import org.whispersystems.textsecuregcm.entities.DeviceInfo; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; @@ -82,14 +90,13 @@ import reactor.core.scheduler.Scheduler; import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem; import software.amazon.awssdk.services.dynamodb.model.TransactionCanceledException; -public class AccountsManager { +public class AccountsManager extends RedisPubSubAdapter implements Managed { private static final Timer createTimer = Metrics.timer(name(AccountsManager.class, "create")); private static final Timer updateTimer = Metrics.timer(name(AccountsManager.class, "update")); private static final Timer getByNumberTimer = Metrics.timer(name(AccountsManager.class, "getByNumber")); private static final Timer getByUsernameHashTimer = Metrics.timer(name(AccountsManager.class, "getByUsernameHash")); - private static final Timer getByUsernameLinkHandleTimer = Metrics.timer( - name(AccountsManager.class, "getByUsernameLinkHandle")); + private static final Timer getByUsernameLinkHandleTimer = Metrics.timer(name(AccountsManager.class, "getByUsernameLinkHandle")); private static final Timer getByUuidTimer = Metrics.timer(name(AccountsManager.class, "getByUuid")); private static final Timer deleteTimer = Metrics.timer(name(AccountsManager.class, "delete")); @@ -108,6 +115,7 @@ public class AccountsManager { private final Accounts accounts; private final PhoneNumberIdentifiers phoneNumberIdentifiers; private final FaultTolerantRedisClusterClient cacheCluster; + private final FaultTolerantRedisClient pubSubRedisSingleton; private final AccountLockManager accountLockManager; private final KeysManager keysManager; private final MessagesManager messagesManager; @@ -124,6 +132,16 @@ public class AccountsManager { private final Key verificationTokenKey; + private final FaultTolerantPubSubConnection pubSubConnection; + + private final Map>> waitForDeviceFuturesByTokenIdentifier = + new ConcurrentHashMap<>(); + + private static final int SHA256_HASH_LENGTH = getSha256MessageDigest().getDigestLength(); + private static final Duration RECENTLY_ADDED_DEVICE_TTL = Duration.ofHours(1); + private static final String LINKED_DEVICE_PREFIX = "linked_device::"; + private static final String LINKED_DEVICE_KEYSPACE_PATTERN = "__keyspace@0__:" + LINKED_DEVICE_PREFIX + "*"; + private static final ObjectWriter ACCOUNT_REDIS_JSON_WRITER = SystemMapper.jsonMapper() .writer(SystemMapper.excludingField(Account.class, List.of("uuid"))); @@ -158,6 +176,7 @@ public class AccountsManager { public AccountsManager(final Accounts accounts, final PhoneNumberIdentifiers phoneNumberIdentifiers, final FaultTolerantRedisClusterClient cacheCluster, + final FaultTolerantRedisClient pubSubRedisSingleton, final AccountLockManager accountLockManager, final KeysManager keysManager, final MessagesManager messagesManager, @@ -175,6 +194,7 @@ public class AccountsManager { this.accounts = accounts; this.phoneNumberIdentifiers = phoneNumberIdentifiers; this.cacheCluster = cacheCluster; + this.pubSubRedisSingleton = pubSubRedisSingleton; this.accountLockManager = accountLockManager; this.keysManager = keysManager; this.messagesManager = messagesManager; @@ -197,6 +217,20 @@ public class AccountsManager { } catch (final InvalidKeyException e) { throw new IllegalArgumentException(e); } + + this.pubSubConnection = pubSubRedisSingleton.createPubSubConnection(); + } + + @Override + public void start() { + pubSubConnection.usePubSubConnection(connection -> connection.addListener(this)); + pubSubConnection.usePubSubConnection(connection -> connection.sync().psubscribe(LINKED_DEVICE_KEYSPACE_PATTERN)); + } + + @Override + public void stop() { + pubSubConnection.usePubSubConnection(connection -> connection.sync().punsubscribe()); + pubSubConnection.usePubSubConnection(connection -> connection.removeListener(this)); } public Account create(final String number, @@ -363,6 +397,26 @@ public class AccountsManager { } return CompletableFuture.failedFuture(throwable); + }) + .whenComplete((updatedAccountAndDevice, throwable) -> { + if (updatedAccountAndDevice != null) { + final String key = getLinkedDeviceKey(getLinkDeviceTokenIdentifier(linkDeviceToken)); + final String deviceInfoJson; + + try { + deviceInfoJson = SystemMapper.jsonMapper().writeValueAsString(DeviceInfo.forDevice(updatedAccountAndDevice.second())); + } catch (final JsonProcessingException e) { + throw new UncheckedIOException(e); + } + + pubSubRedisSingleton.withConnection(connection -> + connection.async().set(key, deviceInfoJson, SetArgs.Builder.ex(RECENTLY_ADDED_DEVICE_TTL))) + .whenComplete((ignored, pubSubThrowable) -> { + if (pubSubThrowable != null) { + logger.warn("Failed to record recently-created device", pubSubThrowable); + } + }); + } }); } @@ -386,7 +440,7 @@ public class AccountsManager { } } - public String generateDeviceLinkingToken(final UUID aci) { + public String generateLinkDeviceToken(final UUID aci) { final String claims = aci + "." + clock.instant().toEpochMilli(); final byte[] signature = getInitializedMac().doFinal(claims.getBytes(StandardCharsets.UTF_8)); @@ -394,7 +448,7 @@ public class AccountsManager { } @VisibleForTesting - static String generateDeviceLinkingToken(final UUID aci, final Key linkDeviceTokenKey, final Clock clock) + static String generateLinkDeviceToken(final UUID aci, final Key linkDeviceTokenKey, final Clock clock) throws InvalidKeyException { final String claims = aci + "." + clock.instant().toEpochMilli(); @@ -403,6 +457,11 @@ public class AccountsManager { return claims + ":" + Base64.getUrlEncoder().encodeToString(signature); } + public static String getLinkDeviceTokenIdentifier(final String linkDeviceToken) { + return Base64.getUrlEncoder().withoutPadding().encodeToString( + getSha256MessageDigest().digest(linkDeviceToken.getBytes(StandardCharsets.UTF_8))); + } + /** * Checks that a device-linking token is valid and returns the account identifier from the token if so, or empty if * the token was invalid @@ -1340,4 +1399,75 @@ public class AccountsManager { .whenComplete((ignoredResult, ignoredException) -> sample.stop(redisDeleteTimer)) .thenRun(Util.NOOP); } + + public CompletableFuture> waitForNewLinkedDevice(final String linkDeviceTokenIdentifier, final Duration timeout) { + // Unbeknownst to callers but beknownst to us, the "link device token identifier" is the base64/url-encoded SHA256 + // hash of a device-linking token. Before we use the string anywhere, make sure it's the right "shape" for a hash. + if (Base64.getUrlDecoder().decode(linkDeviceTokenIdentifier).length != SHA256_HASH_LENGTH) { + return CompletableFuture.failedFuture(new IllegalArgumentException("Invalid token identifier")); + } + + final CompletableFuture> waitForDeviceFuture = new CompletableFuture<>(); + + waitForDeviceFuture + .completeOnTimeout(Optional.empty(), TimeUnit.MILLISECONDS.convert(timeout), TimeUnit.MILLISECONDS) + .whenComplete((maybeDevice, throwable) -> waitForDeviceFuturesByTokenIdentifier.compute(linkDeviceTokenIdentifier, + (ignored, existingFuture) -> { + // Only remove the future from the map if it's THIS future, and not one that later displaced this one + return existingFuture == waitForDeviceFuture ? null : existingFuture; + })); + + { + final CompletableFuture> displacedFuture = + waitForDeviceFuturesByTokenIdentifier.put(linkDeviceTokenIdentifier, waitForDeviceFuture); + + if (displacedFuture != null) { + displacedFuture.complete(Optional.empty()); + } + } + + // The device may already have been linked by the time the caller started watching for it; perform an immediate + // check to see if the device is already there. + pubSubRedisSingleton.withConnection(connection -> connection.async().get(getLinkedDeviceKey(linkDeviceTokenIdentifier))) + .thenAccept(response -> { + if (StringUtils.isNotBlank(response)) { + handleDeviceAdded(waitForDeviceFuture, response); + } + }); + + return waitForDeviceFuture; + } + + private static String getLinkedDeviceKey(final String linkDeviceTokenIdentifier) { + return LINKED_DEVICE_PREFIX + linkDeviceTokenIdentifier; + } + + @Override + public void message(final String pattern, final String channel, final String message) { + if (LINKED_DEVICE_KEYSPACE_PATTERN.equals(pattern) && "set".equalsIgnoreCase(message)) { + // The `- 1` here compensates for the '*' in the pattern + final String tokenIdentifier = channel.substring(LINKED_DEVICE_KEYSPACE_PATTERN.length() - 1); + + Optional.ofNullable(waitForDeviceFuturesByTokenIdentifier.remove(tokenIdentifier)) + .ifPresent(future -> pubSubRedisSingleton.withConnection(connection -> connection.async().get(getLinkedDeviceKey(tokenIdentifier))) + .thenAccept(deviceInfoJson -> handleDeviceAdded(future, deviceInfoJson))); + } + } + + private void handleDeviceAdded(final CompletableFuture> future, final String deviceInfoJson) { + try { + future.complete(Optional.of(SystemMapper.jsonMapper().readValue(deviceInfoJson, DeviceInfo.class))); + } catch (final JsonProcessingException e) { + logger.error("Could not parse device json", e); + future.completeExceptionally(e); + } + } + + private static MessageDigest getSha256MessageDigest() { + try { + return MessageDigest.getInstance("SHA-256"); + } catch (final NoSuchAlgorithmException e) { + throw new AssertionError("Every implementation of the Java platform is required to support the SHA-256 MessageDigest algorithm", e); + } + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/LinkDeviceToken.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/LinkDeviceToken.java new file mode 100644 index 000000000..d132b3fcf --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/LinkDeviceToken.java @@ -0,0 +1,24 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.util; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; +import io.swagger.v3.oas.annotations.media.Schema; + +public record LinkDeviceToken( + @Schema(description = """ + An opaque token to send to a new linked device that authorizes the new device to link itself to the account that + requested this token. + """) + @JsonProperty("verificationCode") String token, + + @Schema(description = """ + An opaque identifier for the generated token that the caller may use to watch for a new device to complete the + linking process. + """) + String tokenIdentifier) { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/VerificationCode.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/VerificationCode.java deleted file mode 100644 index 5d502e1aa..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/VerificationCode.java +++ /dev/null @@ -1,8 +0,0 @@ -/* - * Copyright 2013-2020 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.util; - -public record VerificationCode(String verificationCode) { -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java index 72b292e4f..52b958cad 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java @@ -40,6 +40,7 @@ import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.FcmSender; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.storage.AccountLockManager; @@ -112,6 +113,8 @@ record CommandDependencies( .build("main_cache", redisClientResourcesBuilder); FaultTolerantRedisClusterClient pushSchedulerCluster = configuration.getPushSchedulerCluster() .build("push_scheduler", redisClientResourcesBuilder); + FaultTolerantRedisClient pubsubClient = + configuration.getRedisPubSubConfiguration().build("pubsub", redisClientResourcesBuilder.build()); ScheduledExecutorService recurringJobExecutor = environment.lifecycle() .scheduledExecutorService(name(name, "recurringJob-%d")).threads(2).build(); @@ -225,7 +228,7 @@ record CommandDependencies( ClientPublicKeysManager clientPublicKeysManager = new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor); AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster, - accountLockManager, keys, messagesManager, profilesManager, + pubsubClient, accountLockManager, keys, messagesManager, profilesManager, secureStorageClient, secureValueRecovery2Client, clientPresenceManager, registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, clientPresenceExecutor, clock, configuration.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index 9f28daa5a..1ce2ca208 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -5,12 +5,14 @@ package org.whispersystems.textsecuregcm.controllers; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.clearInvocations; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -18,6 +20,7 @@ import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.amazonaws.util.Base64; import com.google.common.net.HttpHeaders; import io.dropwizard.auth.AuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; @@ -25,6 +28,7 @@ import io.dropwizard.testing.junit5.ResourceExtension; import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -35,6 +39,7 @@ import java.util.stream.Stream; import javax.ws.rs.client.Entity; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; +import org.apache.commons.lang3.RandomStringUtils; import org.glassfish.jersey.server.ServerProperties; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.AfterEach; @@ -54,6 +59,7 @@ import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventLis import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest; +import org.whispersystems.textsecuregcm.entities.DeviceInfo; import org.whispersystems.textsecuregcm.entities.DeviceResponse; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; @@ -64,6 +70,7 @@ import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper; +import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; @@ -79,7 +86,7 @@ import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.TestRandomUtil; -import org.whispersystems.textsecuregcm.util.VerificationCode; +import org.whispersystems.textsecuregcm.util.LinkDeviceToken; @ExtendWith(DropwizardExtensionsSupport.class) class DeviceControllerTest { @@ -112,6 +119,7 @@ class DeviceControllerTest { .addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE) .addProvider(AuthHelper.getAuthFilter()) .addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class)) + .addProvider(new RateLimitExceededExceptionMapper()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager)) .addProvider(new DeviceLimitExceededExceptionMapper()) @@ -122,6 +130,7 @@ class DeviceControllerTest { void setup() { when(rateLimiters.getAllocateDeviceLimiter()).thenReturn(rateLimiter); when(rateLimiters.getVerifyDeviceLimiter()).thenReturn(rateLimiter); + when(rateLimiters.getWaitForLinkedDeviceLimiter()).thenReturn(rateLimiter); when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID); @@ -479,16 +488,17 @@ class DeviceControllerTest { final Optional apnRegistrationId, final Optional gcmRegistrationId) { when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); + when(accountsManager.generateLinkDeviceToken(any())).thenReturn("test"); final Device existingDevice = mock(Device.class); when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID); when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); - VerificationCode deviceCode = resources.getJerseyTest() + final LinkDeviceToken deviceCode = resources.getJerseyTest() .target("/v1/devices/provisioning/code") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .get(VerificationCode.class); + .get(LinkDeviceToken.class); final ECSignedPreKey aciSignedPreKey; final ECSignedPreKey pniSignedPreKey; @@ -506,7 +516,7 @@ class DeviceControllerTest { when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); - final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(), + final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.token(), new AccountAttributes(fetchesMessages, 1234, 5678, null, null, true, null), new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId)); @@ -539,21 +549,22 @@ class DeviceControllerTest { final KEMSignedPreKey pniPqLastResortPreKey) { when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); + when(accountsManager.generateLinkDeviceToken(any())).thenReturn("test"); final Device existingDevice = mock(Device.class); when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID); when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); - VerificationCode deviceCode = resources.getJerseyTest() + final LinkDeviceToken deviceCode = resources.getJerseyTest() .target("/v1/devices/provisioning/code") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .get(VerificationCode.class); + .get(LinkDeviceToken.class); when(account.getIdentityKey(IdentityType.ACI)).thenReturn(aciIdentityKey); when(account.getIdentityKey(IdentityType.PNI)).thenReturn(pniIdentityKey); - final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(), + final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.token(), new AccountAttributes(true, 1234, 5678, null, null, true, null), new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty())); @@ -931,4 +942,120 @@ class DeviceControllerTest { verify(clientPublicKeysManager).setPublicKey(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE.getId(), request.publicKey()); } + + @Test + void waitForLinkedDevice() { + final DeviceInfo deviceInfo = new DeviceInfo(Device.PRIMARY_ID, + "Device name ciphertext".getBytes(StandardCharsets.UTF_8), + System.currentTimeMillis(), + System.currentTimeMillis()); + + final String tokenIdentifier = Base64.encodeAsString(new byte[32]); + + when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo))); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/wait_for_linked_device/" + tokenIdentifier) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get()) { + + assertEquals(200, response.getStatus()); + + final DeviceInfo retrievedDeviceInfo = response.readEntity(DeviceInfo.class); + assertEquals(deviceInfo.id(), retrievedDeviceInfo.id()); + assertArrayEquals(deviceInfo.name(), retrievedDeviceInfo.name()); + assertEquals(deviceInfo.created(), retrievedDeviceInfo.created()); + assertEquals(deviceInfo.lastSeen(), retrievedDeviceInfo.lastSeen()); + } + } + + @Test + void waitForLinkedDeviceNoDeviceLinked() { + final String tokenIdentifier = Base64.encodeAsString(new byte[32]); + + when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any())) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/wait_for_linked_device/" + tokenIdentifier) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get()) { + + assertEquals(204, response.getStatus()); + } + } + + @Test + void waitForLinkedDeviceBadTokenIdentifier() { + final String tokenIdentifier = Base64.encodeAsString(new byte[32]); + + when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any())) + .thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException())); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/wait_for_linked_device/" + tokenIdentifier) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get()) { + + assertEquals(400, response.getStatus()); + } + } + + @ParameterizedTest + @MethodSource + void waitForLinkedDeviceBadTimeout(final int timeoutSeconds) { + final String tokenIdentifier = Base64.encodeAsString(new byte[32]); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/wait_for_linked_device/" + tokenIdentifier) + .queryParam("timeout", timeoutSeconds) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get()) { + + assertEquals(400, response.getStatus()); + } + } + + private static List waitForLinkedDeviceBadTimeout() { + return List.of(0, -1, 3601); + } + + @ParameterizedTest + @MethodSource + void waitForLinkedDeviceBadTokenIdentifierLength(final String tokenIdentifier) { + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/wait_for_linked_device/" + tokenIdentifier) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get()) { + + assertEquals(400, response.getStatus()); + } + } + + private static List waitForLinkedDeviceBadTokenIdentifierLength() { + return List.of(RandomStringUtils.randomAlphanumeric(DeviceController.MIN_TOKEN_IDENTIFIER_LENGTH - 1), + RandomStringUtils.randomAlphanumeric(DeviceController.MAX_TOKEN_IDENTIFIER_LENGTH + 1)); + } + + @Test + void waitForLinkedDeviceRateLimited() throws RateLimitExceededException { + final String tokenIdentifier = Base64.encodeAsString(new byte[32]); + + doThrow(new RateLimitExceededException(null)).when(rateLimiter).validate(AuthHelper.VALID_UUID); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/wait_for_linked_device/" + tokenIdentifier) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get()) { + + assertEquals(429, response.getStatus()); + } + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java index b229b24ce..545e907d7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountCreationDeletionIntegrationTest.java @@ -43,6 +43,7 @@ import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; @@ -142,6 +143,7 @@ public class AccountCreationDeletionIntegrationTest { accounts, phoneNumberIdentifiers, CACHE_CLUSTER_EXTENSION.getRedisCluster(), + mock(FaultTolerantRedisClient.class), accountLockManager, keysManager, messagesManager, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java index 52edd9f02..f6da8755c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java @@ -36,6 +36,7 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; @@ -137,6 +138,7 @@ class AccountsManagerChangeNumberIntegrationTest { accounts, phoneNumberIdentifiers, CACHE_CLUSTER_EXTENSION.getRedisCluster(), + mock(FaultTolerantRedisClient.class), accountLockManager, keysManager, messagesManager, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java index c8a7ee5ed..6801948d9 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java @@ -48,6 +48,7 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; @@ -124,6 +125,7 @@ class AccountsManagerConcurrentModificationIntegrationTest { accounts, phoneNumberIdentifiers, RedisClusterHelper.builder().stringCommands(commands).build(), + mock(FaultTolerantRedisClient.class), accountLockManager, mock(KeysManager.class), mock(MessagesManager.class), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 33d69464c..f2477daaa 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -33,6 +33,7 @@ import static org.mockito.Mockito.when; import com.google.i18n.phonenumbers.PhoneNumberUtil; import io.lettuce.core.RedisException; +import io.lettuce.core.api.async.RedisAsyncCommands; import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import java.io.InputStream; @@ -78,6 +79,7 @@ import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecoveryException; @@ -88,6 +90,7 @@ import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; +import org.whispersystems.textsecuregcm.tests.util.RedisServerHelper; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.TestClock; @@ -119,8 +122,9 @@ class AccountsManagerTest { private Map phoneNumberIdentifiersByE164; - private RedisAdvancedClusterCommands commands; - private RedisAdvancedClusterAsyncCommands asyncCommands; + private RedisAsyncCommands asyncCommands; + private RedisAdvancedClusterCommands clusterCommands; + private RedisAdvancedClusterAsyncCommands asyncClusterCommands; private AccountsManager accountsManager; private SecureValueRecovery2Client svr2Client; private DynamicConfiguration dynamicConfiguration; @@ -162,14 +166,18 @@ class AccountsManagerTest { }).when(clientPresenceExecutor).execute(any()); //noinspection unchecked - commands = mock(RedisAdvancedClusterCommands.class); + asyncCommands = mock(RedisAsyncCommands.class); + when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); //noinspection unchecked - asyncCommands = mock(RedisAdvancedClusterAsyncCommands.class); - when(asyncCommands.del(any(String[].class))).thenReturn(MockRedisFuture.completedFuture(0L)); - when(asyncCommands.get(any())).thenReturn(MockRedisFuture.completedFuture(null)); - when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); - when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); + clusterCommands = mock(RedisAdvancedClusterCommands.class); + + //noinspection unchecked + asyncClusterCommands = mock(RedisAdvancedClusterAsyncCommands.class); + when(asyncClusterCommands.del(any(String[].class))).thenReturn(MockRedisFuture.completedFuture(0L)); + when(asyncClusterCommands.get(any())).thenReturn(MockRedisFuture.completedFuture(null)); + when(asyncClusterCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); + when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); when(accounts.updateAsync(any())).thenReturn(CompletableFuture.completedFuture(null)); when(accounts.updateTransactionallyAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); @@ -230,15 +238,20 @@ class AccountsManagerTest { CLOCK = TestClock.now(); - final FaultTolerantRedisClusterClient redisCluster = RedisClusterHelper.builder() - .stringCommands(commands) + final FaultTolerantRedisClient pubSubClient = RedisServerHelper.builder() .stringAsyncCommands(asyncCommands) .build(); + final FaultTolerantRedisClusterClient redisCluster = RedisClusterHelper.builder() + .stringCommands(clusterCommands) + .stringAsyncCommands(asyncClusterCommands) + .build(); + accountsManager = new AccountsManager( accounts, phoneNumberIdentifiers, redisCluster, + pubSubClient, accountLockManager, keysManager, messagesManager, @@ -285,8 +298,8 @@ class AccountsManagerTest { final UUID aci = UUID.randomUUID(); final UUID pni = UUID.randomUUID(); - when(commands.get(eq("AccountMap::" + pni))).thenReturn(aci.toString()); - when(commands.get(eq("Account3::" + aci))).thenReturn( + when(clusterCommands.get(eq("AccountMap::" + pni))).thenReturn(aci.toString()); + when(clusterCommands.get(eq("Account3::" + aci))).thenReturn( "{\"number\": \"+14152222222\", \"pni\": \"" + pni + "\"}"); assertTrue(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(aci)).isPresent()); @@ -300,11 +313,11 @@ class AccountsManagerTest { final UUID aci = UUID.randomUUID(); final UUID pni = UUID.randomUUID(); - when(asyncCommands.get(eq("AccountMap::" + pni))).thenReturn(MockRedisFuture.completedFuture(aci.toString())); - when(asyncCommands.get(eq("Account3::" + aci))).thenReturn(MockRedisFuture.completedFuture( + when(asyncClusterCommands.get(eq("AccountMap::" + pni))).thenReturn(MockRedisFuture.completedFuture(aci.toString())); + when(asyncClusterCommands.get(eq("Account3::" + aci))).thenReturn(MockRedisFuture.completedFuture( "{\"number\": \"+14152222222\", \"pni\": \"" + pni + "\"}")); - when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); + when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); when(accounts.getByAccountIdentifierAsync(any())) .thenReturn(CompletableFuture.completedFuture(Optional.empty())); @@ -323,7 +336,7 @@ class AccountsManagerTest { void testGetAccountByUuidInCache() { UUID uuid = UUID.randomUUID(); - when(commands.get(eq("Account3::" + uuid))).thenReturn( + when(clusterCommands.get(eq("Account3::" + uuid))).thenReturn( "{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}"); Optional account = accountsManager.getByAccountIdentifier(uuid); @@ -333,8 +346,8 @@ class AccountsManagerTest { assertEquals(account.get().getUuid(), uuid); assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier()); - verify(commands, times(1)).get(eq("Account3::" + uuid)); - verifyNoMoreInteractions(commands); + verify(clusterCommands, times(1)).get(eq("Account3::" + uuid)); + verifyNoMoreInteractions(clusterCommands); verifyNoInteractions(accounts); } @@ -343,10 +356,10 @@ class AccountsManagerTest { void testGetAccountByUuidInCacheAsync() { UUID uuid = UUID.randomUUID(); - when(asyncCommands.get(eq("Account3::" + uuid))).thenReturn(MockRedisFuture.completedFuture( + when(asyncClusterCommands.get(eq("Account3::" + uuid))).thenReturn(MockRedisFuture.completedFuture( "{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}")); - when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); + when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); Optional account = accountsManager.getByAccountIdentifierAsync(uuid).join(); @@ -355,8 +368,8 @@ class AccountsManagerTest { assertEquals(account.get().getUuid(), uuid); assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier()); - verify(asyncCommands, times(1)).get(eq("Account3::" + uuid)); - verifyNoMoreInteractions(asyncCommands); + verify(asyncClusterCommands, times(1)).get(eq("Account3::" + uuid)); + verifyNoMoreInteractions(asyncClusterCommands); verifyNoInteractions(accounts); } @@ -366,8 +379,8 @@ class AccountsManagerTest { UUID uuid = UUID.randomUUID(); UUID pni = UUID.randomUUID(); - when(commands.get(eq("AccountMap::" + pni))).thenReturn(uuid.toString()); - when(commands.get(eq("Account3::" + uuid))).thenReturn( + when(clusterCommands.get(eq("AccountMap::" + pni))).thenReturn(uuid.toString()); + when(clusterCommands.get(eq("Account3::" + uuid))).thenReturn( "{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}"); Optional account = accountsManager.getByPhoneNumberIdentifier(pni); @@ -376,9 +389,9 @@ class AccountsManagerTest { assertEquals(account.get().getNumber(), "+14152222222"); assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier()); - verify(commands).get(eq("AccountMap::" + pni)); - verify(commands).get(eq("Account3::" + uuid)); - verifyNoMoreInteractions(commands); + verify(clusterCommands).get(eq("AccountMap::" + pni)); + verify(clusterCommands).get(eq("Account3::" + uuid)); + verifyNoMoreInteractions(clusterCommands); verifyNoInteractions(accounts); } @@ -388,13 +401,13 @@ class AccountsManagerTest { UUID uuid = UUID.randomUUID(); UUID pni = UUID.randomUUID(); - when(asyncCommands.get(eq("AccountMap::" + pni))) + when(asyncClusterCommands.get(eq("AccountMap::" + pni))) .thenReturn(MockRedisFuture.completedFuture(uuid.toString())); - when(asyncCommands.get(eq("Account3::" + uuid))).thenReturn(MockRedisFuture.completedFuture( + when(asyncClusterCommands.get(eq("Account3::" + uuid))).thenReturn(MockRedisFuture.completedFuture( "{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}")); - when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); + when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); Optional account = accountsManager.getByPhoneNumberIdentifierAsync(pni).join(); @@ -402,9 +415,9 @@ class AccountsManagerTest { assertEquals(account.get().getNumber(), "+14152222222"); assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier()); - verify(asyncCommands).get(eq("AccountMap::" + pni)); - verify(asyncCommands).get(eq("Account3::" + uuid)); - verifyNoMoreInteractions(asyncCommands); + verify(asyncClusterCommands).get(eq("AccountMap::" + pni)); + verify(asyncClusterCommands).get(eq("Account3::" + uuid)); + verifyNoMoreInteractions(asyncClusterCommands); verifyNoInteractions(accounts); } @@ -415,7 +428,7 @@ class AccountsManagerTest { UUID pni = UUID.randomUUID(); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - when(commands.get(eq("Account3::" + uuid))).thenReturn(null); + when(clusterCommands.get(eq("Account3::" + uuid))).thenReturn(null); when(accounts.getByAccountIdentifier(eq(uuid))).thenReturn(Optional.of(account)); Optional retrieved = accountsManager.getByAccountIdentifier(uuid); @@ -423,10 +436,10 @@ class AccountsManagerTest { assertTrue(retrieved.isPresent()); assertSame(retrieved.get(), account); - verify(commands, times(1)).get(eq("Account3::" + uuid)); - verify(commands, times(1)).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); - verify(commands, times(1)).setex(eq("Account3::" + uuid), anyLong(), anyString()); - verifyNoMoreInteractions(commands); + verify(clusterCommands, times(1)).get(eq("Account3::" + uuid)); + verify(clusterCommands, times(1)).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); + verify(clusterCommands, times(1)).setex(eq("Account3::" + uuid), anyLong(), anyString()); + verifyNoMoreInteractions(clusterCommands); verify(accounts, times(1)).getByAccountIdentifier(eq(uuid)); verifyNoMoreInteractions(accounts); @@ -438,8 +451,8 @@ class AccountsManagerTest { UUID pni = UUID.randomUUID(); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - when(asyncCommands.get(eq("Account3::" + uuid))).thenReturn(MockRedisFuture.completedFuture(null)); - when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); + when(asyncClusterCommands.get(eq("Account3::" + uuid))).thenReturn(MockRedisFuture.completedFuture(null)); + when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); when(accounts.getByAccountIdentifierAsync(eq(uuid))) .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); @@ -448,10 +461,10 @@ class AccountsManagerTest { assertTrue(retrieved.isPresent()); assertSame(retrieved.get(), account); - verify(asyncCommands).get(eq("Account3::" + uuid)); - verify(asyncCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); - verify(asyncCommands).setex(eq("Account3::" + uuid), anyLong(), anyString()); - verifyNoMoreInteractions(asyncCommands); + verify(asyncClusterCommands).get(eq("Account3::" + uuid)); + verify(asyncClusterCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); + verify(asyncClusterCommands).setex(eq("Account3::" + uuid), anyLong(), anyString()); + verifyNoMoreInteractions(asyncClusterCommands); verify(accounts).getByAccountIdentifierAsync(eq(uuid)); verifyNoMoreInteractions(accounts); @@ -464,7 +477,7 @@ class AccountsManagerTest { Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - when(commands.get(eq("AccountMap::" + pni))).thenReturn(null); + when(clusterCommands.get(eq("AccountMap::" + pni))).thenReturn(null); when(accounts.getByPhoneNumberIdentifier(pni)).thenReturn(Optional.of(account)); Optional retrieved = accountsManager.getByPhoneNumberIdentifier(pni); @@ -472,10 +485,10 @@ class AccountsManagerTest { assertTrue(retrieved.isPresent()); assertSame(retrieved.get(), account); - verify(commands).get(eq("AccountMap::" + pni)); - verify(commands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); - verify(commands).setex(eq("Account3::" + uuid), anyLong(), anyString()); - verifyNoMoreInteractions(commands); + verify(clusterCommands).get(eq("AccountMap::" + pni)); + verify(clusterCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); + verify(clusterCommands).setex(eq("Account3::" + uuid), anyLong(), anyString()); + verifyNoMoreInteractions(clusterCommands); verify(accounts).getByPhoneNumberIdentifier(pni); verifyNoMoreInteractions(accounts); @@ -488,8 +501,8 @@ class AccountsManagerTest { Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - when(asyncCommands.get(eq("AccountMap::" + pni))).thenReturn(MockRedisFuture.completedFuture(null)); - when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); + when(asyncClusterCommands.get(eq("AccountMap::" + pni))).thenReturn(MockRedisFuture.completedFuture(null)); + when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); when(accounts.getByPhoneNumberIdentifierAsync(pni)) .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); @@ -498,10 +511,10 @@ class AccountsManagerTest { assertTrue(retrieved.isPresent()); assertSame(retrieved.get(), account); - verify(asyncCommands).get(eq("AccountMap::" + pni)); - verify(asyncCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); - verify(asyncCommands).setex(eq("Account3::" + uuid), anyLong(), anyString()); - verifyNoMoreInteractions(asyncCommands); + verify(asyncClusterCommands).get(eq("AccountMap::" + pni)); + verify(asyncClusterCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); + verify(asyncClusterCommands).setex(eq("Account3::" + uuid), anyLong(), anyString()); + verifyNoMoreInteractions(asyncClusterCommands); verify(accounts).getByPhoneNumberIdentifierAsync(pni); verifyNoMoreInteractions(accounts); @@ -528,7 +541,7 @@ class AccountsManagerTest { UUID pni = UUID.randomUUID(); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - when(commands.get(eq("Account3::" + uuid))).thenThrow(new RedisException("Connection lost!")); + when(clusterCommands.get(eq("Account3::" + uuid))).thenThrow(new RedisException("Connection lost!")); when(accounts.getByAccountIdentifier(eq(uuid))).thenReturn(Optional.of(account)); Optional retrieved = accountsManager.getByAccountIdentifier(uuid); @@ -536,10 +549,10 @@ class AccountsManagerTest { assertTrue(retrieved.isPresent()); assertSame(retrieved.get(), account); - verify(commands, times(1)).get(eq("Account3::" + uuid)); - verify(commands, times(1)).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); - verify(commands, times(1)).setex(eq("Account3::" + uuid), anyLong(), anyString()); - verifyNoMoreInteractions(commands); + verify(clusterCommands, times(1)).get(eq("Account3::" + uuid)); + verify(clusterCommands, times(1)).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); + verify(clusterCommands, times(1)).setex(eq("Account3::" + uuid), anyLong(), anyString()); + verifyNoMoreInteractions(clusterCommands); verify(accounts, times(1)).getByAccountIdentifier(eq(uuid)); verifyNoMoreInteractions(accounts); @@ -551,10 +564,10 @@ class AccountsManagerTest { UUID pni = UUID.randomUUID(); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - when(asyncCommands.get(eq("Account3::" + uuid))) + when(asyncClusterCommands.get(eq("Account3::" + uuid))) .thenReturn(MockRedisFuture.failedFuture(new RedisException("Connection lost!"))); - when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); + when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); when(accounts.getByAccountIdentifierAsync(eq(uuid))) .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); @@ -564,10 +577,10 @@ class AccountsManagerTest { assertTrue(retrieved.isPresent()); assertSame(retrieved.get(), account); - verify(asyncCommands).get(eq("Account3::" + uuid)); - verify(asyncCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); - verify(asyncCommands).setex(eq("Account3::" + uuid), anyLong(), anyString()); - verifyNoMoreInteractions(asyncCommands); + verify(asyncClusterCommands).get(eq("Account3::" + uuid)); + verify(asyncClusterCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); + verify(asyncClusterCommands).setex(eq("Account3::" + uuid), anyLong(), anyString()); + verifyNoMoreInteractions(asyncClusterCommands); verify(accounts).getByAccountIdentifierAsync(eq(uuid)); verifyNoMoreInteractions(accounts); @@ -580,7 +593,7 @@ class AccountsManagerTest { Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - when(commands.get(eq("AccountMap::" + pni))).thenThrow(new RedisException("OH NO")); + when(clusterCommands.get(eq("AccountMap::" + pni))).thenThrow(new RedisException("OH NO")); when(accounts.getByPhoneNumberIdentifier(pni)).thenReturn(Optional.of(account)); Optional retrieved = accountsManager.getByPhoneNumberIdentifier(pni); @@ -588,10 +601,10 @@ class AccountsManagerTest { assertTrue(retrieved.isPresent()); assertSame(retrieved.get(), account); - verify(commands).get(eq("AccountMap::" + pni)); - verify(commands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); - verify(commands).setex(eq("Account3::" + uuid), anyLong(), anyString()); - verifyNoMoreInteractions(commands); + verify(clusterCommands).get(eq("AccountMap::" + pni)); + verify(clusterCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); + verify(clusterCommands).setex(eq("Account3::" + uuid), anyLong(), anyString()); + verifyNoMoreInteractions(clusterCommands); verify(accounts).getByPhoneNumberIdentifier(pni); verifyNoMoreInteractions(accounts); @@ -604,10 +617,10 @@ class AccountsManagerTest { Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - when(asyncCommands.get(eq("AccountMap::" + pni))) + when(asyncClusterCommands.get(eq("AccountMap::" + pni))) .thenReturn(MockRedisFuture.failedFuture(new RedisException("OH NO"))); - when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); + when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); when(accounts.getByPhoneNumberIdentifierAsync(pni)) .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); @@ -617,10 +630,10 @@ class AccountsManagerTest { assertTrue(retrieved.isPresent()); assertSame(retrieved.get(), account); - verify(asyncCommands).get(eq("AccountMap::" + pni)); - verify(asyncCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); - verify(asyncCommands).setex(eq("Account3::" + uuid), anyLong(), anyString()); - verifyNoMoreInteractions(asyncCommands); + verify(asyncClusterCommands).get(eq("AccountMap::" + pni)); + verify(asyncClusterCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); + verify(asyncClusterCommands).setex(eq("Account3::" + uuid), anyLong(), anyString()); + verifyNoMoreInteractions(asyncClusterCommands); verify(accounts).getByPhoneNumberIdentifierAsync(pni); verifyNoMoreInteractions(accounts); @@ -632,7 +645,7 @@ class AccountsManagerTest { UUID pni = UUID.randomUUID(); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - when(commands.get(eq("Account3::" + uuid))).thenReturn(null); + when(clusterCommands.get(eq("Account3::" + uuid))).thenReturn(null); when(accounts.getByAccountIdentifier(uuid)).thenReturn( Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]))); @@ -658,7 +671,7 @@ class AccountsManagerTest { UUID pni = UUID.randomUUID(); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - when(asyncCommands.get(eq("Account3::" + uuid))).thenReturn(null); + when(asyncClusterCommands.get(eq("Account3::" + uuid))).thenReturn(null); when(accounts.getByAccountIdentifierAsync(uuid)).thenReturn(CompletableFuture.completedFuture( Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH])))); @@ -684,7 +697,7 @@ class AccountsManagerTest { UUID uuid = UUID.randomUUID(); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); - when(commands.get(eq("Account3::" + uuid))).thenReturn(null); + when(clusterCommands.get(eq("Account3::" + uuid))).thenReturn(null); when(accounts.getByAccountIdentifier(uuid)).thenReturn(Optional.empty()) .thenReturn(Optional.of(account)); when(accounts.create(any(), any())).thenThrow(ContestedOptimisticLockException.class); @@ -971,7 +984,7 @@ class AccountsManagerTest { pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey), - accountsManager.generateDeviceLinkingToken(aci)) + accountsManager.generateLinkDeviceToken(aci)) .join(); verify(keysManager).deleteSingleUsePreKeys(aci, nextDeviceId); @@ -1606,7 +1619,7 @@ class AccountsManagerTest { final UUID aci = UUID.randomUUID(); assertEquals(Optional.of(aci), - accountsManager.checkDeviceLinkingToken(accountsManager.generateDeviceLinkingToken(aci))); + accountsManager.checkDeviceLinkingToken(accountsManager.generateLinkDeviceToken(aci))); } @ParameterizedTest @@ -1622,7 +1635,7 @@ class AccountsManagerTest { return Stream.of( // Expired token - Arguments.of(AccountsManager.generateDeviceLinkingToken(UUID.randomUUID(), + Arguments.of(AccountsManager.generateLinkDeviceToken(UUID.randomUUID(), new SecretKeySpec(LINK_DEVICE_SECRET, AccountsManager.LINK_DEVICE_VERIFICATION_TOKEN_ALGORITHM), CLOCK), tokenTimestamp.plus(AccountsManager.LINK_DEVICE_TOKEN_EXPIRATION_DURATION).plusSeconds(1)), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java index 58cdaf678..918a9c275 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java @@ -36,6 +36,7 @@ import org.junit.jupiter.api.extension.RegisterExtension; import org.mockito.Mockito; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; @@ -137,6 +138,7 @@ class AccountsManagerUsernameIntegrationTest { accounts, phoneNumberIdentifiers, CACHE_CLUSTER_EXTENSION.getRedisCluster(), + mock(FaultTolerantRedisClient.class), accountLockManager, keysManager, messageManager, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java index d2c78e44a..5eb5e2a01 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java @@ -13,6 +13,7 @@ import static org.mockito.Mockito.when; import com.google.i18n.phonenumbers.PhoneNumberUtil; import java.nio.charset.StandardCharsets; import java.time.Clock; +import java.time.Duration; import java.time.Instant; import java.time.ZoneId; import java.util.Optional; @@ -29,9 +30,11 @@ import org.junit.jupiter.api.extension.RegisterExtension; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.entities.DeviceInfo; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; +import org.whispersystems.textsecuregcm.redis.RedisServerExtension; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; @@ -59,6 +62,9 @@ public class AddRemoveDeviceIntegrationTest { @RegisterExtension static final RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); + @RegisterExtension + static final RedisServerExtension PUBSUB_SERVER_EXTENSION = RedisServerExtension.builder().build(); + private static final Clock CLOCK = Clock.fixed(Instant.now(), ZoneId.systemDefault()); private ExecutorService accountLockExecutor; @@ -128,10 +134,16 @@ public class AddRemoveDeviceIntegrationTest { when(registrationRecoveryPasswordsManager.removeForNumber(any())) .thenReturn(CompletableFuture.completedFuture(null)); + PUBSUB_SERVER_EXTENSION.getRedisClient().useConnection(connection -> { + connection.sync().flushall(); + connection.sync().configSet("notify-keyspace-events", "K$"); + }); + accountsManager = new AccountsManager( accounts, phoneNumberIdentifiers, CACHE_CLUSTER_EXTENSION.getRedisCluster(), + PUBSUB_SERVER_EXTENSION.getRedisClient(), accountLockManager, keysManager, messagesManager, @@ -146,10 +158,14 @@ public class AddRemoveDeviceIntegrationTest { CLOCK, "link-device-secret".getBytes(StandardCharsets.UTF_8), dynamicConfigurationManager); + + accountsManager.start(); } @AfterEach void tearDown() throws InterruptedException { + accountsManager.stop(); + accountLockExecutor.shutdown(); clientPresenceExecutor.shutdown(); @@ -187,7 +203,7 @@ public class AddRemoveDeviceIntegrationTest { KeysHelper.signedECPreKey(2, pniKeyPair), KeysHelper.signedKEMPreKey(3, aciKeyPair), KeysHelper.signedKEMPreKey(4, pniKeyPair)), - accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI))) + accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI))) .join(); assertEquals(2, updatedAccountAndDevice.first().getDevices().size()); @@ -216,7 +232,7 @@ public class AddRemoveDeviceIntegrationTest { final Account account = AccountsHelper.createAccount(accountsManager, number); assertEquals(1, accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getDevices().size()); - final String linkDeviceToken = accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI)); + final String linkDeviceToken = accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI)); final Pair updatedAccountAndDevice = accountsManager.addDevice(account, new DeviceSpec( @@ -292,7 +308,7 @@ public class AddRemoveDeviceIntegrationTest { KeysHelper.signedECPreKey(2, pniKeyPair), KeysHelper.signedKEMPreKey(3, aciKeyPair), KeysHelper.signedKEMPreKey(4, pniKeyPair)), - accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI))) + accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI))) .join(); final byte addedDeviceId = updatedAccountAndDevice.second().getId(); @@ -346,7 +362,7 @@ public class AddRemoveDeviceIntegrationTest { KeysHelper.signedECPreKey(2, pniKeyPair), KeysHelper.signedKEMPreKey(3, aciKeyPair), KeysHelper.signedKEMPreKey(4, pniKeyPair)), - accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI))) + accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI))) .join(); final byte addedDeviceId = updatedAccountAndDevice.second().getId(); @@ -376,4 +392,110 @@ public class AddRemoveDeviceIntegrationTest { assertTrue(keysManager.getLastResort(retrievedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); assertTrue(clientPublicKeysManager.findPublicKey(retrievedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent()); } + + @Test + void waitForNewLinkedDevice() throws InterruptedException { + final String number = PhoneNumberUtil.getInstance().format( + PhoneNumberUtil.getInstance().getExampleNumber("US"), + PhoneNumberUtil.PhoneNumberFormat.E164); + + final ECKeyPair aciKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniKeyPair = Curve.generateKeyPair(); + + final Account account = AccountsHelper.createAccount(accountsManager, number); + + final String linkDeviceToken = accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI)); + final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken); + + final CompletableFuture> displacedFuture = + accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofSeconds(5)); + + final CompletableFuture> activeFuture = + accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofSeconds(5)); + + assertEquals(Optional.empty(), displacedFuture.join()); + + final Pair updatedAccountAndDevice = + accountsManager.addDevice(account, new DeviceSpec( + "device-name".getBytes(StandardCharsets.UTF_8), + "password", + "OWT", + new Device.DeviceCapabilities(true, true, true, false), + 1, + 2, + true, + Optional.empty(), + Optional.empty(), + KeysHelper.signedECPreKey(1, aciKeyPair), + KeysHelper.signedECPreKey(2, pniKeyPair), + KeysHelper.signedKEMPreKey(3, aciKeyPair), + KeysHelper.signedKEMPreKey(4, pniKeyPair)), + linkDeviceToken) + .join(); + + final Optional maybeDeviceInfo = activeFuture.join(); + + assertTrue(maybeDeviceInfo.isPresent()); + final DeviceInfo deviceInfo = maybeDeviceInfo.get(); + + assertEquals(updatedAccountAndDevice.second().getId(), deviceInfo.id()); + assertEquals(updatedAccountAndDevice.second().getCreated(), deviceInfo.created()); + } + + @Test + void waitForNewLinkedDeviceAlreadyAdded() throws InterruptedException { + final String number = PhoneNumberUtil.getInstance().format( + PhoneNumberUtil.getInstance().getExampleNumber("US"), + PhoneNumberUtil.PhoneNumberFormat.E164); + + final ECKeyPair aciKeyPair = Curve.generateKeyPair(); + final ECKeyPair pniKeyPair = Curve.generateKeyPair(); + + final Account account = AccountsHelper.createAccount(accountsManager, number); + + final String linkDeviceToken = accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI)); + final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken); + + final Pair updatedAccountAndDevice = + accountsManager.addDevice(account, new DeviceSpec( + "device-name".getBytes(StandardCharsets.UTF_8), + "password", + "OWT", + new Device.DeviceCapabilities(true, true, true, false), + 1, + 2, + true, + Optional.empty(), + Optional.empty(), + KeysHelper.signedECPreKey(1, aciKeyPair), + KeysHelper.signedECPreKey(2, pniKeyPair), + KeysHelper.signedKEMPreKey(3, aciKeyPair), + KeysHelper.signedKEMPreKey(4, pniKeyPair)), + linkDeviceToken) + .join(); + + final CompletableFuture> linkedDeviceFuture = + accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofMinutes(1)); + + final Optional maybeDeviceInfo = linkedDeviceFuture.join(); + + assertTrue(maybeDeviceInfo.isPresent()); + final DeviceInfo deviceInfo = maybeDeviceInfo.get(); + + assertEquals(updatedAccountAndDevice.second().getId(), deviceInfo.id()); + assertEquals(updatedAccountAndDevice.second().getCreated(), deviceInfo.created()); + } + + @Test + void waitForNewLinkedDeviceTimeout() { + final String linkDeviceToken = accountsManager.generateLinkDeviceToken(UUID.randomUUID()); + final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken); + + final CompletableFuture> linkedDeviceFuture = + accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofMillis(10)); + + final Optional maybeDeviceInfo = linkedDeviceFuture.join(); + + assertTrue(maybeDeviceInfo.isEmpty()); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisServerHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisServerHelper.java new file mode 100644 index 000000000..8a1aeb4ef --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/RedisServerHelper.java @@ -0,0 +1,112 @@ +package org.whispersystems.textsecuregcm.tests.util; + +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.async.RedisAsyncCommands; +import io.lettuce.core.api.reactive.RedisReactiveCommands; +import io.lettuce.core.api.sync.RedisCommands; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient; +import java.util.function.Consumer; +import java.util.function.Function; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class RedisServerHelper { + + public static RedisServerHelper.Builder builder() { + return new RedisServerHelper.Builder(); + } + + @SuppressWarnings("unchecked") + private static FaultTolerantRedisClient buildMockRedisClient( + final RedisCommands stringCommands, + final RedisAsyncCommands stringAsyncCommands, + final RedisCommands binaryCommands, + final RedisAsyncCommands binaryAsyncCommands, + final RedisReactiveCommands binaryReactiveCommands) { + final FaultTolerantRedisClient client = mock(FaultTolerantRedisClient.class); + final StatefulRedisConnection stringConnection = mock(StatefulRedisConnection.class); + final StatefulRedisConnection binaryConnection = mock(StatefulRedisConnection.class); + + when(stringConnection.sync()).thenReturn(stringCommands); + when(stringConnection.async()).thenReturn(stringAsyncCommands); + when(binaryConnection.sync()).thenReturn(binaryCommands); + when(binaryConnection.async()).thenReturn(binaryAsyncCommands); + when(binaryConnection.reactive()).thenReturn(binaryReactiveCommands); + + when(client.withConnection(any(Function.class))).thenAnswer(invocation -> { + return invocation.getArgument(0, Function.class).apply(stringConnection); + }); + + doAnswer(invocation -> { + invocation.getArgument(0, Consumer.class).accept(stringConnection); + return null; + }).when(client).useConnection(any(Consumer.class)); + + when(client.withBinaryConnection(any(Function.class))).thenAnswer(invocation -> { + return invocation.getArgument(0, Function.class).apply(binaryConnection); + }); + + doAnswer(invocation -> { + invocation.getArgument(0, Consumer.class).accept(binaryConnection); + return null; + }).when(client).useBinaryConnection(any(Consumer.class)); + + return client; + } + + @SuppressWarnings("unchecked") + public static class Builder { + + private RedisCommands stringCommands = mock(RedisCommands.class); + private RedisAsyncCommands stringAsyncCommands = mock(RedisAsyncCommands.class); + + private RedisCommands binaryCommands = mock(RedisCommands.class); + + private RedisAsyncCommands binaryAsyncCommands = + mock(RedisAsyncCommands.class); + + private RedisReactiveCommands binaryReactiveCommands = + mock(RedisReactiveCommands.class); + + private Builder() { + + } + + public RedisServerHelper.Builder stringCommands(final RedisCommands stringCommands) { + this.stringCommands = stringCommands; + return this; + } + + public RedisServerHelper.Builder stringAsyncCommands(final RedisAsyncCommands stringAsyncCommands) { + this.stringAsyncCommands = stringAsyncCommands; + return this; + } + + public RedisServerHelper.Builder binaryCommands(final RedisCommands binaryCommands) { + this.binaryCommands = binaryCommands; + return this; + } + + public RedisServerHelper.Builder binaryAsyncCommands(final RedisAsyncCommands binaryAsyncCommands) { + this.binaryAsyncCommands = binaryAsyncCommands; + return this; + } + + public RedisServerHelper.Builder binaryReactiveCommands( + final RedisReactiveCommands binaryReactiveCommands) { + this.binaryReactiveCommands = binaryReactiveCommands; + return this; + } + + public FaultTolerantRedisClient build() { + return RedisServerHelper.buildMockRedisClient(stringCommands, + stringAsyncCommands, + binaryCommands, + binaryAsyncCommands, + binaryReactiveCommands); + } + } +}