Wait for messages in waitForNewLinkedDevice

This commit is contained in:
Ravi Khadiwala 2024-11-07 15:47:21 -06:00 committed by ravi-signal
parent 3288d3d538
commit 81f3ba17c7
16 changed files with 374 additions and 60 deletions

View File

@ -559,6 +559,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.scheduledExecutorService(name(getClass(), "subscriptionProcessorRetry-%d")).threads(1).build(); .scheduledExecutorService(name(getClass(), "subscriptionProcessorRetry-%d")).threads(1).build();
ScheduledExecutorService cloudflareTurnRetryExecutor = environment.lifecycle() ScheduledExecutorService cloudflareTurnRetryExecutor = environment.lifecycle()
.scheduledExecutorService(name(getClass(), "cloudflareTurnRetry-%d")).threads(1).build(); .scheduledExecutorService(name(getClass(), "cloudflareTurnRetry-%d")).threads(1).build();
ScheduledExecutorService messagePollExecutor = environment.lifecycle()
.scheduledExecutorService(name(getClass(), "messagePollExecutor-%d")).threads(1).build();
final ManagedNioEventLoopGroup dnsResolutionEventLoopGroup = new ManagedNioEventLoopGroup(); final ManagedNioEventLoopGroup dnsResolutionEventLoopGroup = new ManagedNioEventLoopGroup();
final DnsNameResolver cloudflareDnsResolver = new DnsNameResolverBuilder(dnsResolutionEventLoopGroup.next()) final DnsNameResolver cloudflareDnsResolver = new DnsNameResolverBuilder(dnsResolutionEventLoopGroup.next())
@ -620,7 +622,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster, AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
pubsubClient, accountLockManager, keysManager, messagesManager, profilesManager, pubsubClient, accountLockManager, keysManager, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client, disconnectionRequestManager, webSocketConnectionEventManager, secureStorageClient, secureValueRecovery2Client, disconnectionRequestManager, webSocketConnectionEventManager,
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, messagePollExecutor,
clock, config.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager); clock, config.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager);
RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs); RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs);
APNSender apnSender = new APNSender(apnSenderExecutor, config.getApnConfiguration()); APNSender apnSender = new APNSender(apnSenderExecutor, config.getApnConfiguration());

View File

@ -362,9 +362,9 @@ public class DeviceController {
linkedDeviceListenerCounter.incrementAndGet(); linkedDeviceListenerCounter.incrementAndGet();
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
try { try {
return accounts.waitForNewLinkedDevice(tokenIdentifier, Duration.ofSeconds(timeoutSeconds)) return accounts.waitForNewLinkedDevice(authenticatedDevice.getAccount().getUuid(),
authenticatedDevice.getAuthenticatedDevice(), tokenIdentifier, Duration.ofSeconds(timeoutSeconds))
.thenApply(maybeDeviceInfo -> maybeDeviceInfo .thenApply(maybeDeviceInfo -> maybeDeviceInfo
.map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build()) .map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build())
.orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())) .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build()))

View File

@ -47,6 +47,7 @@ import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage; import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
@ -130,6 +131,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager; private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager;
private final ClientPublicKeysManager clientPublicKeysManager; private final ClientPublicKeysManager clientPublicKeysManager;
private final Executor accountLockExecutor; private final Executor accountLockExecutor;
private final ScheduledExecutorService messagesPollExecutor;
private final Clock clock; private final Clock clock;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager; private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
@ -163,6 +165,9 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
private static final ObjectWriter ACCOUNT_REDIS_JSON_WRITER = SystemMapper.jsonMapper() private static final ObjectWriter ACCOUNT_REDIS_JSON_WRITER = SystemMapper.jsonMapper()
.writer(SystemMapper.excludingField(Account.class, List.of("uuid"))); .writer(SystemMapper.excludingField(Account.class, List.of("uuid")));
private static Duration MESSAGE_POLL_INTERVAL = Duration.ofSeconds(1);
private static Duration MAX_SERVER_CLOCK_DRIFT = Duration.ofSeconds(5);
// An account that's used at least daily will get reset in the cache at least once per day when its "last seen" // An account that's used at least daily will get reset in the cache at least once per day when its "last seen"
// timestamp updates; expiring entries after two days will help clear out "zombie" cache entries that are read // timestamp updates; expiring entries after two days will help clear out "zombie" cache entries that are read
// frequently (e.g. the account is in an active group and receives messages frequently), but aren't actively used by // frequently (e.g. the account is in an active group and receives messages frequently), but aren't actively used by
@ -209,6 +214,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager, final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager,
final ClientPublicKeysManager clientPublicKeysManager, final ClientPublicKeysManager clientPublicKeysManager,
final Executor accountLockExecutor, final Executor accountLockExecutor,
final ScheduledExecutorService messagesPollExecutor,
final Clock clock, final Clock clock,
final byte[] linkDeviceSecret, final byte[] linkDeviceSecret,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) { final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) {
@ -227,6 +233,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
this.registrationRecoveryPasswordsManager = requireNonNull(registrationRecoveryPasswordsManager); this.registrationRecoveryPasswordsManager = requireNonNull(registrationRecoveryPasswordsManager);
this.clientPublicKeysManager = clientPublicKeysManager; this.clientPublicKeysManager = clientPublicKeysManager;
this.accountLockExecutor = accountLockExecutor; this.accountLockExecutor = accountLockExecutor;
this.messagesPollExecutor = messagesPollExecutor;
this.clock = requireNonNull(clock); this.clock = requireNonNull(clock);
this.dynamicConfigurationManager = dynamicConfigurationManager; this.dynamicConfigurationManager = dynamicConfigurationManager;
@ -1428,20 +1435,90 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
.thenRun(Util.NOOP); .thenRun(Util.NOOP);
} }
public CompletableFuture<Optional<DeviceInfo>> waitForNewLinkedDevice(final String linkDeviceTokenIdentifier, final Duration timeout) { public CompletableFuture<Optional<DeviceInfo>> waitForNewLinkedDevice(
final UUID accountIdentifier,
final Device linkingDevice,
final String linkDeviceTokenIdentifier,
final Duration timeout) {
if (!linkingDevice.isPrimary()) {
throw new IllegalArgumentException("Only primary devices can link devices");
}
// Unbeknownst to callers but beknownst to us, the "link device token identifier" is the base64/url-encoded SHA256 // Unbeknownst to callers but beknownst to us, the "link device token identifier" is the base64/url-encoded SHA256
// hash of a device-linking token. Before we use the string anywhere, make sure it's the right "shape" for a hash. // hash of a device-linking token. Before we use the string anywhere, make sure it's the right "shape" for a hash.
if (Base64.getUrlDecoder().decode(linkDeviceTokenIdentifier).length != SHA256_HASH_LENGTH) { if (Base64.getUrlDecoder().decode(linkDeviceTokenIdentifier).length != SHA256_HASH_LENGTH) {
return CompletableFuture.failedFuture(new IllegalArgumentException("Invalid token identifier")); return CompletableFuture.failedFuture(new IllegalArgumentException("Invalid token identifier"));
} }
return waitForPubSubKey(waitForDeviceFuturesByTokenIdentifier, final Instant deadline = clock.instant().plus(timeout);
linkDeviceTokenIdentifier, final CompletableFuture<Optional<DeviceInfo>> deviceAdded = waitForPubSubKey(waitForDeviceFuturesByTokenIdentifier,
getLinkedDeviceKey(linkDeviceTokenIdentifier), linkDeviceTokenIdentifier, getLinkedDeviceKey(linkDeviceTokenIdentifier), timeout, this::handleDeviceAdded);
timeout,
this::handleDeviceAdded); return deviceAdded.thenCompose(maybeDeviceInfo -> maybeDeviceInfo.map(deviceInfo -> {
// The device finished linking, we now want to make sure the client has fetched messages that could
// have come in before the device's mailbox was set up.
// A worst case estimate of the wall clock time at which the linked device was added to the account record
Instant deviceLinked = Instant.ofEpochMilli(deviceInfo.created()).plus(MAX_SERVER_CLOCK_DRIFT);
Instant now = clock.instant();
// We know at `now` the device finished linking, so if we waited for all the messages before now it would be
// sufficient. However, now might be much later that the device was linked, so we don't want to force
// the client to wait for messages that are past our worst case estimate of when the device was linked
Instant messageEpoch = Collections.min(List.of(deviceLinked, now));
// We assume that any message with a timestamp after the messageEpoch made it into the linked device's queues
return waitForPreLinkMessagesToBeFetched(accountIdentifier, linkingDevice, deviceInfo, messageEpoch, deadline);
})
.orElseGet(() -> CompletableFuture.completedFuture(maybeDeviceInfo)));
} }
/**
* Wait until there are no pending messages for the authenticatedDevice that have a timestamp lower than the provided
* messageEpoch.
*
* @param aci The account identifier of the device doing the linking
* @param linkingDevice The device doing the linking
* @param linkedDeviceInfo Information about the newly linked device
* @param messageEpoch A time at which the device was linked
* @param deadline The time at which the method will stop waiting
* @return A future that completes when there are no pending messages for the linking device with a timestamp earlier
* the provided messageEpoch, or after the deadline is reached. If the deadline was exceeded, the future will be empty.
*/
private CompletableFuture<Optional<DeviceInfo>> waitForPreLinkMessagesToBeFetched(
final UUID aci,
final Device linkingDevice,
final DeviceInfo linkedDeviceInfo,
final Instant messageEpoch,
final Instant deadline) {
return messagesManager.getEarliestUndeliveredTimestampForDevice(aci, linkingDevice)
.thenCompose(maybeEarliestTimestamp -> {
final boolean clientHasOldMessages = maybeEarliestTimestamp
.map(earliestTimestamp -> earliestTimestamp.isBefore(messageEpoch))
.orElse(false);
if (!clientHasOldMessages) {
// The client has fetched all messages before the messageEpoch
return CompletableFuture.completedFuture(Optional.of(linkedDeviceInfo));
}
final Instant now = clock.instant();
if (now.plus(MESSAGE_POLL_INTERVAL).isAfter(deadline)) {
// Not enough time to try again before the deadline
return CompletableFuture.completedFuture(Optional.empty());
}
// Schedule a retry
return CompletableFuture.supplyAsync(
() -> waitForPreLinkMessagesToBeFetched(aci, linkingDevice, linkedDeviceInfo, messageEpoch, deadline),
r -> messagesPollExecutor.schedule(r, MESSAGE_POLL_INTERVAL.toMillis(), TimeUnit.MILLISECONDS))
.thenCompose(Function.identity());
});
}
private void handleDeviceAdded(final CompletableFuture<Optional<DeviceInfo>> future, final String deviceInfoJson) { private void handleDeviceAdded(final CompletableFuture<Optional<DeviceInfo>> future, final String deviceInfoJson) {
try { try {
future.complete(Optional.of(SystemMapper.jsonMapper().readValue(deviceInfoJson, DeviceInfo.class))); future.complete(Optional.of(SystemMapper.jsonMapper().readValue(deviceInfoJson, DeviceInfo.class)));

View File

@ -290,7 +290,7 @@ public class MessagesCache {
clock.millis() - MAX_EPHEMERAL_MESSAGE_DELAY.toMillis(); clock.millis() - MAX_EPHEMERAL_MESSAGE_DELAY.toMillis();
final Flux<MessageProtos.Envelope> allMessages = getAllMessages(destinationUuid, destinationDevice, final Flux<MessageProtos.Envelope> allMessages = getAllMessages(destinationUuid, destinationDevice,
earliestAllowableEphemeralTimestamp) earliestAllowableEphemeralTimestamp, PAGE_SIZE)
.publish() .publish()
// We expect exactly two subscribers to this base flux: // We expect exactly two subscribers to this base flux:
// 1. the websocket that delivers messages to clients // 1. the websocket that delivers messages to clients
@ -311,6 +311,12 @@ public class MessagesCache {
.tap(Micrometer.metrics(Metrics.globalRegistry)); .tap(Micrometer.metrics(Metrics.globalRegistry));
} }
public Mono<Long> getEarliestUndeliveredTimestamp(final UUID destinationUuid, final byte destinationDevice) {
return getAllMessages(destinationUuid, destinationDevice, -1, 1)
.next()
.map(MessageProtos.Envelope::getServerTimestamp);
}
private static boolean isStaleEphemeralMessage(final MessageProtos.Envelope message, private static boolean isStaleEphemeralMessage(final MessageProtos.Envelope message,
long earliestAllowableTimestamp) { long earliestAllowableTimestamp) {
return message.getEphemeral() && message.getClientTimestamp() < earliestAllowableTimestamp; return message.getEphemeral() && message.getClientTimestamp() < earliestAllowableTimestamp;
@ -330,17 +336,17 @@ public class MessagesCache {
@VisibleForTesting @VisibleForTesting
Flux<MessageProtos.Envelope> getAllMessages(final UUID destinationUuid, final byte destinationDevice, Flux<MessageProtos.Envelope> getAllMessages(final UUID destinationUuid, final byte destinationDevice,
final long earliestAllowableEphemeralTimestamp) { final long earliestAllowableEphemeralTimestamp, final int pageSize) {
// fetch messages by page // fetch messages by page
return getNextMessagePage(destinationUuid, destinationDevice, -1) return getNextMessagePage(destinationUuid, destinationDevice, -1, pageSize)
.expand(queueItemsAndLastMessageId -> { .expand(queueItemsAndLastMessageId -> {
// expand() is breadth-first, so each page will be published in order // expand() is breadth-first, so each page will be published in order
if (queueItemsAndLastMessageId.first().isEmpty()) { if (queueItemsAndLastMessageId.first().isEmpty()) {
return Mono.empty(); return Mono.empty();
} }
return getNextMessagePage(destinationUuid, destinationDevice, queueItemsAndLastMessageId.second()); return getNextMessagePage(destinationUuid, destinationDevice, queueItemsAndLastMessageId.second(), pageSize);
}) })
.limitRate(1) .limitRate(1)
// we want to ensure we dont accidentally block the Lettuce/netty i/o executors // we want to ensure we dont accidentally block the Lettuce/netty i/o executors
@ -478,9 +484,9 @@ public class MessagesCache {
} }
private Mono<Pair<List<byte[]>, Long>> getNextMessagePage(final UUID destinationUuid, final byte destinationDevice, private Mono<Pair<List<byte[]>, Long>> getNextMessagePage(final UUID destinationUuid, final byte destinationDevice,
long messageId) { long messageId, int pageSize) {
return getItemsScript.execute(destinationUuid, destinationDevice, PAGE_SIZE, messageId) return getItemsScript.execute(destinationUuid, destinationDevice, pageSize, messageId)
.map(queueItems -> { .map(queueItems -> {
logger.trace("Processing page: {}", messageId); logger.trace("Processing page: {}", messageId);

View File

@ -8,6 +8,7 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import java.time.Instant;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
@ -200,6 +201,16 @@ public class MessagesManager {
return messagesRemovedFromCache; return messagesRemovedFromCache;
} }
public CompletableFuture<Optional<Instant>> getEarliestUndeliveredTimestampForDevice(UUID destinationUuid, Device destinationDevice) {
// If there's any message in the persisted layer, return the oldest
return Mono.from(messagesDynamoDb.load(destinationUuid, destinationDevice, 1)).map(Envelope::getServerTimestamp)
// If not, return the oldest message in the cache
.switchIfEmpty(messagesCache.getEarliestUndeliveredTimestamp(destinationUuid, destinationDevice.getId()))
.map(epochMilli -> Optional.of(Instant.ofEpochMilli(epochMilli)))
.switchIfEmpty(Mono.just(Optional.empty()))
.toFuture();
}
/** /**
* Inserts the shared multi-recipient message payload to storage. * Inserts the shared multi-recipient message payload to storage.
* *

View File

@ -146,6 +146,8 @@ record CommandDependencies(
.scheduledExecutorService(name(name, "remoteStorageRetry-%d")).threads(1).build(); .scheduledExecutorService(name(name, "remoteStorageRetry-%d")).threads(1).build();
ScheduledExecutorService storageServiceRetryExecutor = environment.lifecycle() ScheduledExecutorService storageServiceRetryExecutor = environment.lifecycle()
.scheduledExecutorService(name(name, "storageServiceRetry-%d")).threads(1).build(); .scheduledExecutorService(name(name, "storageServiceRetry-%d")).threads(1).build();
ScheduledExecutorService messagePollExecutor = environment.lifecycle()
.scheduledExecutorService(name(name, "messagePollExecutor-%d")).threads(1).build();
ExternalServiceCredentialsGenerator storageCredentialsGenerator = SecureStorageController.credentialsGenerator( ExternalServiceCredentialsGenerator storageCredentialsGenerator = SecureStorageController.credentialsGenerator(
configuration.getSecureStorageServiceConfiguration()); configuration.getSecureStorageServiceConfiguration());
@ -227,7 +229,7 @@ record CommandDependencies(
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster, AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
pubsubClient, accountLockManager, keys, messagesManager, profilesManager, pubsubClient, accountLockManager, keys, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client, disconnectionRequestManager, webSocketConnectionEventManager, secureStorageClient, secureValueRecovery2Client, disconnectionRequestManager, webSocketConnectionEventManager,
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, messagePollExecutor,
clock, configuration.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager); clock, configuration.getLinkDeviceSecretConfiguration().secret().value(), dynamicConfigurationManager);
RateLimiters rateLimiters = RateLimiters.createAndValidate(configuration.getLimitsConfiguration(), RateLimiters rateLimiters = RateLimiters.createAndValidate(configuration.getLimitsConfiguration(),
dynamicConfigurationManager, rateLimitersCluster); dynamicConfigurationManager, rateLimitersCluster);

View File

@ -919,7 +919,8 @@ class DeviceControllerTest {
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]); final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any())) when(accountsManager
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo))); .thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo)));
try (final Response response = resources.getJerseyTest() try (final Response response = resources.getJerseyTest()
@ -942,7 +943,8 @@ class DeviceControllerTest {
void waitForLinkedDeviceNoDeviceLinked() { void waitForLinkedDeviceNoDeviceLinked() {
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]); final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any())) when(accountsManager
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty())); .thenReturn(CompletableFuture.completedFuture(Optional.empty()));
try (final Response response = resources.getJerseyTest() try (final Response response = resources.getJerseyTest()
@ -959,7 +961,8 @@ class DeviceControllerTest {
void waitForLinkedDeviceBadTokenIdentifier() { void waitForLinkedDeviceBadTokenIdentifier() {
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]); final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any())) when(accountsManager
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException())); .thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException()));
try (final Response response = resources.getJerseyTest() try (final Response response = resources.getJerseyTest()

View File

@ -25,6 +25,7 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.RandomStringUtils;
@ -74,7 +75,7 @@ public class AccountCreationDeletionIntegrationTest {
private static final Clock CLOCK = Clock.fixed(Instant.now(), ZoneId.systemDefault()); private static final Clock CLOCK = Clock.fixed(Instant.now(), ZoneId.systemDefault());
private ExecutorService accountLockExecutor; private ScheduledExecutorService executor;
private AccountsManager accountsManager; private AccountsManager accountsManager;
private KeysManager keysManager; private KeysManager keysManager;
@ -113,12 +114,12 @@ public class AccountCreationDeletionIntegrationTest {
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS.tableName(), DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS.tableName(),
DynamoDbExtensionSchema.Tables.USED_LINK_DEVICE_TOKENS.tableName()); DynamoDbExtensionSchema.Tables.USED_LINK_DEVICE_TOKENS.tableName());
accountLockExecutor = Executors.newSingleThreadExecutor(); executor = Executors.newSingleThreadScheduledExecutor();
final AccountLockManager accountLockManager = new AccountLockManager(DYNAMO_DB_EXTENSION.getDynamoDbClient(), final AccountLockManager accountLockManager = new AccountLockManager(DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS_LOCK.tableName()); DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS_LOCK.tableName());
clientPublicKeysManager = new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor); clientPublicKeysManager = new ClientPublicKeysManager(clientPublicKeys, accountLockManager, executor);
final SecureStorageClient secureStorageClient = mock(SecureStorageClient.class); final SecureStorageClient secureStorageClient = mock(SecureStorageClient.class);
when(secureStorageClient.deleteStoredData(any())).thenReturn(CompletableFuture.completedFuture(null)); when(secureStorageClient.deleteStoredData(any())).thenReturn(CompletableFuture.completedFuture(null));
@ -164,7 +165,8 @@ public class AccountCreationDeletionIntegrationTest {
webSocketConnectionEventManager, webSocketConnectionEventManager,
registrationRecoveryPasswordsManager, registrationRecoveryPasswordsManager,
clientPublicKeysManager, clientPublicKeysManager,
accountLockExecutor, executor,
executor,
CLOCK, CLOCK,
"link-device-secret".getBytes(StandardCharsets.UTF_8), "link-device-secret".getBytes(StandardCharsets.UTF_8),
dynamicConfigurationManager); dynamicConfigurationManager);
@ -172,10 +174,10 @@ public class AccountCreationDeletionIntegrationTest {
@AfterEach @AfterEach
void tearDown() throws InterruptedException { void tearDown() throws InterruptedException {
accountLockExecutor.shutdown(); executor.shutdown();
//noinspection ResultOfMethodCallIgnored //noinspection ResultOfMethodCallIgnored
accountLockExecutor.awaitTermination(1, TimeUnit.SECONDS); executor.awaitTermination(1, TimeUnit.SECONDS);
} }
@CartesianTest @CartesianTest

View File

@ -23,6 +23,7 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -69,7 +70,7 @@ class AccountsManagerChangeNumberIntegrationTest {
private KeysManager keysManager; private KeysManager keysManager;
private DisconnectionRequestManager disconnectionRequestManager; private DisconnectionRequestManager disconnectionRequestManager;
private WebSocketConnectionEventManager webSocketConnectionEventManager; private WebSocketConnectionEventManager webSocketConnectionEventManager;
private ExecutorService accountLockExecutor; private ScheduledExecutorService executor;
private AccountsManager accountsManager; private AccountsManager accountsManager;
@ -104,13 +105,13 @@ class AccountsManagerChangeNumberIntegrationTest {
Tables.DELETED_ACCOUNTS.tableName(), Tables.DELETED_ACCOUNTS.tableName(),
Tables.USED_LINK_DEVICE_TOKENS.tableName()); Tables.USED_LINK_DEVICE_TOKENS.tableName());
accountLockExecutor = Executors.newSingleThreadExecutor(); executor = Executors.newSingleThreadScheduledExecutor();
final AccountLockManager accountLockManager = new AccountLockManager(DYNAMO_DB_EXTENSION.getDynamoDbClient(), final AccountLockManager accountLockManager = new AccountLockManager(DYNAMO_DB_EXTENSION.getDynamoDbClient(),
Tables.DELETED_ACCOUNTS_LOCK.tableName()); Tables.DELETED_ACCOUNTS_LOCK.tableName());
final ClientPublicKeysManager clientPublicKeysManager = final ClientPublicKeysManager clientPublicKeysManager =
new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor); new ClientPublicKeysManager(clientPublicKeys, accountLockManager, executor);
final SecureStorageClient secureStorageClient = mock(SecureStorageClient.class); final SecureStorageClient secureStorageClient = mock(SecureStorageClient.class);
when(secureStorageClient.deleteStoredData(any())).thenReturn(CompletableFuture.completedFuture(null)); when(secureStorageClient.deleteStoredData(any())).thenReturn(CompletableFuture.completedFuture(null));
@ -151,7 +152,8 @@ class AccountsManagerChangeNumberIntegrationTest {
webSocketConnectionEventManager, webSocketConnectionEventManager,
registrationRecoveryPasswordsManager, registrationRecoveryPasswordsManager,
clientPublicKeysManager, clientPublicKeysManager,
accountLockExecutor, executor,
executor,
mock(Clock.class), mock(Clock.class),
"link-device-secret".getBytes(StandardCharsets.UTF_8), "link-device-secret".getBytes(StandardCharsets.UTF_8),
dynamicConfigurationManager); dynamicConfigurationManager);
@ -160,10 +162,10 @@ class AccountsManagerChangeNumberIntegrationTest {
@AfterEach @AfterEach
void tearDown() throws InterruptedException { void tearDown() throws InterruptedException {
accountLockExecutor.shutdown(); executor.shutdown();
//noinspection ResultOfMethodCallIgnored //noinspection ResultOfMethodCallIgnored
accountLockExecutor.awaitTermination(1, TimeUnit.SECONDS); executor.awaitTermination(1, TimeUnit.SECONDS);
} }
@Test @Test

View File

@ -30,6 +30,7 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Consumer; import java.util.function.Consumer;
@ -139,6 +140,7 @@ class AccountsManagerConcurrentModificationIntegrationTest {
mock(RegistrationRecoveryPasswordsManager.class), mock(RegistrationRecoveryPasswordsManager.class),
mock(ClientPublicKeysManager.class), mock(ClientPublicKeysManager.class),
mock(Executor.class), mock(Executor.class),
mock(ScheduledExecutorService.class),
mock(Clock.class), mock(Clock.class),
"link-device-secret".getBytes(StandardCharsets.UTF_8), "link-device-secret".getBytes(StandardCharsets.UTF_8),
dynamicConfigurationManager dynamicConfigurationManager

View File

@ -29,6 +29,7 @@ import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
@ -68,6 +69,7 @@ public class AccountsManagerDeviceTransferIntegrationTest {
mock(RegistrationRecoveryPasswordsManager.class), mock(RegistrationRecoveryPasswordsManager.class),
mock(ClientPublicKeysManager.class), mock(ClientPublicKeysManager.class),
mock(ExecutorService.class), mock(ExecutorService.class),
mock(ScheduledExecutorService.class),
Clock.systemUTC(), Clock.systemUTC(),
"link-device-secret".getBytes(StandardCharsets.UTF_8), "link-device-secret".getBytes(StandardCharsets.UTF_8),
mock(DynamicConfigurationManager.class)); mock(DynamicConfigurationManager.class));

View File

@ -54,6 +54,7 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -262,6 +263,7 @@ class AccountsManagerTest {
registrationRecoveryPasswordsManager, registrationRecoveryPasswordsManager,
clientPublicKeysManager, clientPublicKeysManager,
mock(Executor.class), mock(Executor.class),
mock(ScheduledExecutorService.class),
CLOCK, CLOCK,
LINK_DEVICE_SECRET, LINK_DEVICE_SECRET,
dynamicConfigurationManager); dynamicConfigurationManager);
@ -1537,6 +1539,21 @@ class AccountsManagerTest {
assertThrows(AssertionError.class, () -> accountsManager.update(account, a -> a.setUsernameHash(USERNAME_HASH_1))); assertThrows(AssertionError.class, () -> accountsManager.update(account, a -> a.setUsernameHash(USERNAME_HASH_1)));
} }
@Test
void testOnlyPrimaryCanWaitForDeviceLinked() {
final Device primaryDevice = new Device();
primaryDevice.setId(Device.PRIMARY_ID);
final Device linkedDevice = new Device();
linkedDevice.setId((byte) (Device.PRIMARY_ID + 1));
final Account account = AccountsHelper.generateTestAccount("+14152222222", List.of(primaryDevice, linkedDevice));
assertThrows(IllegalArgumentException.class,
() -> accountsManager.waitForNewLinkedDevice(account.getUuid(), linkedDevice, "", Duration.ofSeconds(1)));
}
@Test @Test
void testJsonRoundTripSerialization() throws Exception { void testJsonRoundTripSerialization() throws Exception {
String originalJson; String originalJson;

View File

@ -157,6 +157,7 @@ class AccountsManagerUsernameIntegrationTest {
mock(RegistrationRecoveryPasswordsManager.class), mock(RegistrationRecoveryPasswordsManager.class),
mock(ClientPublicKeysManager.class), mock(ClientPublicKeysManager.class),
Executors.newSingleThreadExecutor(), Executors.newSingleThreadExecutor(),
Executors.newSingleThreadScheduledExecutor(),
mock(Clock.class), mock(Clock.class),
"link-device-secret".getBytes(StandardCharsets.UTF_8), "link-device-secret".getBytes(StandardCharsets.UTF_8),
dynamicConfigurationManager); dynamicConfigurationManager);

View File

@ -7,15 +7,14 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.ArgumentMatchers.anyByte;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.i18n.phonenumbers.PhoneNumberUtil; import com.google.i18n.phonenumbers.PhoneNumberUtil;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Duration; import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.time.ZoneId;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
@ -23,11 +22,15 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
@ -42,6 +45,7 @@ import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.tests.util.KeysHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.TestClock;
public class AddRemoveDeviceIntegrationTest { public class AddRemoveDeviceIntegrationTest {
@ -67,14 +71,14 @@ public class AddRemoveDeviceIntegrationTest {
@RegisterExtension @RegisterExtension
static final RedisServerExtension PUBSUB_SERVER_EXTENSION = RedisServerExtension.builder().build(); static final RedisServerExtension PUBSUB_SERVER_EXTENSION = RedisServerExtension.builder().build();
private static final Clock CLOCK = Clock.fixed(Instant.now(), ZoneId.systemDefault());
private ExecutorService accountLockExecutor; private ExecutorService accountLockExecutor;
private ScheduledExecutorService messagePollExecutor;
private KeysManager keysManager; private KeysManager keysManager;
private ClientPublicKeysManager clientPublicKeysManager; private ClientPublicKeysManager clientPublicKeysManager;
private MessagesManager messagesManager; private MessagesManager messagesManager;
private AccountsManager accountsManager; private AccountsManager accountsManager;
private TestClock clock;
@BeforeEach @BeforeEach
void setUp() { void setUp() {
@ -84,6 +88,8 @@ public class AddRemoveDeviceIntegrationTest {
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class); final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration); when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
clock = TestClock.pinned(Instant.now());
keysManager = new KeysManager( keysManager = new KeysManager(
DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(),
DynamoDbExtensionSchema.Tables.EC_KEYS.tableName(), DynamoDbExtensionSchema.Tables.EC_KEYS.tableName(),
@ -106,6 +112,7 @@ public class AddRemoveDeviceIntegrationTest {
DynamoDbExtensionSchema.Tables.USED_LINK_DEVICE_TOKENS.tableName()); DynamoDbExtensionSchema.Tables.USED_LINK_DEVICE_TOKENS.tableName());
accountLockExecutor = Executors.newSingleThreadExecutor(); accountLockExecutor = Executors.newSingleThreadExecutor();
messagePollExecutor = mock(ScheduledExecutorService.class);
final AccountLockManager accountLockManager = new AccountLockManager(DYNAMO_DB_EXTENSION.getDynamoDbClient(), final AccountLockManager accountLockManager = new AccountLockManager(DYNAMO_DB_EXTENSION.getDynamoDbClient(),
DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS_LOCK.tableName()); DynamoDbExtensionSchema.Tables.DELETED_ACCOUNTS_LOCK.tableName());
@ -155,7 +162,8 @@ public class AddRemoveDeviceIntegrationTest {
registrationRecoveryPasswordsManager, registrationRecoveryPasswordsManager,
clientPublicKeysManager, clientPublicKeysManager,
accountLockExecutor, accountLockExecutor,
CLOCK, messagePollExecutor,
clock,
"link-device-secret".getBytes(StandardCharsets.UTF_8), "link-device-secret".getBytes(StandardCharsets.UTF_8),
dynamicConfigurationManager); dynamicConfigurationManager);
@ -210,10 +218,15 @@ public class AddRemoveDeviceIntegrationTest {
final byte addedDeviceId = updatedAccountAndDevice.second().getId(); final byte addedDeviceId = updatedAccountAndDevice.second().getId();
assertTrue(keysManager.getEcSignedPreKey(updatedAccountAndDevice.first().getUuid(), addedDeviceId).join().isPresent()); assertTrue(
assertTrue(keysManager.getEcSignedPreKey(updatedAccountAndDevice.first().getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); keysManager.getEcSignedPreKey(updatedAccountAndDevice.first().getUuid(), addedDeviceId).join().isPresent());
assertTrue(
keysManager.getEcSignedPreKey(updatedAccountAndDevice.first().getPhoneNumberIdentifier(), addedDeviceId).join()
.isPresent());
assertTrue(keysManager.getLastResort(updatedAccountAndDevice.first().getUuid(), addedDeviceId).join().isPresent()); assertTrue(keysManager.getLastResort(updatedAccountAndDevice.first().getUuid(), addedDeviceId).join().isPresent());
assertTrue(keysManager.getLastResort(updatedAccountAndDevice.first().getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); assertTrue(
keysManager.getLastResort(updatedAccountAndDevice.first().getPhoneNumberIdentifier(), addedDeviceId).join()
.isPresent());
} }
@Test @Test
@ -317,15 +330,18 @@ public class AddRemoveDeviceIntegrationTest {
assertEquals(1, updatedAccount.getDevices().size()); assertEquals(1, updatedAccount.getDevices().size());
assertFalse(keysManager.getEcSignedPreKey(updatedAccount.getUuid(), addedDeviceId).join().isPresent()); assertFalse(keysManager.getEcSignedPreKey(updatedAccount.getUuid(), addedDeviceId).join().isPresent());
assertFalse(keysManager.getEcSignedPreKey(updatedAccount.getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); assertFalse(
keysManager.getEcSignedPreKey(updatedAccount.getPhoneNumberIdentifier(), addedDeviceId).join().isPresent());
assertFalse(keysManager.getLastResort(updatedAccount.getUuid(), addedDeviceId).join().isPresent()); assertFalse(keysManager.getLastResort(updatedAccount.getUuid(), addedDeviceId).join().isPresent());
assertFalse(keysManager.getLastResort(updatedAccount.getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); assertFalse(keysManager.getLastResort(updatedAccount.getPhoneNumberIdentifier(), addedDeviceId).join().isPresent());
assertFalse(clientPublicKeysManager.findPublicKey(updatedAccount.getUuid(), addedDeviceId).join().isPresent()); assertFalse(clientPublicKeysManager.findPublicKey(updatedAccount.getUuid(), addedDeviceId).join().isPresent());
assertTrue(keysManager.getEcSignedPreKey(updatedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent()); assertTrue(keysManager.getEcSignedPreKey(updatedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent());
assertTrue(keysManager.getEcSignedPreKey(updatedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); assertTrue(
keysManager.getEcSignedPreKey(updatedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent());
assertTrue(keysManager.getLastResort(updatedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent()); assertTrue(keysManager.getLastResort(updatedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent());
assertTrue(keysManager.getLastResort(updatedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); assertTrue(
keysManager.getLastResort(updatedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent());
assertTrue(clientPublicKeysManager.findPublicKey(updatedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent()); assertTrue(clientPublicKeysManager.findPublicKey(updatedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent());
} }
@ -371,21 +387,27 @@ public class AddRemoveDeviceIntegrationTest {
final Account retrievedAccount = accountsManager.getByAccountIdentifierAsync(aci).join().orElseThrow(); final Account retrievedAccount = accountsManager.getByAccountIdentifierAsync(aci).join().orElseThrow();
clientPublicKeysManager.setPublicKey(retrievedAccount, Device.PRIMARY_ID, Curve.generateKeyPair().getPublicKey()).join(); clientPublicKeysManager.setPublicKey(retrievedAccount, Device.PRIMARY_ID, Curve.generateKeyPair().getPublicKey())
clientPublicKeysManager.setPublicKey(retrievedAccount, addedDeviceId, Curve.generateKeyPair().getPublicKey()).join(); .join();
clientPublicKeysManager.setPublicKey(retrievedAccount, addedDeviceId, Curve.generateKeyPair().getPublicKey())
.join();
assertEquals(2, retrievedAccount.getDevices().size()); assertEquals(2, retrievedAccount.getDevices().size());
assertTrue(keysManager.getEcSignedPreKey(retrievedAccount.getUuid(), addedDeviceId).join().isPresent()); assertTrue(keysManager.getEcSignedPreKey(retrievedAccount.getUuid(), addedDeviceId).join().isPresent());
assertTrue(keysManager.getEcSignedPreKey(retrievedAccount.getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); assertTrue(
keysManager.getEcSignedPreKey(retrievedAccount.getPhoneNumberIdentifier(), addedDeviceId).join().isPresent());
assertTrue(keysManager.getLastResort(retrievedAccount.getUuid(), addedDeviceId).join().isPresent()); assertTrue(keysManager.getLastResort(retrievedAccount.getUuid(), addedDeviceId).join().isPresent());
assertTrue(keysManager.getLastResort(retrievedAccount.getPhoneNumberIdentifier(), addedDeviceId).join().isPresent()); assertTrue(
keysManager.getLastResort(retrievedAccount.getPhoneNumberIdentifier(), addedDeviceId).join().isPresent());
assertTrue(clientPublicKeysManager.findPublicKey(retrievedAccount.getUuid(), addedDeviceId).join().isPresent()); assertTrue(clientPublicKeysManager.findPublicKey(retrievedAccount.getUuid(), addedDeviceId).join().isPresent());
assertTrue(keysManager.getEcSignedPreKey(retrievedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent()); assertTrue(keysManager.getEcSignedPreKey(retrievedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent());
assertTrue(keysManager.getEcSignedPreKey(retrievedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); assertTrue(keysManager.getEcSignedPreKey(retrievedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join()
.isPresent());
assertTrue(keysManager.getLastResort(retrievedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent()); assertTrue(keysManager.getLastResort(retrievedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent());
assertTrue(keysManager.getLastResort(retrievedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent()); assertTrue(
keysManager.getLastResort(retrievedAccount.getPhoneNumberIdentifier(), Device.PRIMARY_ID).join().isPresent());
assertTrue(clientPublicKeysManager.findPublicKey(retrievedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent()); assertTrue(clientPublicKeysManager.findPublicKey(retrievedAccount.getUuid(), Device.PRIMARY_ID).join().isPresent());
} }
@ -403,11 +425,15 @@ public class AddRemoveDeviceIntegrationTest {
final String linkDeviceToken = accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI)); final String linkDeviceToken = accountsManager.generateLinkDeviceToken(account.getIdentifier(IdentityType.ACI));
final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken); final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken);
final CompletableFuture<Optional<DeviceInfo>> displacedFuture = final CompletableFuture<Optional<DeviceInfo>> displacedFuture = accountsManager.waitForNewLinkedDevice(
accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofSeconds(5)); account.getUuid(), account.getPrimaryDevice(),
linkDeviceTokenIdentifier, Duration.ofSeconds(5));
when(messagesManager.getEarliestUndeliveredTimestampForDevice(account.getUuid(), account.getPrimaryDevice()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
final CompletableFuture<Optional<DeviceInfo>> activeFuture = final CompletableFuture<Optional<DeviceInfo>> activeFuture =
accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofSeconds(5)); accountsManager.waitForNewLinkedDevice(account.getUuid(), account.getPrimaryDevice(), linkDeviceTokenIdentifier,
Duration.ofSeconds(5));
assertEquals(Optional.empty(), displacedFuture.join()); assertEquals(Optional.empty(), displacedFuture.join());
@ -470,8 +496,11 @@ public class AddRemoveDeviceIntegrationTest {
linkDeviceToken) linkDeviceToken)
.join(); .join();
final CompletableFuture<Optional<DeviceInfo>> linkedDeviceFuture = when(messagesManager.getEarliestUndeliveredTimestampForDevice(account.getUuid(), account.getPrimaryDevice()))
accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofMinutes(1)); .thenReturn(CompletableFuture.completedFuture(Optional.empty()));
final CompletableFuture<Optional<DeviceInfo>> linkedDeviceFuture = accountsManager.waitForNewLinkedDevice(
account.getUuid(), account.getPrimaryDevice(), linkDeviceTokenIdentifier, Duration.ofMinutes(1));
final Optional<DeviceInfo> maybeDeviceInfo = linkedDeviceFuture.join(); final Optional<DeviceInfo> maybeDeviceInfo = linkedDeviceFuture.join();
@ -483,15 +512,121 @@ public class AddRemoveDeviceIntegrationTest {
} }
@Test @Test
void waitForNewLinkedDeviceTimeout() { void waitForNewLinkedDeviceTimeout() throws InterruptedException {
final String number = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
final Account account = AccountsHelper.createAccount(accountsManager, number);
final String linkDeviceToken = accountsManager.generateLinkDeviceToken(UUID.randomUUID()); final String linkDeviceToken = accountsManager.generateLinkDeviceToken(UUID.randomUUID());
final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken); final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken);
final CompletableFuture<Optional<DeviceInfo>> linkedDeviceFuture = final CompletableFuture<Optional<DeviceInfo>> linkedDeviceFuture = accountsManager.waitForNewLinkedDevice(
accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofMillis(1)); account.getUuid(), account.getPrimaryDevice(), linkDeviceTokenIdentifier, Duration.ofMillis(1));
final Optional<DeviceInfo> maybeDeviceInfo = linkedDeviceFuture.join(); final Optional<DeviceInfo> maybeDeviceInfo = linkedDeviceFuture.join();
assertTrue(maybeDeviceInfo.isEmpty()); assertTrue(maybeDeviceInfo.isEmpty());
} }
@ParameterizedTest
@CsvSource({
"10_000,1000,,false", // no pending messages
"10_000,1000,1000,true", // pending message at device creation
"10_000,1000,5999,true", // pending message right before device creation + fudge factor
"10_000,1000,6000,false", // pending message at device creation + fudge factor
"3000,5000,4000,false", // pending message after current time but before device creation
})
void waitForMessageFetch(long currentTime, long deviceCreation, Long oldestMessage, boolean shouldWait)
throws InterruptedException {
final String number = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
final Account account = AccountsHelper.createAccount(accountsManager, number);
final String linkDeviceToken = accountsManager.generateLinkDeviceToken(UUID.randomUUID());
final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken);
clock.pin(Instant.ofEpochMilli(deviceCreation));
final Pair<Account, Device> updatedAccountAndDevice =
accountsManager.addDevice(account, new DeviceSpec(
"device-name".getBytes(StandardCharsets.UTF_8),
"password",
"OWT",
Set.of(),
1,
2,
true,
Optional.empty(),
Optional.empty(),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)),
linkDeviceToken)
.join();
assertEquals(updatedAccountAndDevice.second().getCreated(), deviceCreation);
when(messagesManager.getEarliestUndeliveredTimestampForDevice(account.getUuid(), account.getPrimaryDevice()))
.thenReturn(CompletableFuture.completedFuture(Optional.ofNullable(oldestMessage).map(Instant::ofEpochMilli)));
clock.pin(Instant.ofEpochMilli(currentTime));
Duration timeout = shouldWait ? Duration.ofMillis(5) : Duration.ofMillis(1000);
Optional<DeviceInfo> result = accountsManager.waitForNewLinkedDevice(account.getUuid(),
account.getPrimaryDevice(), linkDeviceTokenIdentifier, timeout).join();
assertEquals(result.isEmpty(), shouldWait);
}
// ThreadMode.SEPARATE_THREAD protects against hangs in the async calls, as this mode allows the test code to be
// preempted by the timeout check
@Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
@Test
void waitForMessageFetchRetries()
throws InterruptedException {
final String number = PhoneNumberUtil.getInstance().format(
PhoneNumberUtil.getInstance().getExampleNumber("US"),
PhoneNumberUtil.PhoneNumberFormat.E164);
final ECKeyPair aciKeyPair = Curve.generateKeyPair();
final ECKeyPair pniKeyPair = Curve.generateKeyPair();
final Account account = AccountsHelper.createAccount(accountsManager, number);
final String linkDeviceToken = accountsManager.generateLinkDeviceToken(UUID.randomUUID());
final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken);
clock.pin(Instant.ofEpochMilli(0));
accountsManager.addDevice(account, new DeviceSpec(
"device-name".getBytes(StandardCharsets.UTF_8),
"password",
"OWT",
Set.of(),
1,
2,
true,
Optional.empty(),
Optional.empty(),
KeysHelper.signedECPreKey(1, aciKeyPair),
KeysHelper.signedECPreKey(2, pniKeyPair),
KeysHelper.signedKEMPreKey(3, aciKeyPair),
KeysHelper.signedKEMPreKey(4, pniKeyPair)),
linkDeviceToken)
.join();
when(messagesManager.getEarliestUndeliveredTimestampForDevice(account.getUuid(), account.getPrimaryDevice()))
// Has a message older than the message epoch
.thenReturn(CompletableFuture.completedFuture(Optional.of(Instant.ofEpochMilli(1000))))
// The message was fetched
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
clock.pin(Instant.ofEpochMilli(10_000));
// Run any scheduled job right away
when(messagePollExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer(x -> {
x.getArgument(0, Runnable.class).run();
return null;
});
Optional<DeviceInfo> result = accountsManager.waitForNewLinkedDevice(account.getUuid(),
account.getPrimaryDevice(), linkDeviceTokenIdentifier, Duration.ofSeconds(10)).join();
assertTrue(result.isPresent());
}
} }

View File

@ -9,6 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
@ -145,7 +146,7 @@ class MessagesCacheTest {
messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage); messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage);
messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage); messagesCache.insert(duplicateGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, duplicateMessage);
assertEquals(1, messagesCache.getAllMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID, 0) assertEquals(1, messagesCache.getAllMessages(DESTINATION_UUID, DESTINATION_DEVICE_ID, 0, 10)
.count() .count()
.blockOptional() .blockOptional()
.orElse(0L)); .orElse(0L));
@ -225,6 +226,31 @@ class MessagesCacheTest {
assertTrue(messagesCache.hasMessagesAsync(DESTINATION_UUID, DESTINATION_DEVICE_ID).join()); assertTrue(messagesCache.hasMessagesAsync(DESTINATION_UUID, DESTINATION_DEVICE_ID).join());
} }
@Test
void getOldestTimestamp() {
final int messageCount = 100;
final List<MessageProtos.Envelope> expectedMessages = new ArrayList<>(messageCount);
long expectedOldestTimestamp = serialTimestamp;
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, i % 2 == 0);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
assertEquals(expectedOldestTimestamp,
messagesCache.getEarliestUndeliveredTimestamp(DESTINATION_UUID, DESTINATION_DEVICE_ID).block());
expectedMessages.add(message);
}
for (final MessageProtos.Envelope message : expectedMessages) {
assertEquals(expectedOldestTimestamp,
messagesCache.getEarliestUndeliveredTimestamp(DESTINATION_UUID, DESTINATION_DEVICE_ID).block());
messagesCache.remove(DESTINATION_UUID, DESTINATION_DEVICE_ID, UUID.fromString(message.getServerGuid())).join();
expectedOldestTimestamp += 1;
}
assertNull(messagesCache.getEarliestUndeliveredTimestamp(DESTINATION_UUID, DESTINATION_DEVICE_ID).block());
}
@ParameterizedTest @ParameterizedTest
@ValueSource(booleans = {true, false}) @ValueSource(booleans = {true, false})
void testGetMessages(final boolean sealedSender) throws Exception { void testGetMessages(final boolean sealedSender) throws Exception {
@ -236,7 +262,6 @@ class MessagesCacheTest {
final UUID messageGuid = UUID.randomUUID(); final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); messagesCache.insert(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message);
expectedMessages.add(message); expectedMessages.add(message);
} }
@ -322,7 +347,7 @@ class MessagesCacheTest {
.get(5, TimeUnit.SECONDS); .get(5, TimeUnit.SECONDS);
final List<MessageProtos.Envelope> messages = messagesCache.getAllMessages(DESTINATION_UUID, final List<MessageProtos.Envelope> messages = messagesCache.getAllMessages(DESTINATION_UUID,
DESTINATION_DEVICE_ID, 0) DESTINATION_DEVICE_ID, 0, 10)
.collectList() .collectList()
.toFuture().get(5, TimeUnit.SECONDS); .toFuture().get(5, TimeUnit.SECONDS);
@ -655,7 +680,7 @@ class MessagesCacheTest {
.thenReturn(Flux.from(emptyFinalPagePublisher)) .thenReturn(Flux.from(emptyFinalPagePublisher))
.thenReturn(Flux.empty()); .thenReturn(Flux.empty());
final Flux<?> allMessages = messagesCache.getAllMessages(UUID.randomUUID(), Device.PRIMARY_ID, 0); final Flux<?> allMessages = messagesCache.getAllMessages(UUID.randomUUID(), Device.PRIMARY_ID, 0, 10);
// Why initialValue = 3? // Why initialValue = 3?
// 1. messagesCache.getAllMessages() above produces the first call // 1. messagesCache.getAllMessages() above produces the first call

View File

@ -14,6 +14,8 @@ import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import java.time.Instant;
import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
@ -21,6 +23,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.CsvSource;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import reactor.core.publisher.Mono;
class MessagesManagerTest { class MessagesManagerTest {
@ -77,4 +80,28 @@ class MessagesManagerTest {
assertEquals(expectMayHaveMessages, messagesManager.mayHaveMessages(accountIdentifier, device).join()); assertEquals(expectMayHaveMessages, messagesManager.mayHaveMessages(accountIdentifier, device).join());
} }
@ParameterizedTest
@CsvSource({
",,",
"1,,1",
",1,1",
"2,1,1",
"1,2,2"
})
public void oldestMessageTimestamp(Long oldestCached, Long oldestPersisted, Long expected) {
final UUID accountIdentifier = UUID.randomUUID();
final Device device = mock(Device.class);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(messagesCache.getEarliestUndeliveredTimestamp(accountIdentifier, Device.PRIMARY_ID))
.thenReturn(oldestCached == null ? Mono.empty() : Mono.just(oldestCached));
when(messagesDynamoDb.load(accountIdentifier, device, 1))
.thenReturn(oldestPersisted == null
? Mono.empty()
: Mono.just(Envelope.newBuilder().setServerTimestamp(oldestPersisted).build()));
final Optional<Instant> earliest =
messagesManager.getEarliestUndeliveredTimestampForDevice(accountIdentifier, device).join();
assertEquals(Optional.ofNullable(expected).map(Instant::ofEpochMilli), earliest);
}
} }