From 7ff48155d65fee406196ebdbbe6b20e37d13d0df Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Thu, 10 Oct 2024 11:59:37 -0400 Subject: [PATCH] Add plumbing for a "wait for transfer archive" system --- .../storage/AccountsManager.java | 213 +++++++++++++----- ...ManagerTransferArchiveIntegrationTest.java | 147 ++++++++++++ .../AddRemoveDeviceIntegrationTest.java | 2 +- 3 files changed, 307 insertions(+), 55 deletions(-) create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTransferArchiveIntegrationTest.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 2d637aeae..344dd6826 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -49,6 +49,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; @@ -69,6 +70,7 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.DeviceInfo; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; +import org.whispersystems.textsecuregcm.entities.RemoteAttachment; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; @@ -115,7 +117,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen private final Accounts accounts; private final PhoneNumberIdentifiers phoneNumberIdentifiers; private final FaultTolerantRedisClusterClient cacheCluster; - private final FaultTolerantRedisClient pubSubRedisSingleton; + private final FaultTolerantRedisClient pubSubRedisClient; private final AccountLockManager accountLockManager; private final KeysManager keysManager; private final MessagesManager messagesManager; @@ -137,11 +139,19 @@ public class AccountsManager extends RedisPubSubAdapter implemen private final Map>> waitForDeviceFuturesByTokenIdentifier = new ConcurrentHashMap<>(); + private final Map>> waitForTransferArchiveFuturesByDeviceIdentifier = + new ConcurrentHashMap<>(); + private static final int SHA256_HASH_LENGTH = getSha256MessageDigest().getDigestLength(); + private static final Duration RECENTLY_ADDED_DEVICE_TTL = Duration.ofHours(1); private static final String LINKED_DEVICE_PREFIX = "linked_device::"; private static final String LINKED_DEVICE_KEYSPACE_PATTERN = "__keyspace@0__:" + LINKED_DEVICE_PREFIX + "*"; + private static final Duration RECENTLY_ADDED_TRANSFER_ARCHIVE_TTL = Duration.ofHours(1); + private static final String TRANSFER_ARCHIVE_PREFIX = "transfer_archive::"; + private static final String TRANSFER_ARCHIVE_KEYSPACE_PATTERN = "__keyspace@0__:" + TRANSFER_ARCHIVE_PREFIX + "*"; + private static final ObjectWriter ACCOUNT_REDIS_JSON_WRITER = SystemMapper.jsonMapper() .writer(SystemMapper.excludingField(Account.class, List.of("uuid"))); @@ -173,10 +183,13 @@ public class AccountsManager extends RedisPubSubAdapter implemen } } + private record TimestampedDeviceIdentifier(UUID accountIdentifier, byte deviceId, Instant deviceCreationTimestamp) { + } + public AccountsManager(final Accounts accounts, final PhoneNumberIdentifiers phoneNumberIdentifiers, final FaultTolerantRedisClusterClient cacheCluster, - final FaultTolerantRedisClient pubSubRedisSingleton, + final FaultTolerantRedisClient pubSubRedisClient, final AccountLockManager accountLockManager, final KeysManager keysManager, final MessagesManager messagesManager, @@ -194,7 +207,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen this.accounts = accounts; this.phoneNumberIdentifiers = phoneNumberIdentifiers; this.cacheCluster = cacheCluster; - this.pubSubRedisSingleton = pubSubRedisSingleton; + this.pubSubRedisClient = pubSubRedisClient; this.accountLockManager = accountLockManager; this.keysManager = keysManager; this.messagesManager = messagesManager; @@ -218,19 +231,23 @@ public class AccountsManager extends RedisPubSubAdapter implemen throw new IllegalArgumentException(e); } - this.pubSubConnection = pubSubRedisSingleton.createPubSubConnection(); + this.pubSubConnection = pubSubRedisClient.createPubSubConnection(); } @Override public void start() { - pubSubConnection.usePubSubConnection(connection -> connection.addListener(this)); - pubSubConnection.usePubSubConnection(connection -> connection.sync().psubscribe(LINKED_DEVICE_KEYSPACE_PATTERN)); + pubSubConnection.usePubSubConnection(connection -> { + connection.addListener(this); + connection.sync().psubscribe(LINKED_DEVICE_KEYSPACE_PATTERN, TRANSFER_ARCHIVE_KEYSPACE_PATTERN); + }); } @Override public void stop() { - pubSubConnection.usePubSubConnection(connection -> connection.sync().punsubscribe()); - pubSubConnection.usePubSubConnection(connection -> connection.removeListener(this)); + pubSubConnection.usePubSubConnection(connection -> { + connection.sync().punsubscribe(); + connection.removeListener(this); + }); } public Account create(final String number, @@ -409,7 +426,7 @@ public class AccountsManager extends RedisPubSubAdapter implemen throw new UncheckedIOException(e); } - pubSubRedisSingleton.withConnection(connection -> + pubSubRedisClient.withConnection(connection -> connection.async().set(key, deviceInfoJson, SetArgs.Builder.ex(RECENTLY_ADDED_DEVICE_TTL))) .whenComplete((ignored, pubSubThrowable) -> { if (pubSubThrowable != null) { @@ -1406,51 +1423,11 @@ public class AccountsManager extends RedisPubSubAdapter implemen return CompletableFuture.failedFuture(new IllegalArgumentException("Invalid token identifier")); } - final CompletableFuture> waitForDeviceFuture = new CompletableFuture<>(); - - waitForDeviceFuture - .completeOnTimeout(Optional.empty(), TimeUnit.MILLISECONDS.convert(timeout), TimeUnit.MILLISECONDS) - .whenComplete((maybeDevice, throwable) -> waitForDeviceFuturesByTokenIdentifier.compute(linkDeviceTokenIdentifier, - (ignored, existingFuture) -> { - // Only remove the future from the map if it's THIS future, and not one that later displaced this one - return existingFuture == waitForDeviceFuture ? null : existingFuture; - })); - - { - final CompletableFuture> displacedFuture = - waitForDeviceFuturesByTokenIdentifier.put(linkDeviceTokenIdentifier, waitForDeviceFuture); - - if (displacedFuture != null) { - displacedFuture.complete(Optional.empty()); - } - } - - // The device may already have been linked by the time the caller started watching for it; perform an immediate - // check to see if the device is already there. - pubSubRedisSingleton.withConnection(connection -> connection.async().get(getLinkedDeviceKey(linkDeviceTokenIdentifier))) - .thenAccept(response -> { - if (StringUtils.isNotBlank(response)) { - handleDeviceAdded(waitForDeviceFuture, response); - } - }); - - return waitForDeviceFuture; - } - - private static String getLinkedDeviceKey(final String linkDeviceTokenIdentifier) { - return LINKED_DEVICE_PREFIX + linkDeviceTokenIdentifier; - } - - @Override - public void message(final String pattern, final String channel, final String message) { - if (LINKED_DEVICE_KEYSPACE_PATTERN.equals(pattern) && "set".equalsIgnoreCase(message)) { - // The `- 1` here compensates for the '*' in the pattern - final String tokenIdentifier = channel.substring(LINKED_DEVICE_KEYSPACE_PATTERN.length() - 1); - - Optional.ofNullable(waitForDeviceFuturesByTokenIdentifier.remove(tokenIdentifier)) - .ifPresent(future -> pubSubRedisSingleton.withConnection(connection -> connection.async().get(getLinkedDeviceKey(tokenIdentifier))) - .thenAccept(deviceInfoJson -> handleDeviceAdded(future, deviceInfoJson))); - } + return waitForPubSubKey(waitForDeviceFuturesByTokenIdentifier, + linkDeviceTokenIdentifier, + getLinkedDeviceKey(linkDeviceTokenIdentifier), + timeout, + this::handleDeviceAdded); } private void handleDeviceAdded(final CompletableFuture> future, final String deviceInfoJson) { @@ -1462,6 +1439,134 @@ public class AccountsManager extends RedisPubSubAdapter implemen } } + private static String getLinkedDeviceKey(final String linkDeviceTokenIdentifier) { + return LINKED_DEVICE_PREFIX + linkDeviceTokenIdentifier; + } + + public CompletableFuture> waitForTransferArchive(final Account account, final Device device, final Duration timeout) { + final TimestampedDeviceIdentifier deviceIdentifier = + new TimestampedDeviceIdentifier(account.getIdentifier(IdentityType.ACI), + device.getId(), + Instant.ofEpochMilli(device.getCreated())); + + return waitForPubSubKey(waitForTransferArchiveFuturesByDeviceIdentifier, + deviceIdentifier, + getTransferArchiveKey(account.getIdentifier(IdentityType.ACI), device.getId(), Instant.ofEpochMilli(device.getCreated())), + timeout, + this::handleTransferArchiveAdded); + } + + public CompletableFuture recordTransferArchiveUpload(final Account account, + final byte destinationDeviceId, + final Instant destinationDeviceCreationTimestamp, + final RemoteAttachment transferArchive) { + + final String key = getTransferArchiveKey(account.getIdentifier(IdentityType.ACI), + destinationDeviceId, + destinationDeviceCreationTimestamp); + + try { + final String transferArchiveJson = SystemMapper.jsonMapper().writeValueAsString(transferArchive); + + return pubSubRedisClient.withConnection(connection -> + connection.async().set(key, transferArchiveJson, SetArgs.Builder.ex(RECENTLY_ADDED_TRANSFER_ARCHIVE_TTL))) + .thenRun(Util.NOOP) + .toCompletableFuture(); + } catch (final JsonProcessingException e) { + // This should never happen for well-defined objects we control + throw new UncheckedIOException(e); + } + } + + private void handleTransferArchiveAdded(final CompletableFuture> future, final String transferArchiveJson) { + try { + future.complete(Optional.of(SystemMapper.jsonMapper().readValue(transferArchiveJson, RemoteAttachment.class))); + } catch (final JsonProcessingException e) { + logger.error("Could not parse transfer archive json", e); + future.completeExceptionally(e); + } + } + + private static String getTransferArchiveKey(final UUID accountIdentifier, + final byte destinationDeviceId, + final Instant destinationDeviceCreationTimestamp) { + + return TRANSFER_ARCHIVE_PREFIX + accountIdentifier.toString() + + ":" + destinationDeviceId + + ":" + destinationDeviceCreationTimestamp.toEpochMilli(); + } + + private CompletableFuture> waitForPubSubKey(final Map>> futureMap, + final K mapKey, + final String redisKey, + final Duration timeout, + final BiConsumer>, String> handler) { + + final CompletableFuture> future = new CompletableFuture<>(); + + future.completeOnTimeout(Optional.empty(), TimeUnit.MILLISECONDS.convert(timeout), TimeUnit.MILLISECONDS) + .whenComplete((maybeBackup, throwable) -> futureMap.remove(mapKey, future)); + + { + final CompletableFuture> displacedFuture = futureMap.put(mapKey, future); + + if (displacedFuture != null) { + displacedFuture.complete(Optional.empty()); + } + } + + // The Redis key we're waiting for may have been added before the caller issued a request to watch for it; check to + // see if it's already there + pubSubRedisClient.withConnection(connection -> connection.async().get(redisKey)) + .thenAccept(response -> { + if (StringUtils.isNotBlank(response)) { + handler.accept(future, response); + } + }); + + return future; + } + + @Override + public void message(final String pattern, final String channel, final String message) { + if (LINKED_DEVICE_KEYSPACE_PATTERN.equals(pattern) && "set".equalsIgnoreCase(message)) { + // The `- 1` here compensates for the '*' in the pattern + final String tokenIdentifier = channel.substring(LINKED_DEVICE_KEYSPACE_PATTERN.length() - 1); + + Optional.ofNullable(waitForDeviceFuturesByTokenIdentifier.remove(tokenIdentifier)) + .ifPresent(future -> pubSubRedisClient.withConnection(connection -> connection.async().get(getLinkedDeviceKey(tokenIdentifier))) + .thenAccept(deviceInfoJson -> handleDeviceAdded(future, deviceInfoJson))); + } else if (TRANSFER_ARCHIVE_KEYSPACE_PATTERN.equals(pattern) && "set".equalsIgnoreCase(message)) { + // The `- 1` here compensates for the '*' in the pattern + final String[] deviceIdentifierComponents = + channel.substring(TRANSFER_ARCHIVE_KEYSPACE_PATTERN.length() - 1).split(":", 3); + + if (deviceIdentifierComponents.length != 3) { + logger.error("Could not parse timestamped device identifier; unexpected component count"); + return; + } + + try { + final TimestampedDeviceIdentifier deviceIdentifier; + final String transferArchiveKey; + { + final UUID accountIdentifier = UUID.fromString(deviceIdentifierComponents[0]); + final byte deviceId = Byte.parseByte(deviceIdentifierComponents[1]); + final Instant deviceCreationTimestamp = Instant.ofEpochMilli(Long.parseLong(deviceIdentifierComponents[2])); + + deviceIdentifier = new TimestampedDeviceIdentifier(accountIdentifier, deviceId, deviceCreationTimestamp); + transferArchiveKey = getTransferArchiveKey(accountIdentifier, deviceId, deviceCreationTimestamp); + } + + Optional.ofNullable(waitForTransferArchiveFuturesByDeviceIdentifier.remove(deviceIdentifier)) + .ifPresent(future -> pubSubRedisClient.withConnection(connection -> connection.async().get(transferArchiveKey)) + .thenAccept(transferArchiveJson -> handleTransferArchiveAdded(future, transferArchiveJson))); + } catch (final IllegalArgumentException e) { + logger.error("Could not parse timestamped device identifier", e); + } + } + } + private static MessageDigest getSha256MessageDigest() { try { return MessageDigest.getInstance("SHA-256"); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTransferArchiveIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTransferArchiveIntegrationTest.java new file mode 100644 index 000000000..db58467ff --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTransferArchiveIntegrationTest.java @@ -0,0 +1,147 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.entities.RemoteAttachment; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.push.ClientPresenceManager; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; +import org.whispersystems.textsecuregcm.redis.RedisServerExtension; +import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; +import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; +import java.nio.charset.StandardCharsets; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Base64; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +// ThreadMode.SEPARATE_THREAD protects against hangs in the remote Redis calls, as this mode allows the test code to be +// preempted by the timeout check +@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) +public class AccountsManagerTransferArchiveIntegrationTest { + + @RegisterExtension + static final RedisServerExtension PUBSUB_SERVER_EXTENSION = RedisServerExtension.builder().build(); + + private AccountsManager accountsManager; + + @BeforeEach + void setUp() { + PUBSUB_SERVER_EXTENSION.getRedisClient().useConnection(connection -> { + connection.sync().flushall(); + connection.sync().configSet("notify-keyspace-events", "K$"); + }); + + //noinspection unchecked + accountsManager = new AccountsManager( + mock(Accounts.class), + mock(PhoneNumberIdentifiers.class), + mock(FaultTolerantRedisClusterClient.class), + PUBSUB_SERVER_EXTENSION.getRedisClient(), + mock(AccountLockManager.class), + mock(KeysManager.class), + mock(MessagesManager.class), + mock(ProfilesManager.class), + mock(SecureStorageClient.class), + mock(SecureValueRecovery2Client.class), + mock(ClientPresenceManager.class), + mock(RegistrationRecoveryPasswordsManager.class), + mock(ClientPublicKeysManager.class), + mock(ExecutorService.class), + mock(ExecutorService.class), + Clock.systemUTC(), + "link-device-secret".getBytes(StandardCharsets.UTF_8), + mock(DynamicConfigurationManager.class)); + + accountsManager.start(); + } + + @AfterEach + void tearDown() { + accountsManager.stop(); + } + + @Test + void waitForTransferArchive() { + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = Device.PRIMARY_ID; + final long deviceCreated = System.currentTimeMillis(); + + final RemoteAttachment transferArchive = + new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("transfer-archive".getBytes(StandardCharsets.UTF_8))); + + final Device device = mock(Device.class); + when(device.getId()).thenReturn(deviceId); + when(device.getCreated()).thenReturn(deviceCreated); + + final Account account = mock(Account.class); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); + + final CompletableFuture> displacedFuture = + accountsManager.waitForTransferArchive(account, device, Duration.ofSeconds(5)); + + final CompletableFuture> activeFuture = + accountsManager.waitForTransferArchive(account, device, Duration.ofSeconds(5)); + + assertEquals(Optional.empty(), displacedFuture.join()); + + accountsManager.recordTransferArchiveUpload(account, deviceId, Instant.ofEpochMilli(deviceCreated), transferArchive).join(); + + assertEquals(Optional.of(transferArchive), activeFuture.join()); + } + + @Test + void waitForTransferArchiveAlreadyAdded() { + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = Device.PRIMARY_ID; + final long deviceCreated = System.currentTimeMillis(); + + final RemoteAttachment transferArchive = + new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("transfer-archive".getBytes(StandardCharsets.UTF_8))); + + final Device device = mock(Device.class); + when(device.getId()).thenReturn(deviceId); + when(device.getCreated()).thenReturn(deviceCreated); + + final Account account = mock(Account.class); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); + + accountsManager.recordTransferArchiveUpload(account, deviceId, Instant.ofEpochMilli(deviceCreated), transferArchive).join(); + + assertEquals(Optional.of(transferArchive), + accountsManager.waitForTransferArchive(account, device, Duration.ofSeconds(5)).join()); + } + + @Test + void waitForTransferArchiveTimeout() { + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceId = Device.PRIMARY_ID; + final long deviceCreated = System.currentTimeMillis(); + + final Device device = mock(Device.class); + when(device.getId()).thenReturn(deviceId); + when(device.getCreated()).thenReturn(deviceCreated); + + final Account account = mock(Account.class); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); + + assertEquals(Optional.empty(), + accountsManager.waitForTransferArchive(account, device, Duration.ofMillis(1)).join()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java index 5eb5e2a01..cac7a31d4 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AddRemoveDeviceIntegrationTest.java @@ -492,7 +492,7 @@ public class AddRemoveDeviceIntegrationTest { final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken); final CompletableFuture> linkedDeviceFuture = - accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofMillis(10)); + accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofMillis(1)); final Optional maybeDeviceInfo = linkedDeviceFuture.join();