Add API endpoints for waiting for newly-linked devices

This commit is contained in:
Jon Chambers 2024-10-10 10:11:32 -04:00 committed by GitHub
parent 087c2b61ee
commit 8c30a359e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 793 additions and 122 deletions

View File

@ -642,7 +642,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
ClientPublicKeysManager clientPublicKeysManager = ClientPublicKeysManager clientPublicKeysManager =
new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor); new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor);
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster, AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
accountLockManager, keysManager, messagesManager, profilesManager, pubsubClient, accountLockManager, keysManager, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client, secureStorageClient, secureValueRecovery2Client,
clientPresenceManager, clientPresenceManager,
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, clientPresenceExecutor, registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, clientPresenceExecutor,
@ -764,6 +764,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.lifecycle().manage(keyTransparencyServiceClient); environment.lifecycle().manage(keyTransparencyServiceClient);
environment.lifecycle().manage(clientReleaseManager); environment.lifecycle().manage(clientReleaseManager);
environment.lifecycle().manage(virtualThreadPinEventMonitor); environment.lifecycle().manage(virtualThreadPinEventMonitor);
environment.lifecycle().manage(accountsManager);
final RegistrationCaptchaManager registrationCaptchaManager = new RegistrationCaptchaManager(captchaChecker); final RegistrationCaptchaManager registrationCaptchaManager = new RegistrationCaptchaManager(captchaChecker);

View File

@ -4,22 +4,36 @@
*/ */
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.HttpHeaders; import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth; import io.dropwizard.auth.Auth;
import io.lettuce.core.RedisException;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.headers.Header; import io.swagger.v3.oas.annotations.headers.Header;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag; import java.time.Duration;
import java.util.LinkedList; import java.util.Arrays;
import java.util.EnumMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import javax.validation.constraints.Size;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE; import javax.ws.rs.DELETE;
import javax.ws.rs.DefaultValue;
import javax.ws.rs.ForbiddenException; import javax.ws.rs.ForbiddenException;
import javax.ws.rs.GET; import javax.ws.rs.GET;
import javax.ws.rs.HeaderParam; import javax.ws.rs.HeaderParam;
@ -27,10 +41,12 @@ import javax.ws.rs.PUT;
import javax.ws.rs.Path; import javax.ws.rs.Path;
import javax.ws.rs.PathParam; import javax.ws.rs.PathParam;
import javax.ws.rs.Produces; import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException; import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.Context; import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import io.swagger.v3.oas.annotations.tags.Tag;
import org.glassfish.jersey.server.ContainerRequest; import org.glassfish.jersey.server.ContainerRequest;
import org.whispersystems.textsecuregcm.auth.LinkedDeviceRefreshRequirementProvider; import org.whispersystems.textsecuregcm.auth.LinkedDeviceRefreshRequirementProvider;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
@ -47,6 +63,8 @@ import org.whispersystems.textsecuregcm.entities.ProvisioningMessage;
import org.whispersystems.textsecuregcm.entities.SetPublicKeyRequest; import org.whispersystems.textsecuregcm.entities.SetPublicKeyRequest;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager; import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
@ -54,7 +72,11 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.DeviceSpec; import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.LinkDeviceTokenAlreadyUsedException; import org.whispersystems.textsecuregcm.storage.LinkDeviceTokenAlreadyUsedException;
import org.whispersystems.textsecuregcm.util.VerificationCode; import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.LinkDeviceToken;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.websocket.auth.Mutable; import org.whispersystems.websocket.auth.Mutable;
import org.whispersystems.websocket.auth.ReadOnly; import org.whispersystems.websocket.auth.ReadOnly;
@ -69,6 +91,21 @@ public class DeviceController {
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
private final Map<String, Integer> maxDeviceConfiguration; private final Map<String, Integer> maxDeviceConfiguration;
private final EnumMap<ClientPlatform, AtomicInteger> linkedDeviceListenersByPlatform;
private final AtomicInteger linkedDeviceListenersForUnrecognizedPlatforms;
private static final String LINKED_DEVICE_LISTENER_GAUGE_NAME =
MetricsUtil.name(DeviceController.class, "linkedDeviceListeners");
private static final String WAIT_FOR_LINKED_DEVICE_TIMER_NAME =
MetricsUtil.name(DeviceController.class, "waitForLinkedDeviceDuration");
@VisibleForTesting
static final int MIN_TOKEN_IDENTIFIER_LENGTH = 32;
@VisibleForTesting
static final int MAX_TOKEN_IDENTIFIER_LENGTH = 64;
public DeviceController(final AccountsManager accounts, public DeviceController(final AccountsManager accounts,
final ClientPublicKeysManager clientPublicKeysManager, final ClientPublicKeysManager clientPublicKeysManager,
final RateLimiters rateLimiters, final RateLimiters rateLimiters,
@ -78,19 +115,32 @@ public class DeviceController {
this.clientPublicKeysManager = clientPublicKeysManager; this.clientPublicKeysManager = clientPublicKeysManager;
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
this.maxDeviceConfiguration = maxDeviceConfiguration; this.maxDeviceConfiguration = maxDeviceConfiguration;
linkedDeviceListenersByPlatform = Arrays.stream(ClientPlatform.values())
.collect(Collectors.toMap(
Function.identity(),
clientPlatform -> buildGauge(clientPlatform.name().toLowerCase()),
(a, b) -> {
throw new AssertionError("Duplicate client platform enumeration key");
},
() -> new EnumMap<>(ClientPlatform.class)
));
linkedDeviceListenersForUnrecognizedPlatforms = buildGauge("unknown");
}
private static AtomicInteger buildGauge(final String clientPlatformName) {
return Metrics.gauge(LINKED_DEVICE_LISTENER_GAUGE_NAME,
Tags.of(io.micrometer.core.instrument.Tag.of(UserAgentTagUtil.PLATFORM_TAG, clientPlatformName)),
new AtomicInteger(0));
} }
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public DeviceInfoList getDevices(@ReadOnly @Auth AuthenticatedDevice auth) { public DeviceInfoList getDevices(@ReadOnly @Auth AuthenticatedDevice auth) {
List<DeviceInfo> devices = new LinkedList<>(); return new DeviceInfoList(auth.getAccount().getDevices().stream()
.map(DeviceInfo::forDevice)
for (Device device : auth.getAccount().getDevices()) { .toList());
devices.add(new DeviceInfo(device.getId(), device.getName(),
device.getLastSeen(), device.getCreated()));
}
return new DeviceInfoList(devices);
} }
@DELETE @DELETE
@ -138,7 +188,7 @@ public class DeviceController {
@ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header( @ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header(
name = "Retry-After", name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed")) description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public VerificationCode createDeviceToken(@ReadOnly @Auth AuthenticatedDevice auth) public LinkDeviceToken createDeviceToken(@ReadOnly @Auth AuthenticatedDevice auth)
throws RateLimitExceededException, DeviceLimitExceededException { throws RateLimitExceededException, DeviceLimitExceededException {
final Account account = auth.getAccount(); final Account account = auth.getAccount();
@ -159,7 +209,9 @@ public class DeviceController {
throw new WebApplicationException(Response.Status.UNAUTHORIZED); throw new WebApplicationException(Response.Status.UNAUTHORIZED);
} }
return new VerificationCode(accounts.generateDeviceLinkingToken(account.getUuid())); final String token = accounts.generateLinkDeviceToken(account.getUuid());
return new LinkDeviceToken(token, AccountsManager.getLinkDeviceTokenIdentifier(token));
} }
@PUT @PUT
@ -266,6 +318,83 @@ public class DeviceController {
} }
} }
@GET
@Path("/wait_for_linked_device/{tokenIdentifier}")
@Produces(MediaType.APPLICATION_JSON)
@Operation(summary = "Wait for a new device to be linked to an account",
description = """
Waits for a new device to be linked to an account and returns basic information about the new device when
available.
""")
@ApiResponse(responseCode = "200", description = "The specified was linked to an account")
@ApiResponse(responseCode = "204", description = "No device was linked to the account before the call completed")
@ApiResponse(responseCode = "400", description = "The given token identifier or timeout was invalid")
@ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay")
@Schema(description = "Basic information about the linked device", implementation = DeviceInfo.class)
public CompletableFuture<Response> waitForLinkedDevice(
@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice,
@PathParam("tokenIdentifier")
@Schema(description = "A 'link device' token identifier provided by the 'create link device token' endpoint")
@Size(min = MIN_TOKEN_IDENTIFIER_LENGTH, max = MAX_TOKEN_IDENTIFIER_LENGTH)
final String tokenIdentifier,
@QueryParam("timeout")
@DefaultValue("30")
@Min(1)
@Max(3600)
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED,
minimum = "1",
maximum = "3600",
description = """
The amount of time (in seconds) to wait for a response. If the expected device is not linked within the
given amount of time, this endpoint will return a status of HTTP/204.
""") final int timeoutSeconds,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent) throws RateLimitExceededException {
rateLimiters.getWaitForLinkedDeviceLimiter().validate(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI));
final AtomicInteger linkedDeviceListenerCounter = getCounterForLinkedDeviceListeners(userAgent);
linkedDeviceListenerCounter.incrementAndGet();
final Timer.Sample sample = Timer.start();
try {
return accounts.waitForNewLinkedDevice(tokenIdentifier, Duration.ofSeconds(timeoutSeconds))
.thenApply(maybeDeviceInfo -> maybeDeviceInfo
.map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build())
.orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build()))
.exceptionally(ExceptionUtils.exceptionallyHandler(IllegalArgumentException.class,
e -> Response.status(Response.Status.BAD_REQUEST).build()))
.whenComplete((response, throwable) -> {
linkedDeviceListenerCounter.decrementAndGet();
if (response != null) {
sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME)
.publishPercentileHistogram(true)
.tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
io.micrometer.core.instrument.Tag.of("deviceFound",
String.valueOf(response.getStatus() == Response.Status.OK.getStatusCode()))))
.register(Metrics.globalRegistry));
}
});
} catch (final RedisException e) {
// `waitForNewLinkedDevice` could fail synchronously if the Redis circuit breaker is open; prevent counter drift
// if that happens
linkedDeviceListenerCounter.decrementAndGet();
throw e;
}
}
private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) {
try {
return linkedDeviceListenersByPlatform.get(UserAgentUtil.parseUserAgentString(userAgent).getPlatform());
} catch (final UnrecognizedUserAgentException ignored) {
return linkedDeviceListenersForUnrecognizedPlatforms;
}
}
@PUT @PUT
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Path("/unauthenticated_delivery") @Path("/unauthenticated_delivery")

View File

@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.ByteArrayBase64WithPaddingAdapter; import org.whispersystems.textsecuregcm.util.ByteArrayBase64WithPaddingAdapter;
public record DeviceInfo(long id, public record DeviceInfo(long id,
@ -17,4 +18,8 @@ public record DeviceInfo(long id,
long lastSeen, long lastSeen,
long created) { long created) {
public static DeviceInfo forDevice(final Device device) {
return new DeviceInfo(device.getId(), device.getName(), device.getLastSeen(), device.getCreated());
}
} }

View File

@ -50,6 +50,7 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", true, new RateLimiterConfig(100, Duration.ofMinutes(15))), EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", true, new RateLimiterConfig(100, Duration.ofMinutes(15))),
KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", true, new RateLimiterConfig(100, Duration.ofSeconds(15))), KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", true, new RateLimiterConfig(100, Duration.ofSeconds(15))),
KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", true, new RateLimiterConfig(100, Duration.ofSeconds(15))), KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", true, new RateLimiterConfig(100, Duration.ofSeconds(15))),
WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", true, new RateLimiterConfig(10, Duration.ofSeconds(30))),
; ;
private final String id; private final String id;
@ -205,4 +206,8 @@ public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
public RateLimiter getStoriesLimiter() { public RateLimiter getStoriesLimiter() {
return forDescriptor(For.STORIES); return forDescriptor(For.STORIES);
} }
public RateLimiter getWaitForLinkedDeviceLimiter() {
return forDescriptor(For.WAIT_FOR_LINKED_DEVICE);
}
} }

View File

@ -12,8 +12,11 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectWriter; import com.fasterxml.jackson.databind.ObjectWriter;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import io.dropwizard.lifecycle.Managed;
import io.lettuce.core.RedisException; import io.lettuce.core.RedisException;
import io.lettuce.core.SetArgs;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import io.lettuce.core.pubsub.RedisPubSubAdapter;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags; import io.micrometer.core.instrument.Tags;
@ -42,7 +45,9 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage; import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction; import java.util.function.BiFunction;
import java.util.function.Consumer; import java.util.function.Consumer;
@ -61,13 +66,16 @@ import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException; import org.whispersystems.textsecuregcm.controllers.MismatchedDevicesException;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubConnection;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
@ -82,14 +90,13 @@ import reactor.core.scheduler.Scheduler;
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem; import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
import software.amazon.awssdk.services.dynamodb.model.TransactionCanceledException; import software.amazon.awssdk.services.dynamodb.model.TransactionCanceledException;
public class AccountsManager { public class AccountsManager extends RedisPubSubAdapter<String, String> implements Managed {
private static final Timer createTimer = Metrics.timer(name(AccountsManager.class, "create")); private static final Timer createTimer = Metrics.timer(name(AccountsManager.class, "create"));
private static final Timer updateTimer = Metrics.timer(name(AccountsManager.class, "update")); private static final Timer updateTimer = Metrics.timer(name(AccountsManager.class, "update"));
private static final Timer getByNumberTimer = Metrics.timer(name(AccountsManager.class, "getByNumber")); private static final Timer getByNumberTimer = Metrics.timer(name(AccountsManager.class, "getByNumber"));
private static final Timer getByUsernameHashTimer = Metrics.timer(name(AccountsManager.class, "getByUsernameHash")); private static final Timer getByUsernameHashTimer = Metrics.timer(name(AccountsManager.class, "getByUsernameHash"));
private static final Timer getByUsernameLinkHandleTimer = Metrics.timer( private static final Timer getByUsernameLinkHandleTimer = Metrics.timer(name(AccountsManager.class, "getByUsernameLinkHandle"));
name(AccountsManager.class, "getByUsernameLinkHandle"));
private static final Timer getByUuidTimer = Metrics.timer(name(AccountsManager.class, "getByUuid")); private static final Timer getByUuidTimer = Metrics.timer(name(AccountsManager.class, "getByUuid"));
private static final Timer deleteTimer = Metrics.timer(name(AccountsManager.class, "delete")); private static final Timer deleteTimer = Metrics.timer(name(AccountsManager.class, "delete"));
@ -108,6 +115,7 @@ public class AccountsManager {
private final Accounts accounts; private final Accounts accounts;
private final PhoneNumberIdentifiers phoneNumberIdentifiers; private final PhoneNumberIdentifiers phoneNumberIdentifiers;
private final FaultTolerantRedisClusterClient cacheCluster; private final FaultTolerantRedisClusterClient cacheCluster;
private final FaultTolerantRedisClient pubSubRedisSingleton;
private final AccountLockManager accountLockManager; private final AccountLockManager accountLockManager;
private final KeysManager keysManager; private final KeysManager keysManager;
private final MessagesManager messagesManager; private final MessagesManager messagesManager;
@ -124,6 +132,16 @@ public class AccountsManager {
private final Key verificationTokenKey; private final Key verificationTokenKey;
private final FaultTolerantPubSubConnection<String, String> pubSubConnection;
private final Map<String, CompletableFuture<Optional<DeviceInfo>>> waitForDeviceFuturesByTokenIdentifier =
new ConcurrentHashMap<>();
private static final int SHA256_HASH_LENGTH = getSha256MessageDigest().getDigestLength();
private static final Duration RECENTLY_ADDED_DEVICE_TTL = Duration.ofHours(1);
private static final String LINKED_DEVICE_PREFIX = "linked_device::";
private static final String LINKED_DEVICE_KEYSPACE_PATTERN = "__keyspace@0__:" + LINKED_DEVICE_PREFIX + "*";
private static final ObjectWriter ACCOUNT_REDIS_JSON_WRITER = SystemMapper.jsonMapper() private static final ObjectWriter ACCOUNT_REDIS_JSON_WRITER = SystemMapper.jsonMapper()
.writer(SystemMapper.excludingField(Account.class, List.of("uuid"))); .writer(SystemMapper.excludingField(Account.class, List.of("uuid")));
@ -158,6 +176,7 @@ public class AccountsManager {
public AccountsManager(final Accounts accounts, public AccountsManager(final Accounts accounts,
final PhoneNumberIdentifiers phoneNumberIdentifiers, final PhoneNumberIdentifiers phoneNumberIdentifiers,
final FaultTolerantRedisClusterClient cacheCluster, final FaultTolerantRedisClusterClient cacheCluster,
final FaultTolerantRedisClient pubSubRedisSingleton,
final AccountLockManager accountLockManager, final AccountLockManager accountLockManager,
final KeysManager keysManager, final KeysManager keysManager,
final MessagesManager messagesManager, final MessagesManager messagesManager,
@ -175,6 +194,7 @@ public class AccountsManager {
this.accounts = accounts; this.accounts = accounts;
this.phoneNumberIdentifiers = phoneNumberIdentifiers; this.phoneNumberIdentifiers = phoneNumberIdentifiers;
this.cacheCluster = cacheCluster; this.cacheCluster = cacheCluster;
this.pubSubRedisSingleton = pubSubRedisSingleton;
this.accountLockManager = accountLockManager; this.accountLockManager = accountLockManager;
this.keysManager = keysManager; this.keysManager = keysManager;
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
@ -197,6 +217,20 @@ public class AccountsManager {
} catch (final InvalidKeyException e) { } catch (final InvalidKeyException e) {
throw new IllegalArgumentException(e); throw new IllegalArgumentException(e);
} }
this.pubSubConnection = pubSubRedisSingleton.createPubSubConnection();
}
@Override
public void start() {
pubSubConnection.usePubSubConnection(connection -> connection.addListener(this));
pubSubConnection.usePubSubConnection(connection -> connection.sync().psubscribe(LINKED_DEVICE_KEYSPACE_PATTERN));
}
@Override
public void stop() {
pubSubConnection.usePubSubConnection(connection -> connection.sync().punsubscribe());
pubSubConnection.usePubSubConnection(connection -> connection.removeListener(this));
} }
public Account create(final String number, public Account create(final String number,
@ -363,6 +397,26 @@ public class AccountsManager {
} }
return CompletableFuture.failedFuture(throwable); return CompletableFuture.failedFuture(throwable);
})
.whenComplete((updatedAccountAndDevice, throwable) -> {
if (updatedAccountAndDevice != null) {
final String key = getLinkedDeviceKey(getLinkDeviceTokenIdentifier(linkDeviceToken));
final String deviceInfoJson;
try {
deviceInfoJson = SystemMapper.jsonMapper().writeValueAsString(DeviceInfo.forDevice(updatedAccountAndDevice.second()));
} catch (final JsonProcessingException e) {
throw new UncheckedIOException(e);
}
pubSubRedisSingleton.withConnection(connection ->
connection.async().set(key, deviceInfoJson, SetArgs.Builder.ex(RECENTLY_ADDED_DEVICE_TTL)))
.whenComplete((ignored, pubSubThrowable) -> {
if (pubSubThrowable != null) {
logger.warn("Failed to record recently-created device", pubSubThrowable);
}
});
}
}); });
} }
@ -386,7 +440,7 @@ public class AccountsManager {
} }
} }
public String generateDeviceLinkingToken(final UUID aci) { public String generateLinkDeviceToken(final UUID aci) {
final String claims = aci + "." + clock.instant().toEpochMilli(); final String claims = aci + "." + clock.instant().toEpochMilli();
final byte[] signature = getInitializedMac().doFinal(claims.getBytes(StandardCharsets.UTF_8)); final byte[] signature = getInitializedMac().doFinal(claims.getBytes(StandardCharsets.UTF_8));
@ -394,7 +448,7 @@ public class AccountsManager {
} }
@VisibleForTesting @VisibleForTesting
static String generateDeviceLinkingToken(final UUID aci, final Key linkDeviceTokenKey, final Clock clock) static String generateLinkDeviceToken(final UUID aci, final Key linkDeviceTokenKey, final Clock clock)
throws InvalidKeyException { throws InvalidKeyException {
final String claims = aci + "." + clock.instant().toEpochMilli(); final String claims = aci + "." + clock.instant().toEpochMilli();
@ -403,6 +457,11 @@ public class AccountsManager {
return claims + ":" + Base64.getUrlEncoder().encodeToString(signature); return claims + ":" + Base64.getUrlEncoder().encodeToString(signature);
} }
public static String getLinkDeviceTokenIdentifier(final String linkDeviceToken) {
return Base64.getUrlEncoder().withoutPadding().encodeToString(
getSha256MessageDigest().digest(linkDeviceToken.getBytes(StandardCharsets.UTF_8)));
}
/** /**
* Checks that a device-linking token is valid and returns the account identifier from the token if so, or empty if * Checks that a device-linking token is valid and returns the account identifier from the token if so, or empty if
* the token was invalid * the token was invalid
@ -1340,4 +1399,75 @@ public class AccountsManager {
.whenComplete((ignoredResult, ignoredException) -> sample.stop(redisDeleteTimer)) .whenComplete((ignoredResult, ignoredException) -> sample.stop(redisDeleteTimer))
.thenRun(Util.NOOP); .thenRun(Util.NOOP);
} }
public CompletableFuture<Optional<DeviceInfo>> waitForNewLinkedDevice(final String linkDeviceTokenIdentifier, final Duration timeout) {
// Unbeknownst to callers but beknownst to us, the "link device token identifier" is the base64/url-encoded SHA256
// hash of a device-linking token. Before we use the string anywhere, make sure it's the right "shape" for a hash.
if (Base64.getUrlDecoder().decode(linkDeviceTokenIdentifier).length != SHA256_HASH_LENGTH) {
return CompletableFuture.failedFuture(new IllegalArgumentException("Invalid token identifier"));
}
final CompletableFuture<Optional<DeviceInfo>> waitForDeviceFuture = new CompletableFuture<>();
waitForDeviceFuture
.completeOnTimeout(Optional.empty(), TimeUnit.MILLISECONDS.convert(timeout), TimeUnit.MILLISECONDS)
.whenComplete((maybeDevice, throwable) -> waitForDeviceFuturesByTokenIdentifier.compute(linkDeviceTokenIdentifier,
(ignored, existingFuture) -> {
// Only remove the future from the map if it's THIS future, and not one that later displaced this one
return existingFuture == waitForDeviceFuture ? null : existingFuture;
}));
{
final CompletableFuture<Optional<DeviceInfo>> displacedFuture =
waitForDeviceFuturesByTokenIdentifier.put(linkDeviceTokenIdentifier, waitForDeviceFuture);
if (displacedFuture != null) {
displacedFuture.complete(Optional.empty());
}
}
// The device may already have been linked by the time the caller started watching for it; perform an immediate
// check to see if the device is already there.
pubSubRedisSingleton.withConnection(connection -> connection.async().get(getLinkedDeviceKey(linkDeviceTokenIdentifier)))
.thenAccept(response -> {
if (StringUtils.isNotBlank(response)) {
handleDeviceAdded(waitForDeviceFuture, response);
}
});
return waitForDeviceFuture;
}
private static String getLinkedDeviceKey(final String linkDeviceTokenIdentifier) {
return LINKED_DEVICE_PREFIX + linkDeviceTokenIdentifier;
}
@Override
public void message(final String pattern, final String channel, final String message) {
if (LINKED_DEVICE_KEYSPACE_PATTERN.equals(pattern) && "set".equalsIgnoreCase(message)) {
// The `- 1` here compensates for the '*' in the pattern
final String tokenIdentifier = channel.substring(LINKED_DEVICE_KEYSPACE_PATTERN.length() - 1);
Optional.ofNullable(waitForDeviceFuturesByTokenIdentifier.remove(tokenIdentifier))
.ifPresent(future -> pubSubRedisSingleton.withConnection(connection -> connection.async().get(getLinkedDeviceKey(tokenIdentifier)))
.thenAccept(deviceInfoJson -> handleDeviceAdded(future, deviceInfoJson)));
}
}
private void handleDeviceAdded(final CompletableFuture<Optional<DeviceInfo>> future, final String deviceInfoJson) {
try {
future.complete(Optional.of(SystemMapper.jsonMapper().readValue(deviceInfoJson, DeviceInfo.class)));
} catch (final JsonProcessingException e) {
logger.error("Could not parse device json", e);
future.completeExceptionally(e);
}
}
private static MessageDigest getSha256MessageDigest() {
try {
return MessageDigest.getInstance("SHA-256");
} catch (final NoSuchAlgorithmException e) {
throw new AssertionError("Every implementation of the Java platform is required to support the SHA-256 MessageDigest algorithm", e);
}
}
} }

View File

@ -0,0 +1,24 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import io.swagger.v3.oas.annotations.media.Schema;
public record LinkDeviceToken(
@Schema(description = """
An opaque token to send to a new linked device that authorizes the new device to link itself to the account that
requested this token.
""")
@JsonProperty("verificationCode") String token,
@Schema(description = """
An opaque identifier for the generated token that the caller may use to watch for a new device to complete the
linking process.
""")
String tokenIdentifier) {
}

View File

@ -1,8 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
public record VerificationCode(String verificationCode) {
}

View File

@ -40,6 +40,7 @@ import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.FcmSender; import org.whispersystems.textsecuregcm.push.FcmSender;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.storage.AccountLockManager; import org.whispersystems.textsecuregcm.storage.AccountLockManager;
@ -112,6 +113,8 @@ record CommandDependencies(
.build("main_cache", redisClientResourcesBuilder); .build("main_cache", redisClientResourcesBuilder);
FaultTolerantRedisClusterClient pushSchedulerCluster = configuration.getPushSchedulerCluster() FaultTolerantRedisClusterClient pushSchedulerCluster = configuration.getPushSchedulerCluster()
.build("push_scheduler", redisClientResourcesBuilder); .build("push_scheduler", redisClientResourcesBuilder);
FaultTolerantRedisClient pubsubClient =
configuration.getRedisPubSubConfiguration().build("pubsub", redisClientResourcesBuilder.build());
ScheduledExecutorService recurringJobExecutor = environment.lifecycle() ScheduledExecutorService recurringJobExecutor = environment.lifecycle()
.scheduledExecutorService(name(name, "recurringJob-%d")).threads(2).build(); .scheduledExecutorService(name(name, "recurringJob-%d")).threads(2).build();
@ -225,7 +228,7 @@ record CommandDependencies(
ClientPublicKeysManager clientPublicKeysManager = ClientPublicKeysManager clientPublicKeysManager =
new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor); new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor);
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster, AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
accountLockManager, keys, messagesManager, profilesManager, pubsubClient, accountLockManager, keys, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client, clientPresenceManager, secureStorageClient, secureValueRecovery2Client, clientPresenceManager,
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, clientPresenceExecutor, registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, clientPresenceExecutor,
clock, configuration.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager); clock, configuration.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager);

View File

@ -5,12 +5,14 @@
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq; import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
@ -18,6 +20,7 @@ import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.amazonaws.util.Base64;
import com.google.common.net.HttpHeaders; import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.AuthValueFactoryProvider; import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
@ -25,6 +28,7 @@ import io.dropwizard.testing.junit5.ResourceExtension;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -35,6 +39,7 @@ import java.util.stream.Stream;
import javax.ws.rs.client.Entity; import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.apache.commons.lang3.RandomStringUtils;
import org.glassfish.jersey.server.ServerProperties; import org.glassfish.jersey.server.ServerProperties;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
@ -54,6 +59,7 @@ import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventLis
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest; import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest;
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
import org.whispersystems.textsecuregcm.entities.DeviceResponse; import org.whispersystems.textsecuregcm.entities.DeviceResponse;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
@ -64,6 +70,7 @@ import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
@ -79,7 +86,7 @@ import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import org.whispersystems.textsecuregcm.util.VerificationCode; import org.whispersystems.textsecuregcm.util.LinkDeviceToken;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
class DeviceControllerTest { class DeviceControllerTest {
@ -112,6 +119,7 @@ class DeviceControllerTest {
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE) .addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class)) .addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class))
.addProvider(new RateLimitExceededExceptionMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager)) .addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager))
.addProvider(new DeviceLimitExceededExceptionMapper()) .addProvider(new DeviceLimitExceededExceptionMapper())
@ -122,6 +130,7 @@ class DeviceControllerTest {
void setup() { void setup() {
when(rateLimiters.getAllocateDeviceLimiter()).thenReturn(rateLimiter); when(rateLimiters.getAllocateDeviceLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVerifyDeviceLimiter()).thenReturn(rateLimiter); when(rateLimiters.getVerifyDeviceLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getWaitForLinkedDeviceLimiter()).thenReturn(rateLimiter);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID); when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
@ -479,16 +488,17 @@ class DeviceControllerTest {
final Optional<ApnRegistrationId> apnRegistrationId, final Optional<ApnRegistrationId> apnRegistrationId,
final Optional<GcmRegistrationId> gcmRegistrationId) { final Optional<GcmRegistrationId> gcmRegistrationId) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.generateLinkDeviceToken(any())).thenReturn("test");
final Device existingDevice = mock(Device.class); final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID); when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest() final LinkDeviceToken deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code") .target("/v1/devices/provisioning/code")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class); .get(LinkDeviceToken.class);
final ECSignedPreKey aciSignedPreKey; final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey; final ECSignedPreKey pniSignedPreKey;
@ -506,7 +516,7 @@ class DeviceControllerTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(), final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.token(),
new AccountAttributes(fetchesMessages, 1234, 5678, null, null, true, null), new AccountAttributes(fetchesMessages, 1234, 5678, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId)); new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnRegistrationId, gcmRegistrationId));
@ -539,21 +549,22 @@ class DeviceControllerTest {
final KEMSignedPreKey pniPqLastResortPreKey) { final KEMSignedPreKey pniPqLastResortPreKey) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.generateLinkDeviceToken(any())).thenReturn("test");
final Device existingDevice = mock(Device.class); final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID); when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest() final LinkDeviceToken deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code") .target("/v1/devices/provisioning/code")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class); .get(LinkDeviceToken.class);
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(aciIdentityKey); when(account.getIdentityKey(IdentityType.ACI)).thenReturn(aciIdentityKey);
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(pniIdentityKey); when(account.getIdentityKey(IdentityType.PNI)).thenReturn(pniIdentityKey);
final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.verificationCode(), final LinkDeviceRequest request = new LinkDeviceRequest(deviceCode.token(),
new AccountAttributes(true, 1234, 5678, null, null, true, null), new AccountAttributes(true, 1234, 5678, null, null, true, null),
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty())); new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, Optional.empty(), Optional.empty()));
@ -931,4 +942,120 @@ class DeviceControllerTest {
verify(clientPublicKeysManager).setPublicKey(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE.getId(), request.publicKey()); verify(clientPublicKeysManager).setPublicKey(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE.getId(), request.publicKey());
} }
@Test
void waitForLinkedDevice() {
final DeviceInfo deviceInfo = new DeviceInfo(Device.PRIMARY_ID,
"Device name ciphertext".getBytes(StandardCharsets.UTF_8),
System.currentTimeMillis(),
System.currentTimeMillis());
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo)));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(200, response.getStatus());
final DeviceInfo retrievedDeviceInfo = response.readEntity(DeviceInfo.class);
assertEquals(deviceInfo.id(), retrievedDeviceInfo.id());
assertArrayEquals(deviceInfo.name(), retrievedDeviceInfo.name());
assertEquals(deviceInfo.created(), retrievedDeviceInfo.created());
assertEquals(deviceInfo.lastSeen(), retrievedDeviceInfo.lastSeen());
}
}
@Test
void waitForLinkedDeviceNoDeviceLinked() {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(204, response.getStatus());
}
}
@Test
void waitForLinkedDeviceBadTokenIdentifier() {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException()));
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(400, response.getStatus());
}
}
@ParameterizedTest
@MethodSource
void waitForLinkedDeviceBadTimeout(final int timeoutSeconds) {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.queryParam("timeout", timeoutSeconds)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(400, response.getStatus());
}
}
private static List<Integer> waitForLinkedDeviceBadTimeout() {
return List.of(0, -1, 3601);
}
@ParameterizedTest
@MethodSource
void waitForLinkedDeviceBadTokenIdentifierLength(final String tokenIdentifier) {
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(400, response.getStatus());
}
}
private static List<String> waitForLinkedDeviceBadTokenIdentifierLength() {
return List.of(RandomStringUtils.randomAlphanumeric(DeviceController.MIN_TOKEN_IDENTIFIER_LENGTH - 1),
RandomStringUtils.randomAlphanumeric(DeviceController.MAX_TOKEN_IDENTIFIER_LENGTH + 1));
}
@Test
void waitForLinkedDeviceRateLimited() throws RateLimitExceededException {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
doThrow(new RateLimitExceededException(null)).when(rateLimiter).validate(AuthHelper.VALID_UUID);
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {
assertEquals(429, response.getStatus());
}
}
} }

View File

@ -43,6 +43,7 @@ import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
@ -142,6 +143,7 @@ public class AccountCreationDeletionIntegrationTest {
accounts, accounts,
phoneNumberIdentifiers, phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(), CACHE_CLUSTER_EXTENSION.getRedisCluster(),
mock(FaultTolerantRedisClient.class),
accountLockManager, accountLockManager,
keysManager, keysManager,
messagesManager, messagesManager,

View File

@ -36,6 +36,7 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
@ -137,6 +138,7 @@ class AccountsManagerChangeNumberIntegrationTest {
accounts, accounts,
phoneNumberIdentifiers, phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(), CACHE_CLUSTER_EXTENSION.getRedisCluster(),
mock(FaultTolerantRedisClient.class),
accountLockManager, accountLockManager,
keysManager, keysManager,
messagesManager, messagesManager,

View File

@ -48,6 +48,7 @@ import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfigurati
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
@ -124,6 +125,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
accounts, accounts,
phoneNumberIdentifiers, phoneNumberIdentifiers,
RedisClusterHelper.builder().stringCommands(commands).build(), RedisClusterHelper.builder().stringCommands(commands).build(),
mock(FaultTolerantRedisClient.class),
accountLockManager, accountLockManager,
mock(KeysManager.class), mock(KeysManager.class),
mock(MessagesManager.class), mock(MessagesManager.class),

View File

@ -33,6 +33,7 @@ import static org.mockito.Mockito.when;
import com.google.i18n.phonenumbers.PhoneNumberUtil; import com.google.i18n.phonenumbers.PhoneNumberUtil;
import io.lettuce.core.RedisException; import io.lettuce.core.RedisException;
import io.lettuce.core.api.async.RedisAsyncCommands;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.io.InputStream; import java.io.InputStream;
@ -78,6 +79,7 @@ import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecoveryException; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecoveryException;
@ -88,6 +90,7 @@ import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture; import org.whispersystems.textsecuregcm.tests.util.MockRedisFuture;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper; import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.tests.util.RedisServerHelper;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.TestClock;
@ -119,8 +122,9 @@ class AccountsManagerTest {
private Map<String, UUID> phoneNumberIdentifiersByE164; private Map<String, UUID> phoneNumberIdentifiersByE164;
private RedisAdvancedClusterCommands<String, String> commands; private RedisAsyncCommands<String, String> asyncCommands;
private RedisAdvancedClusterAsyncCommands<String, String> asyncCommands; private RedisAdvancedClusterCommands<String, String> clusterCommands;
private RedisAdvancedClusterAsyncCommands<String, String> asyncClusterCommands;
private AccountsManager accountsManager; private AccountsManager accountsManager;
private SecureValueRecovery2Client svr2Client; private SecureValueRecovery2Client svr2Client;
private DynamicConfiguration dynamicConfiguration; private DynamicConfiguration dynamicConfiguration;
@ -162,14 +166,18 @@ class AccountsManagerTest {
}).when(clientPresenceExecutor).execute(any()); }).when(clientPresenceExecutor).execute(any());
//noinspection unchecked //noinspection unchecked
commands = mock(RedisAdvancedClusterCommands.class); asyncCommands = mock(RedisAsyncCommands.class);
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
//noinspection unchecked //noinspection unchecked
asyncCommands = mock(RedisAdvancedClusterAsyncCommands.class); clusterCommands = mock(RedisAdvancedClusterCommands.class);
when(asyncCommands.del(any(String[].class))).thenReturn(MockRedisFuture.completedFuture(0L));
when(asyncCommands.get(any())).thenReturn(MockRedisFuture.completedFuture(null)); //noinspection unchecked
when(asyncCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); asyncClusterCommands = mock(RedisAdvancedClusterAsyncCommands.class);
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); when(asyncClusterCommands.del(any(String[].class))).thenReturn(MockRedisFuture.completedFuture(0L));
when(asyncClusterCommands.get(any())).thenReturn(MockRedisFuture.completedFuture(null));
when(asyncClusterCommands.set(any(), any(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(accounts.updateAsync(any())).thenReturn(CompletableFuture.completedFuture(null)); when(accounts.updateAsync(any())).thenReturn(CompletableFuture.completedFuture(null));
when(accounts.updateTransactionallyAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(accounts.updateTransactionallyAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
@ -230,15 +238,20 @@ class AccountsManagerTest {
CLOCK = TestClock.now(); CLOCK = TestClock.now();
final FaultTolerantRedisClusterClient redisCluster = RedisClusterHelper.builder() final FaultTolerantRedisClient pubSubClient = RedisServerHelper.builder()
.stringCommands(commands)
.stringAsyncCommands(asyncCommands) .stringAsyncCommands(asyncCommands)
.build(); .build();
final FaultTolerantRedisClusterClient redisCluster = RedisClusterHelper.builder()
.stringCommands(clusterCommands)
.stringAsyncCommands(asyncClusterCommands)
.build();
accountsManager = new AccountsManager( accountsManager = new AccountsManager(
accounts, accounts,
phoneNumberIdentifiers, phoneNumberIdentifiers,
redisCluster, redisCluster,
pubSubClient,
accountLockManager, accountLockManager,
keysManager, keysManager,
messagesManager, messagesManager,
@ -285,8 +298,8 @@ class AccountsManagerTest {
final UUID aci = UUID.randomUUID(); final UUID aci = UUID.randomUUID();
final UUID pni = UUID.randomUUID(); final UUID pni = UUID.randomUUID();
when(commands.get(eq("AccountMap::" + pni))).thenReturn(aci.toString()); when(clusterCommands.get(eq("AccountMap::" + pni))).thenReturn(aci.toString());
when(commands.get(eq("Account3::" + aci))).thenReturn( when(clusterCommands.get(eq("Account3::" + aci))).thenReturn(
"{\"number\": \"+14152222222\", \"pni\": \"" + pni + "\"}"); "{\"number\": \"+14152222222\", \"pni\": \"" + pni + "\"}");
assertTrue(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(aci)).isPresent()); assertTrue(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(aci)).isPresent());
@ -300,11 +313,11 @@ class AccountsManagerTest {
final UUID aci = UUID.randomUUID(); final UUID aci = UUID.randomUUID();
final UUID pni = UUID.randomUUID(); final UUID pni = UUID.randomUUID();
when(asyncCommands.get(eq("AccountMap::" + pni))).thenReturn(MockRedisFuture.completedFuture(aci.toString())); when(asyncClusterCommands.get(eq("AccountMap::" + pni))).thenReturn(MockRedisFuture.completedFuture(aci.toString()));
when(asyncCommands.get(eq("Account3::" + aci))).thenReturn(MockRedisFuture.completedFuture( when(asyncClusterCommands.get(eq("Account3::" + aci))).thenReturn(MockRedisFuture.completedFuture(
"{\"number\": \"+14152222222\", \"pni\": \"" + pni + "\"}")); "{\"number\": \"+14152222222\", \"pni\": \"" + pni + "\"}"));
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(accounts.getByAccountIdentifierAsync(any())) when(accounts.getByAccountIdentifierAsync(any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty())); .thenReturn(CompletableFuture.completedFuture(Optional.empty()));
@ -323,7 +336,7 @@ class AccountsManagerTest {
void testGetAccountByUuidInCache() { void testGetAccountByUuidInCache() {
UUID uuid = UUID.randomUUID(); UUID uuid = UUID.randomUUID();
when(commands.get(eq("Account3::" + uuid))).thenReturn( when(clusterCommands.get(eq("Account3::" + uuid))).thenReturn(
"{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}"); "{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}");
Optional<Account> account = accountsManager.getByAccountIdentifier(uuid); Optional<Account> account = accountsManager.getByAccountIdentifier(uuid);
@ -333,8 +346,8 @@ class AccountsManagerTest {
assertEquals(account.get().getUuid(), uuid); assertEquals(account.get().getUuid(), uuid);
assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier()); assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier());
verify(commands, times(1)).get(eq("Account3::" + uuid)); verify(clusterCommands, times(1)).get(eq("Account3::" + uuid));
verifyNoMoreInteractions(commands); verifyNoMoreInteractions(clusterCommands);
verifyNoInteractions(accounts); verifyNoInteractions(accounts);
} }
@ -343,10 +356,10 @@ class AccountsManagerTest {
void testGetAccountByUuidInCacheAsync() { void testGetAccountByUuidInCacheAsync() {
UUID uuid = UUID.randomUUID(); UUID uuid = UUID.randomUUID();
when(asyncCommands.get(eq("Account3::" + uuid))).thenReturn(MockRedisFuture.completedFuture( when(asyncClusterCommands.get(eq("Account3::" + uuid))).thenReturn(MockRedisFuture.completedFuture(
"{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}")); "{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}"));
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
Optional<Account> account = accountsManager.getByAccountIdentifierAsync(uuid).join(); Optional<Account> account = accountsManager.getByAccountIdentifierAsync(uuid).join();
@ -355,8 +368,8 @@ class AccountsManagerTest {
assertEquals(account.get().getUuid(), uuid); assertEquals(account.get().getUuid(), uuid);
assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier()); assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier());
verify(asyncCommands, times(1)).get(eq("Account3::" + uuid)); verify(asyncClusterCommands, times(1)).get(eq("Account3::" + uuid));
verifyNoMoreInteractions(asyncCommands); verifyNoMoreInteractions(asyncClusterCommands);
verifyNoInteractions(accounts); verifyNoInteractions(accounts);
} }
@ -366,8 +379,8 @@ class AccountsManagerTest {
UUID uuid = UUID.randomUUID(); UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID(); UUID pni = UUID.randomUUID();
when(commands.get(eq("AccountMap::" + pni))).thenReturn(uuid.toString()); when(clusterCommands.get(eq("AccountMap::" + pni))).thenReturn(uuid.toString());
when(commands.get(eq("Account3::" + uuid))).thenReturn( when(clusterCommands.get(eq("Account3::" + uuid))).thenReturn(
"{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}"); "{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}");
Optional<Account> account = accountsManager.getByPhoneNumberIdentifier(pni); Optional<Account> account = accountsManager.getByPhoneNumberIdentifier(pni);
@ -376,9 +389,9 @@ class AccountsManagerTest {
assertEquals(account.get().getNumber(), "+14152222222"); assertEquals(account.get().getNumber(), "+14152222222");
assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier()); assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier());
verify(commands).get(eq("AccountMap::" + pni)); verify(clusterCommands).get(eq("AccountMap::" + pni));
verify(commands).get(eq("Account3::" + uuid)); verify(clusterCommands).get(eq("Account3::" + uuid));
verifyNoMoreInteractions(commands); verifyNoMoreInteractions(clusterCommands);
verifyNoInteractions(accounts); verifyNoInteractions(accounts);
} }
@ -388,13 +401,13 @@ class AccountsManagerTest {
UUID uuid = UUID.randomUUID(); UUID uuid = UUID.randomUUID();
UUID pni = UUID.randomUUID(); UUID pni = UUID.randomUUID();
when(asyncCommands.get(eq("AccountMap::" + pni))) when(asyncClusterCommands.get(eq("AccountMap::" + pni)))
.thenReturn(MockRedisFuture.completedFuture(uuid.toString())); .thenReturn(MockRedisFuture.completedFuture(uuid.toString()));
when(asyncCommands.get(eq("Account3::" + uuid))).thenReturn(MockRedisFuture.completedFuture( when(asyncClusterCommands.get(eq("Account3::" + uuid))).thenReturn(MockRedisFuture.completedFuture(
"{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}")); "{\"number\": \"+14152222222\", \"pni\": \"de24dc73-fbd8-41be-a7d5-764c70d9da7e\"}"));
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
Optional<Account> account = accountsManager.getByPhoneNumberIdentifierAsync(pni).join(); Optional<Account> account = accountsManager.getByPhoneNumberIdentifierAsync(pni).join();
@ -402,9 +415,9 @@ class AccountsManagerTest {
assertEquals(account.get().getNumber(), "+14152222222"); assertEquals(account.get().getNumber(), "+14152222222");
assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier()); assertEquals(UUID.fromString("de24dc73-fbd8-41be-a7d5-764c70d9da7e"), account.get().getPhoneNumberIdentifier());
verify(asyncCommands).get(eq("AccountMap::" + pni)); verify(asyncClusterCommands).get(eq("AccountMap::" + pni));
verify(asyncCommands).get(eq("Account3::" + uuid)); verify(asyncClusterCommands).get(eq("Account3::" + uuid));
verifyNoMoreInteractions(asyncCommands); verifyNoMoreInteractions(asyncClusterCommands);
verifyNoInteractions(accounts); verifyNoInteractions(accounts);
} }
@ -415,7 +428,7 @@ class AccountsManagerTest {
UUID pni = UUID.randomUUID(); UUID pni = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null); when(clusterCommands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifier(eq(uuid))).thenReturn(Optional.of(account)); when(accounts.getByAccountIdentifier(eq(uuid))).thenReturn(Optional.of(account));
Optional<Account> retrieved = accountsManager.getByAccountIdentifier(uuid); Optional<Account> retrieved = accountsManager.getByAccountIdentifier(uuid);
@ -423,10 +436,10 @@ class AccountsManagerTest {
assertTrue(retrieved.isPresent()); assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account); assertSame(retrieved.get(), account);
verify(commands, times(1)).get(eq("Account3::" + uuid)); verify(clusterCommands, times(1)).get(eq("Account3::" + uuid));
verify(commands, times(1)).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); verify(clusterCommands, times(1)).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(commands, times(1)).setex(eq("Account3::" + uuid), anyLong(), anyString()); verify(clusterCommands, times(1)).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(commands); verifyNoMoreInteractions(clusterCommands);
verify(accounts, times(1)).getByAccountIdentifier(eq(uuid)); verify(accounts, times(1)).getByAccountIdentifier(eq(uuid));
verifyNoMoreInteractions(accounts); verifyNoMoreInteractions(accounts);
@ -438,8 +451,8 @@ class AccountsManagerTest {
UUID pni = UUID.randomUUID(); UUID pni = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(asyncCommands.get(eq("Account3::" + uuid))).thenReturn(MockRedisFuture.completedFuture(null)); when(asyncClusterCommands.get(eq("Account3::" + uuid))).thenReturn(MockRedisFuture.completedFuture(null));
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(accounts.getByAccountIdentifierAsync(eq(uuid))) when(accounts.getByAccountIdentifierAsync(eq(uuid)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account))); .thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
@ -448,10 +461,10 @@ class AccountsManagerTest {
assertTrue(retrieved.isPresent()); assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account); assertSame(retrieved.get(), account);
verify(asyncCommands).get(eq("Account3::" + uuid)); verify(asyncClusterCommands).get(eq("Account3::" + uuid));
verify(asyncCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); verify(asyncClusterCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(asyncCommands).setex(eq("Account3::" + uuid), anyLong(), anyString()); verify(asyncClusterCommands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(asyncCommands); verifyNoMoreInteractions(asyncClusterCommands);
verify(accounts).getByAccountIdentifierAsync(eq(uuid)); verify(accounts).getByAccountIdentifierAsync(eq(uuid));
verifyNoMoreInteractions(accounts); verifyNoMoreInteractions(accounts);
@ -464,7 +477,7 @@ class AccountsManagerTest {
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(commands.get(eq("AccountMap::" + pni))).thenReturn(null); when(clusterCommands.get(eq("AccountMap::" + pni))).thenReturn(null);
when(accounts.getByPhoneNumberIdentifier(pni)).thenReturn(Optional.of(account)); when(accounts.getByPhoneNumberIdentifier(pni)).thenReturn(Optional.of(account));
Optional<Account> retrieved = accountsManager.getByPhoneNumberIdentifier(pni); Optional<Account> retrieved = accountsManager.getByPhoneNumberIdentifier(pni);
@ -472,10 +485,10 @@ class AccountsManagerTest {
assertTrue(retrieved.isPresent()); assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account); assertSame(retrieved.get(), account);
verify(commands).get(eq("AccountMap::" + pni)); verify(clusterCommands).get(eq("AccountMap::" + pni));
verify(commands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); verify(clusterCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(commands).setex(eq("Account3::" + uuid), anyLong(), anyString()); verify(clusterCommands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(commands); verifyNoMoreInteractions(clusterCommands);
verify(accounts).getByPhoneNumberIdentifier(pni); verify(accounts).getByPhoneNumberIdentifier(pni);
verifyNoMoreInteractions(accounts); verifyNoMoreInteractions(accounts);
@ -488,8 +501,8 @@ class AccountsManagerTest {
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(asyncCommands.get(eq("AccountMap::" + pni))).thenReturn(MockRedisFuture.completedFuture(null)); when(asyncClusterCommands.get(eq("AccountMap::" + pni))).thenReturn(MockRedisFuture.completedFuture(null));
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(accounts.getByPhoneNumberIdentifierAsync(pni)) when(accounts.getByPhoneNumberIdentifierAsync(pni))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account))); .thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
@ -498,10 +511,10 @@ class AccountsManagerTest {
assertTrue(retrieved.isPresent()); assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account); assertSame(retrieved.get(), account);
verify(asyncCommands).get(eq("AccountMap::" + pni)); verify(asyncClusterCommands).get(eq("AccountMap::" + pni));
verify(asyncCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); verify(asyncClusterCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(asyncCommands).setex(eq("Account3::" + uuid), anyLong(), anyString()); verify(asyncClusterCommands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(asyncCommands); verifyNoMoreInteractions(asyncClusterCommands);
verify(accounts).getByPhoneNumberIdentifierAsync(pni); verify(accounts).getByPhoneNumberIdentifierAsync(pni);
verifyNoMoreInteractions(accounts); verifyNoMoreInteractions(accounts);
@ -528,7 +541,7 @@ class AccountsManagerTest {
UUID pni = UUID.randomUUID(); UUID pni = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(commands.get(eq("Account3::" + uuid))).thenThrow(new RedisException("Connection lost!")); when(clusterCommands.get(eq("Account3::" + uuid))).thenThrow(new RedisException("Connection lost!"));
when(accounts.getByAccountIdentifier(eq(uuid))).thenReturn(Optional.of(account)); when(accounts.getByAccountIdentifier(eq(uuid))).thenReturn(Optional.of(account));
Optional<Account> retrieved = accountsManager.getByAccountIdentifier(uuid); Optional<Account> retrieved = accountsManager.getByAccountIdentifier(uuid);
@ -536,10 +549,10 @@ class AccountsManagerTest {
assertTrue(retrieved.isPresent()); assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account); assertSame(retrieved.get(), account);
verify(commands, times(1)).get(eq("Account3::" + uuid)); verify(clusterCommands, times(1)).get(eq("Account3::" + uuid));
verify(commands, times(1)).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); verify(clusterCommands, times(1)).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(commands, times(1)).setex(eq("Account3::" + uuid), anyLong(), anyString()); verify(clusterCommands, times(1)).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(commands); verifyNoMoreInteractions(clusterCommands);
verify(accounts, times(1)).getByAccountIdentifier(eq(uuid)); verify(accounts, times(1)).getByAccountIdentifier(eq(uuid));
verifyNoMoreInteractions(accounts); verifyNoMoreInteractions(accounts);
@ -551,10 +564,10 @@ class AccountsManagerTest {
UUID pni = UUID.randomUUID(); UUID pni = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(asyncCommands.get(eq("Account3::" + uuid))) when(asyncClusterCommands.get(eq("Account3::" + uuid)))
.thenReturn(MockRedisFuture.failedFuture(new RedisException("Connection lost!"))); .thenReturn(MockRedisFuture.failedFuture(new RedisException("Connection lost!")));
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(accounts.getByAccountIdentifierAsync(eq(uuid))) when(accounts.getByAccountIdentifierAsync(eq(uuid)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account))); .thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
@ -564,10 +577,10 @@ class AccountsManagerTest {
assertTrue(retrieved.isPresent()); assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account); assertSame(retrieved.get(), account);
verify(asyncCommands).get(eq("Account3::" + uuid)); verify(asyncClusterCommands).get(eq("Account3::" + uuid));
verify(asyncCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); verify(asyncClusterCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(asyncCommands).setex(eq("Account3::" + uuid), anyLong(), anyString()); verify(asyncClusterCommands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(asyncCommands); verifyNoMoreInteractions(asyncClusterCommands);
verify(accounts).getByAccountIdentifierAsync(eq(uuid)); verify(accounts).getByAccountIdentifierAsync(eq(uuid));
verifyNoMoreInteractions(accounts); verifyNoMoreInteractions(accounts);
@ -580,7 +593,7 @@ class AccountsManagerTest {
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(commands.get(eq("AccountMap::" + pni))).thenThrow(new RedisException("OH NO")); when(clusterCommands.get(eq("AccountMap::" + pni))).thenThrow(new RedisException("OH NO"));
when(accounts.getByPhoneNumberIdentifier(pni)).thenReturn(Optional.of(account)); when(accounts.getByPhoneNumberIdentifier(pni)).thenReturn(Optional.of(account));
Optional<Account> retrieved = accountsManager.getByPhoneNumberIdentifier(pni); Optional<Account> retrieved = accountsManager.getByPhoneNumberIdentifier(pni);
@ -588,10 +601,10 @@ class AccountsManagerTest {
assertTrue(retrieved.isPresent()); assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account); assertSame(retrieved.get(), account);
verify(commands).get(eq("AccountMap::" + pni)); verify(clusterCommands).get(eq("AccountMap::" + pni));
verify(commands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); verify(clusterCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(commands).setex(eq("Account3::" + uuid), anyLong(), anyString()); verify(clusterCommands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(commands); verifyNoMoreInteractions(clusterCommands);
verify(accounts).getByPhoneNumberIdentifier(pni); verify(accounts).getByPhoneNumberIdentifier(pni);
verifyNoMoreInteractions(accounts); verifyNoMoreInteractions(accounts);
@ -604,10 +617,10 @@ class AccountsManagerTest {
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(asyncCommands.get(eq("AccountMap::" + pni))) when(asyncClusterCommands.get(eq("AccountMap::" + pni)))
.thenReturn(MockRedisFuture.failedFuture(new RedisException("OH NO"))); .thenReturn(MockRedisFuture.failedFuture(new RedisException("OH NO")));
when(asyncCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK")); when(asyncClusterCommands.setex(any(), anyLong(), any())).thenReturn(MockRedisFuture.completedFuture("OK"));
when(accounts.getByPhoneNumberIdentifierAsync(pni)) when(accounts.getByPhoneNumberIdentifierAsync(pni))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account))); .thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
@ -617,10 +630,10 @@ class AccountsManagerTest {
assertTrue(retrieved.isPresent()); assertTrue(retrieved.isPresent());
assertSame(retrieved.get(), account); assertSame(retrieved.get(), account);
verify(asyncCommands).get(eq("AccountMap::" + pni)); verify(asyncClusterCommands).get(eq("AccountMap::" + pni));
verify(asyncCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString())); verify(asyncClusterCommands).setex(eq("AccountMap::" + pni), anyLong(), eq(uuid.toString()));
verify(asyncCommands).setex(eq("Account3::" + uuid), anyLong(), anyString()); verify(asyncClusterCommands).setex(eq("Account3::" + uuid), anyLong(), anyString());
verifyNoMoreInteractions(asyncCommands); verifyNoMoreInteractions(asyncClusterCommands);
verify(accounts).getByPhoneNumberIdentifierAsync(pni); verify(accounts).getByPhoneNumberIdentifierAsync(pni);
verifyNoMoreInteractions(accounts); verifyNoMoreInteractions(accounts);
@ -632,7 +645,7 @@ class AccountsManagerTest {
UUID pni = UUID.randomUUID(); UUID pni = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null); when(clusterCommands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifier(uuid)).thenReturn( when(accounts.getByAccountIdentifier(uuid)).thenReturn(
Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]))); Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH])));
@ -658,7 +671,7 @@ class AccountsManagerTest {
UUID pni = UUID.randomUUID(); UUID pni = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(asyncCommands.get(eq("Account3::" + uuid))).thenReturn(null); when(asyncClusterCommands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifierAsync(uuid)).thenReturn(CompletableFuture.completedFuture( when(accounts.getByAccountIdentifierAsync(uuid)).thenReturn(CompletableFuture.completedFuture(
Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH])))); Optional.of(AccountsHelper.generateTestAccount("+14152222222", uuid, pni, new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]))));
@ -684,7 +697,7 @@ class AccountsManagerTest {
UUID uuid = UUID.randomUUID(); UUID uuid = UUID.randomUUID();
Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]); Account account = AccountsHelper.generateTestAccount("+14152222222", uuid, UUID.randomUUID(), new ArrayList<>(), new byte[UnidentifiedAccessUtil.UNIDENTIFIED_ACCESS_KEY_LENGTH]);
when(commands.get(eq("Account3::" + uuid))).thenReturn(null); when(clusterCommands.get(eq("Account3::" + uuid))).thenReturn(null);
when(accounts.getByAccountIdentifier(uuid)).thenReturn(Optional.empty()) when(accounts.getByAccountIdentifier(uuid)).thenReturn(Optional.empty())
.thenReturn(Optional.of(account)); .thenReturn(Optional.of(account));
when(accounts.create(any(), any())).thenThrow(ContestedOptimisticLockException.class); when(accounts.create(any(), any())).thenThrow(ContestedOptimisticLockException.class);
@ -971,7 +984,7 @@ class AccountsManagerTest {
pniSignedPreKey, pniSignedPreKey,
aciPqLastResortPreKey, aciPqLastResortPreKey,
pniPqLastResortPreKey), pniPqLastResortPreKey),
accountsManager.generateDeviceLinkingToken(aci)) accountsManager.generateLinkDeviceToken(aci))
.join(); .join();
verify(keysManager).deleteSingleUsePreKeys(aci, nextDeviceId); verify(keysManager).deleteSingleUsePreKeys(aci, nextDeviceId);
@ -1606,7 +1619,7 @@ class AccountsManagerTest {
final UUID aci = UUID.randomUUID(); final UUID aci = UUID.randomUUID();
assertEquals(Optional.of(aci), assertEquals(Optional.of(aci),
accountsManager.checkDeviceLinkingToken(accountsManager.generateDeviceLinkingToken(aci))); accountsManager.checkDeviceLinkingToken(accountsManager.generateLinkDeviceToken(aci)));
} }
@ParameterizedTest @ParameterizedTest
@ -1622,7 +1635,7 @@ class AccountsManagerTest {
return Stream.of( return Stream.of(
// Expired token // Expired token
Arguments.of(AccountsManager.generateDeviceLinkingToken(UUID.randomUUID(), Arguments.of(AccountsManager.generateLinkDeviceToken(UUID.randomUUID(),
new SecretKeySpec(LINK_DEVICE_SECRET, AccountsManager.LINK_DEVICE_VERIFICATION_TOKEN_ALGORITHM), new SecretKeySpec(LINK_DEVICE_SECRET, AccountsManager.LINK_DEVICE_VERIFICATION_TOKEN_ALGORITHM),
CLOCK), CLOCK),
tokenTimestamp.plus(AccountsManager.LINK_DEVICE_TOKEN_EXPIRATION_DURATION).plusSeconds(1)), tokenTimestamp.plus(AccountsManager.LINK_DEVICE_TOKEN_EXPIRATION_DURATION).plusSeconds(1)),

View File

@ -36,6 +36,7 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
@ -137,6 +138,7 @@ class AccountsManagerUsernameIntegrationTest {
accounts, accounts,
phoneNumberIdentifiers, phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(), CACHE_CLUSTER_EXTENSION.getRedisCluster(),
mock(FaultTolerantRedisClient.class),
accountLockManager, accountLockManager,
keysManager, keysManager,
messageManager, messageManager,

View File

@ -13,6 +13,7 @@ import static org.mockito.Mockito.when;
import com.google.i18n.phonenumbers.PhoneNumberUtil; import com.google.i18n.phonenumbers.PhoneNumberUtil;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Clock; import java.time.Clock;
import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.time.ZoneId; import java.time.ZoneId;
import java.util.Optional; import java.util.Optional;
@ -29,9 +30,11 @@ import org.junit.jupiter.api.extension.RegisterExtension;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
import org.whispersystems.textsecuregcm.redis.RedisServerExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
@ -59,6 +62,9 @@ public class AddRemoveDeviceIntegrationTest {
@RegisterExtension @RegisterExtension
static final RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); static final RedisClusterExtension CACHE_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@RegisterExtension
static final RedisServerExtension PUBSUB_SERVER_EXTENSION = RedisServerExtension.builder().build();
private static final Clock CLOCK = Clock.fixed(Instant.now(), ZoneId.systemDefault()); private static final Clock CLOCK = Clock.fixed(Instant.now(), ZoneId.systemDefault());
private ExecutorService accountLockExecutor; private ExecutorService accountLockExecutor;
@ -128,10 +134,16 @@ public class AddRemoveDeviceIntegrationTest {
when(registrationRecoveryPasswordsManager.removeForNumber(any())) when(registrationRecoveryPasswordsManager.removeForNumber(any()))
.thenReturn(CompletableFuture.completedFuture(null)); .thenReturn(CompletableFuture.completedFuture(null));
PUBSUB_SERVER_EXTENSION.getRedisClient().useConnection(connection -> {
connection.sync().flushall();
connection.sync().configSet("notify-keyspace-events", "K$");
});
accountsManager = new AccountsManager( accountsManager = new AccountsManager(
accounts, accounts,
phoneNumberIdentifiers, phoneNumberIdentifiers,
CACHE_CLUSTER_EXTENSION.getRedisCluster(), CACHE_CLUSTER_EXTENSION.getRedisCluster(),
PUBSUB_SERVER_EXTENSION.getRedisClient(),
accountLockManager, accountLockManager,
keysManager, keysManager,
messagesManager, messagesManager,
@ -146,10 +158,14 @@ public class AddRemoveDeviceIntegrationTest {
CLOCK, CLOCK,
"link-device-secret".getBytes(StandardCharsets.UTF_8), "link-device-secret".getBytes(StandardCharsets.UTF_8),
dynamicConfigurationManager); dynamicConfigurationManager);
accountsManager.start();
} }
@AfterEach @AfterEach
void tearDown() throws InterruptedException { void tearDown() throws InterruptedException {
accountsManager.stop();
accountLockExecutor.shutdown(); accountLockExecutor.shutdown();
clientPresenceExecutor.shutdown(); clientPresenceExecutor.shutdown();
@ -187,7 +203,7 @@ public class AddRemoveDeviceIntegrationTest {
KeysHelper.signedECPreKey(2, pniKeyPair), KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair), KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)), KeysHelper.signedKEMPreKey(4, pniKeyPair)),
accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI))) accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI)))
.join(); .join();
assertEquals(2, updatedAccountAndDevice.first().getDevices().size()); assertEquals(2, updatedAccountAndDevice.first().getDevices().size());
@ -216,7 +232,7 @@ public class AddRemoveDeviceIntegrationTest {
final Account account = AccountsHelper.createAccount(accountsManager, number); final Account account = AccountsHelper.createAccount(accountsManager, number);
assertEquals(1, accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getDevices().size()); assertEquals(1, accountsManager.getByAccountIdentifier(account.getUuid()).orElseThrow().getDevices().size());
final String linkDeviceToken = accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI)); final String linkDeviceToken = accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI));
final Pair<Account, Device> updatedAccountAndDevice = final Pair<Account, Device> updatedAccountAndDevice =
accountsManager.addDevice(account, new DeviceSpec( accountsManager.addDevice(account, new DeviceSpec(
@ -292,7 +308,7 @@ public class AddRemoveDeviceIntegrationTest {
KeysHelper.signedECPreKey(2, pniKeyPair), KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair), KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)), KeysHelper.signedKEMPreKey(4, pniKeyPair)),
accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI))) accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI)))
.join(); .join();
final byte addedDeviceId = updatedAccountAndDevice.second().getId(); final byte addedDeviceId = updatedAccountAndDevice.second().getId();
@ -346,7 +362,7 @@ public class AddRemoveDeviceIntegrationTest {
KeysHelper.signedECPreKey(2, pniKeyPair), KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair), KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)), KeysHelper.signedKEMPreKey(4, pniKeyPair)),
accountsManager.generateDeviceLinkingToken(account.getIdentifier(IdentityType.ACI))) accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI)))
.join(); .join();
final byte addedDeviceId = updatedAccountAndDevice.second().getId(); final byte addedDeviceId = updatedAccountAndDevice.second().getId();
@ -376,4 +392,110 @@ public class AddRemoveDeviceIntegrationTest {
assertTrue(keysManager.getLastResort(retrievedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); assertTrue(keysManager.getLastResort(retrievedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent());
assertTrue(clientPublicKeysManager.findPublicKey(retrievedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent()); assertTrue(clientPublicKeysManager.findPublicKey(retrievedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent());
} }
@Test
void waitForNewLinkedDevice() throws InterruptedException {
final String number = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
final Account account = AccountsHelper.createAccount(accountsManager, number);
final String linkDeviceToken = accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI));
final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken);
final CompletableFuture<Optional<DeviceInfo>> displacedFuture =
accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofSeconds(5));
final CompletableFuture<Optional<DeviceInfo>> activeFuture =
accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofSeconds(5));
assertEquals(Optional.empty(), displacedFuture.join());
final Pair<Account, Device> updatedAccountAndDevice =
accountsManager.addDevice(account, new DeviceSpec(
"device-name".getBytes(StandardCharsets.UTF_8),
"password",
"OWT",
new Device.DeviceCapabilities(true, true, true, false),
1,
2,
true,
Optional.empty(),
Optional.empty(),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)),
linkDeviceToken)
.join();
final Optional<DeviceInfo> maybeDeviceInfo = activeFuture.join();
assertTrue(maybeDeviceInfo.isPresent());
final DeviceInfo deviceInfo = maybeDeviceInfo.get();
assertEquals(updatedAccountAndDevice.second().getId(), deviceInfo.id());
assertEquals(updatedAccountAndDevice.second().getCreated(), deviceInfo.created());
}
@Test
void waitForNewLinkedDeviceAlreadyAdded() throws InterruptedException {
final String number = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
final Account account = AccountsHelper.createAccount(accountsManager, number);
final String linkDeviceToken = accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI));
final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken);
final Pair<Account, Device> updatedAccountAndDevice =
accountsManager.addDevice(account, new DeviceSpec(
"device-name".getBytes(StandardCharsets.UTF_8),
"password",
"OWT",
new Device.DeviceCapabilities(true, true, true, false),
1,
2,
true,
Optional.empty(),
Optional.empty(),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)),
linkDeviceToken)
.join();
final CompletableFuture<Optional<DeviceInfo>> linkedDeviceFuture =
accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofMinutes(1));
final Optional<DeviceInfo> maybeDeviceInfo = linkedDeviceFuture.join();
assertTrue(maybeDeviceInfo.isPresent());
final DeviceInfo deviceInfo = maybeDeviceInfo.get();
assertEquals(updatedAccountAndDevice.second().getId(), deviceInfo.id());
assertEquals(updatedAccountAndDevice.second().getCreated(), deviceInfo.created());
}
@Test
void waitForNewLinkedDeviceTimeout() {
final String linkDeviceToken = accountsManager.generateLinkDeviceToken(UUID.randomUUID());
final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken);
final CompletableFuture<Optional<DeviceInfo>> linkedDeviceFuture =
accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofMillis(10));
final Optional<DeviceInfo> maybeDeviceInfo = linkedDeviceFuture.join();
assertTrue(maybeDeviceInfo.isEmpty());
}
} }

View File

@ -0,0 +1,112 @@
package org.whispersystems.textsecuregcm.tests.util;
import io.lettuce.core.api.StatefulRedisConnection;
import io.lettuce.core.api.async.RedisAsyncCommands;
import io.lettuce.core.api.reactive.RedisReactiveCommands;
import io.lettuce.core.api.sync.RedisCommands;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import java.util.function.Consumer;
import java.util.function.Function;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class RedisServerHelper {
public static RedisServerHelper.Builder builder() {
return new RedisServerHelper.Builder();
}
@SuppressWarnings("unchecked")
private static FaultTolerantRedisClient buildMockRedisClient(
final RedisCommands<String, String> stringCommands,
final RedisAsyncCommands<String, String> stringAsyncCommands,
final RedisCommands<byte[], byte[]> binaryCommands,
final RedisAsyncCommands<byte[], byte[]> binaryAsyncCommands,
final RedisReactiveCommands<byte[], byte[]> binaryReactiveCommands) {
final FaultTolerantRedisClient client = mock(FaultTolerantRedisClient.class);
final StatefulRedisConnection<String, String> stringConnection = mock(StatefulRedisConnection.class);
final StatefulRedisConnection<byte[], byte[]> binaryConnection = mock(StatefulRedisConnection.class);
when(stringConnection.sync()).thenReturn(stringCommands);
when(stringConnection.async()).thenReturn(stringAsyncCommands);
when(binaryConnection.sync()).thenReturn(binaryCommands);
when(binaryConnection.async()).thenReturn(binaryAsyncCommands);
when(binaryConnection.reactive()).thenReturn(binaryReactiveCommands);
when(client.withConnection(any(Function.class))).thenAnswer(invocation -> {
return invocation.getArgument(0, Function.class).apply(stringConnection);
});
doAnswer(invocation -> {
invocation.getArgument(0, Consumer.class).accept(stringConnection);
return null;
}).when(client).useConnection(any(Consumer.class));
when(client.withBinaryConnection(any(Function.class))).thenAnswer(invocation -> {
return invocation.getArgument(0, Function.class).apply(binaryConnection);
});
doAnswer(invocation -> {
invocation.getArgument(0, Consumer.class).accept(binaryConnection);
return null;
}).when(client).useBinaryConnection(any(Consumer.class));
return client;
}
@SuppressWarnings("unchecked")
public static class Builder {
private RedisCommands<String, String> stringCommands = mock(RedisCommands.class);
private RedisAsyncCommands<String, String> stringAsyncCommands = mock(RedisAsyncCommands.class);
private RedisCommands<byte[], byte[]> binaryCommands = mock(RedisCommands.class);
private RedisAsyncCommands<byte[], byte[]> binaryAsyncCommands =
mock(RedisAsyncCommands.class);
private RedisReactiveCommands<byte[], byte[]> binaryReactiveCommands =
mock(RedisReactiveCommands.class);
private Builder() {
}
public RedisServerHelper.Builder stringCommands(final RedisCommands<String, String> stringCommands) {
this.stringCommands = stringCommands;
return this;
}
public RedisServerHelper.Builder stringAsyncCommands(final RedisAsyncCommands<String, String> stringAsyncCommands) {
this.stringAsyncCommands = stringAsyncCommands;
return this;
}
public RedisServerHelper.Builder binaryCommands(final RedisCommands<byte[], byte[]> binaryCommands) {
this.binaryCommands = binaryCommands;
return this;
}
public RedisServerHelper.Builder binaryAsyncCommands(final RedisAsyncCommands<byte[], byte[]> binaryAsyncCommands) {
this.binaryAsyncCommands = binaryAsyncCommands;
return this;
}
public RedisServerHelper.Builder binaryReactiveCommands(
final RedisReactiveCommands<byte[], byte[]> binaryReactiveCommands) {
this.binaryReactiveCommands = binaryReactiveCommands;
return this;
}
public FaultTolerantRedisClient build() {
return RedisServerHelper.buildMockRedisClient(stringCommands,
stringAsyncCommands,
binaryCommands,
binaryAsyncCommands,
binaryReactiveCommands);
}
}
}