Add plumbing for a "wait for transfer archive" system

This commit is contained in:
Jon Chambers 2024-10-10 11:59:37 -04:00 committed by Jon Chambers
parent 0adaa331a1
commit 7ff48155d6
3 changed files with 307 additions and 55 deletions

View File

@ -49,6 +49,7 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.BiFunction; import java.util.function.BiFunction;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;
@ -69,6 +70,7 @@ import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.DeviceInfo; import org.whispersystems.textsecuregcm.entities.DeviceInfo;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.RemoteAttachment;
import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
@ -115,7 +117,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
private final Accounts accounts; private final Accounts accounts;
private final PhoneNumberIdentifiers phoneNumberIdentifiers; private final PhoneNumberIdentifiers phoneNumberIdentifiers;
private final FaultTolerantRedisClusterClient cacheCluster; private final FaultTolerantRedisClusterClient cacheCluster;
private final FaultTolerantRedisClient pubSubRedisSingleton; private final FaultTolerantRedisClient pubSubRedisClient;
private final AccountLockManager accountLockManager; private final AccountLockManager accountLockManager;
private final KeysManager keysManager; private final KeysManager keysManager;
private final MessagesManager messagesManager; private final MessagesManager messagesManager;
@ -137,11 +139,19 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
private final Map<String, CompletableFuture<Optional<DeviceInfo>>> waitForDeviceFuturesByTokenIdentifier = private final Map<String, CompletableFuture<Optional<DeviceInfo>>> waitForDeviceFuturesByTokenIdentifier =
new ConcurrentHashMap<>(); new ConcurrentHashMap<>();
private final Map<TimestampedDeviceIdentifier, CompletableFuture<Optional<RemoteAttachment>>> waitForTransferArchiveFuturesByDeviceIdentifier =
new ConcurrentHashMap<>();
private static final int SHA256_HASH_LENGTH = getSha256MessageDigest().getDigestLength(); private static final int SHA256_HASH_LENGTH = getSha256MessageDigest().getDigestLength();
private static final Duration RECENTLY_ADDED_DEVICE_TTL = Duration.ofHours(1); private static final Duration RECENTLY_ADDED_DEVICE_TTL = Duration.ofHours(1);
private static final String LINKED_DEVICE_PREFIX = "linked_device::"; private static final String LINKED_DEVICE_PREFIX = "linked_device::";
private static final String LINKED_DEVICE_KEYSPACE_PATTERN = "__keyspace@0__:" + LINKED_DEVICE_PREFIX + "*"; private static final String LINKED_DEVICE_KEYSPACE_PATTERN = "__keyspace@0__:" + LINKED_DEVICE_PREFIX + "*";
private static final Duration RECENTLY_ADDED_TRANSFER_ARCHIVE_TTL = Duration.ofHours(1);
private static final String TRANSFER_ARCHIVE_PREFIX = "transfer_archive::";
private static final String TRANSFER_ARCHIVE_KEYSPACE_PATTERN = "__keyspace@0__:" + TRANSFER_ARCHIVE_PREFIX + "*";
private static final ObjectWriter ACCOUNT_REDIS_JSON_WRITER = SystemMapper.jsonMapper() private static final ObjectWriter ACCOUNT_REDIS_JSON_WRITER = SystemMapper.jsonMapper()
.writer(SystemMapper.excludingField(Account.class, List.of("uuid"))); .writer(SystemMapper.excludingField(Account.class, List.of("uuid")));
@ -173,10 +183,13 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
} }
} }
private record TimestampedDeviceIdentifier(UUID accountIdentifier, byte deviceId, Instant deviceCreationTimestamp) {
}
public AccountsManager(final Accounts accounts, public AccountsManager(final Accounts accounts,
final PhoneNumberIdentifiers phoneNumberIdentifiers, final PhoneNumberIdentifiers phoneNumberIdentifiers,
final FaultTolerantRedisClusterClient cacheCluster, final FaultTolerantRedisClusterClient cacheCluster,
final FaultTolerantRedisClient pubSubRedisSingleton, final FaultTolerantRedisClient pubSubRedisClient,
final AccountLockManager accountLockManager, final AccountLockManager accountLockManager,
final KeysManager keysManager, final KeysManager keysManager,
final MessagesManager messagesManager, final MessagesManager messagesManager,
@ -194,7 +207,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
this.accounts = accounts; this.accounts = accounts;
this.phoneNumberIdentifiers = phoneNumberIdentifiers; this.phoneNumberIdentifiers = phoneNumberIdentifiers;
this.cacheCluster = cacheCluster; this.cacheCluster = cacheCluster;
this.pubSubRedisSingleton = pubSubRedisSingleton; this.pubSubRedisClient = pubSubRedisClient;
this.accountLockManager = accountLockManager; this.accountLockManager = accountLockManager;
this.keysManager = keysManager; this.keysManager = keysManager;
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
@ -218,19 +231,23 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
throw new IllegalArgumentException(e); throw new IllegalArgumentException(e);
} }
this.pubSubConnection = pubSubRedisSingleton.createPubSubConnection(); this.pubSubConnection = pubSubRedisClient.createPubSubConnection();
} }
@Override @Override
public void start() { public void start() {
pubSubConnection.usePubSubConnection(connection -> connection.addListener(this)); pubSubConnection.usePubSubConnection(connection -> {
pubSubConnection.usePubSubConnection(connection -> connection.sync().psubscribe(LINKED_DEVICE_KEYSPACE_PATTERN)); connection.addListener(this);
connection.sync().psubscribe(LINKED_DEVICE_KEYSPACE_PATTERN, TRANSFER_ARCHIVE_KEYSPACE_PATTERN);
});
} }
@Override @Override
public void stop() { public void stop() {
pubSubConnection.usePubSubConnection(connection -> connection.sync().punsubscribe()); pubSubConnection.usePubSubConnection(connection -> {
pubSubConnection.usePubSubConnection(connection -> connection.removeListener(this)); connection.sync().punsubscribe();
connection.removeListener(this);
});
} }
public Account create(final String number, public Account create(final String number,
@ -409,7 +426,7 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
throw new UncheckedIOException(e); throw new UncheckedIOException(e);
} }
pubSubRedisSingleton.withConnection(connection -> pubSubRedisClient.withConnection(connection ->
connection.async().set(key, deviceInfoJson, SetArgs.Builder.ex(RECENTLY_ADDED_DEVICE_TTL))) connection.async().set(key, deviceInfoJson, SetArgs.Builder.ex(RECENTLY_ADDED_DEVICE_TTL)))
.whenComplete((ignored, pubSubThrowable) -> { .whenComplete((ignored, pubSubThrowable) -> {
if (pubSubThrowable != null) { if (pubSubThrowable != null) {
@ -1406,51 +1423,11 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
return CompletableFuture.failedFuture(new IllegalArgumentException("Invalid token identifier")); return CompletableFuture.failedFuture(new IllegalArgumentException("Invalid token identifier"));
} }
final CompletableFuture<Optional<DeviceInfo>> waitForDeviceFuture = new CompletableFuture<>(); return waitForPubSubKey(waitForDeviceFuturesByTokenIdentifier,
linkDeviceTokenIdentifier,
waitForDeviceFuture getLinkedDeviceKey(linkDeviceTokenIdentifier),
.completeOnTimeout(Optional.empty(), TimeUnit.MILLISECONDS.convert(timeout), TimeUnit.MILLISECONDS) timeout,
.whenComplete((maybeDevice, throwable) -> waitForDeviceFuturesByTokenIdentifier.compute(linkDeviceTokenIdentifier, this::handleDeviceAdded);
(ignored, existingFuture) -> {
// Only remove the future from the map if it's THIS future, and not one that later displaced this one
return existingFuture == waitForDeviceFuture ? null : existingFuture;
}));
{
final CompletableFuture<Optional<DeviceInfo>> displacedFuture =
waitForDeviceFuturesByTokenIdentifier.put(linkDeviceTokenIdentifier, waitForDeviceFuture);
if (displacedFuture != null) {
displacedFuture.complete(Optional.empty());
}
}
// The device may already have been linked by the time the caller started watching for it; perform an immediate
// check to see if the device is already there.
pubSubRedisSingleton.withConnection(connection -> connection.async().get(getLinkedDeviceKey(linkDeviceTokenIdentifier)))
.thenAccept(response -> {
if (StringUtils.isNotBlank(response)) {
handleDeviceAdded(waitForDeviceFuture, response);
}
});
return waitForDeviceFuture;
}
private static String getLinkedDeviceKey(final String linkDeviceTokenIdentifier) {
return LINKED_DEVICE_PREFIX + linkDeviceTokenIdentifier;
}
@Override
public void message(final String pattern, final String channel, final String message) {
if (LINKED_DEVICE_KEYSPACE_PATTERN.equals(pattern) && "set".equalsIgnoreCase(message)) {
// The `- 1` here compensates for the '*' in the pattern
final String tokenIdentifier = channel.substring(LINKED_DEVICE_KEYSPACE_PATTERN.length() - 1);
Optional.ofNullable(waitForDeviceFuturesByTokenIdentifier.remove(tokenIdentifier))
.ifPresent(future -> pubSubRedisSingleton.withConnection(connection -> connection.async().get(getLinkedDeviceKey(tokenIdentifier)))
.thenAccept(deviceInfoJson -> handleDeviceAdded(future, deviceInfoJson)));
}
} }
private void handleDeviceAdded(final CompletableFuture<Optional<DeviceInfo>> future, final String deviceInfoJson) { private void handleDeviceAdded(final CompletableFuture<Optional<DeviceInfo>> future, final String deviceInfoJson) {
@ -1462,6 +1439,134 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
} }
} }
private static String getLinkedDeviceKey(final String linkDeviceTokenIdentifier) {
return LINKED_DEVICE_PREFIX + linkDeviceTokenIdentifier;
}
public CompletableFuture<Optional<RemoteAttachment>> waitForTransferArchive(final Account account, final Device device, final Duration timeout) {
final TimestampedDeviceIdentifier deviceIdentifier =
new TimestampedDeviceIdentifier(account.getIdentifier(IdentityType.ACI),
device.getId(),
Instant.ofEpochMilli(device.getCreated()));
return waitForPubSubKey(waitForTransferArchiveFuturesByDeviceIdentifier,
deviceIdentifier,
getTransferArchiveKey(account.getIdentifier(IdentityType.ACI), device.getId(), Instant.ofEpochMilli(device.getCreated())),
timeout,
this::handleTransferArchiveAdded);
}
public CompletableFuture<Void> recordTransferArchiveUpload(final Account account,
final byte destinationDeviceId,
final Instant destinationDeviceCreationTimestamp,
final RemoteAttachment transferArchive) {
final String key = getTransferArchiveKey(account.getIdentifier(IdentityType.ACI),
destinationDeviceId,
destinationDeviceCreationTimestamp);
try {
final String transferArchiveJson = SystemMapper.jsonMapper().writeValueAsString(transferArchive);
return pubSubRedisClient.withConnection(connection ->
connection.async().set(key, transferArchiveJson, SetArgs.Builder.ex(RECENTLY_ADDED_TRANSFER_ARCHIVE_TTL)))
.thenRun(Util.NOOP)
.toCompletableFuture();
} catch (final JsonProcessingException e) {
// This should never happen for well-defined objects we control
throw new UncheckedIOException(e);
}
}
private void handleTransferArchiveAdded(final CompletableFuture<Optional<RemoteAttachment>> future, final String transferArchiveJson) {
try {
future.complete(Optional.of(SystemMapper.jsonMapper().readValue(transferArchiveJson, RemoteAttachment.class)));
} catch (final JsonProcessingException e) {
logger.error("Could not parse transfer archive json", e);
future.completeExceptionally(e);
}
}
private static String getTransferArchiveKey(final UUID accountIdentifier,
final byte destinationDeviceId,
final Instant destinationDeviceCreationTimestamp) {
return TRANSFER_ARCHIVE_PREFIX + accountIdentifier.toString() +
":" + destinationDeviceId +
":" + destinationDeviceCreationTimestamp.toEpochMilli();
}
private <K, T> CompletableFuture<Optional<T>> waitForPubSubKey(final Map<K, CompletableFuture<Optional<T>>> futureMap,
final K mapKey,
final String redisKey,
final Duration timeout,
final BiConsumer<CompletableFuture<Optional<T>>, String> handler) {
final CompletableFuture<Optional<T>> future = new CompletableFuture<>();
future.completeOnTimeout(Optional.empty(), TimeUnit.MILLISECONDS.convert(timeout), TimeUnit.MILLISECONDS)
.whenComplete((maybeBackup, throwable) -> futureMap.remove(mapKey, future));
{
final CompletableFuture<Optional<T>> displacedFuture = futureMap.put(mapKey, future);
if (displacedFuture != null) {
displacedFuture.complete(Optional.empty());
}
}
// The Redis key we're waiting for may have been added before the caller issued a request to watch for it; check to
// see if it's already there
pubSubRedisClient.withConnection(connection -> connection.async().get(redisKey))
.thenAccept(response -> {
if (StringUtils.isNotBlank(response)) {
handler.accept(future, response);
}
});
return future;
}
@Override
public void message(final String pattern, final String channel, final String message) {
if (LINKED_DEVICE_KEYSPACE_PATTERN.equals(pattern) && "set".equalsIgnoreCase(message)) {
// The `- 1` here compensates for the '*' in the pattern
final String tokenIdentifier = channel.substring(LINKED_DEVICE_KEYSPACE_PATTERN.length() - 1);
Optional.ofNullable(waitForDeviceFuturesByTokenIdentifier.remove(tokenIdentifier))
.ifPresent(future -> pubSubRedisClient.withConnection(connection -> connection.async().get(getLinkedDeviceKey(tokenIdentifier)))
.thenAccept(deviceInfoJson -> handleDeviceAdded(future, deviceInfoJson)));
} else if (TRANSFER_ARCHIVE_KEYSPACE_PATTERN.equals(pattern) && "set".equalsIgnoreCase(message)) {
// The `- 1` here compensates for the '*' in the pattern
final String[] deviceIdentifierComponents =
channel.substring(TRANSFER_ARCHIVE_KEYSPACE_PATTERN.length() - 1).split(":", 3);
if (deviceIdentifierComponents.length != 3) {
logger.error("Could not parse timestamped device identifier; unexpected component count");
return;
}
try {
final TimestampedDeviceIdentifier deviceIdentifier;
final String transferArchiveKey;
{
final UUID accountIdentifier = UUID.fromString(deviceIdentifierComponents[0]);
final byte deviceId = Byte.parseByte(deviceIdentifierComponents[1]);
final Instant deviceCreationTimestamp = Instant.ofEpochMilli(Long.parseLong(deviceIdentifierComponents[2]));
deviceIdentifier = new TimestampedDeviceIdentifier(accountIdentifier, deviceId, deviceCreationTimestamp);
transferArchiveKey = getTransferArchiveKey(accountIdentifier, deviceId, deviceCreationTimestamp);
}
Optional.ofNullable(waitForTransferArchiveFuturesByDeviceIdentifier.remove(deviceIdentifier))
.ifPresent(future -> pubSubRedisClient.withConnection(connection -> connection.async().get(transferArchiveKey))
.thenAccept(transferArchiveJson -> handleTransferArchiveAdded(future, transferArchiveJson)));
} catch (final IllegalArgumentException e) {
logger.error("Could not parse timestamped device identifier", e);
}
}
}
private static MessageDigest getSha256MessageDigest() { private static MessageDigest getSha256MessageDigest() {
try { try {
return MessageDigest.getInstance("SHA-256"); return MessageDigest.getInstance("SHA-256");

View File

@ -0,0 +1,147 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.RemoteAttachment;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.redis.RedisServerExtension;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Base64;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
// ThreadMode.SEPARATE_THREAD protects against hangs in the remote Redis calls, as this mode allows the test code to be
// preempted by the timeout check
@Timeout(value = 5, threadMode = Timeout.ThreadMode.SEPARATE_THREAD)
public class AccountsManagerTransferArchiveIntegrationTest {
@RegisterExtension
static final RedisServerExtension PUBSUB_SERVER_EXTENSION = RedisServerExtension.builder().build();
private AccountsManager accountsManager;
@BeforeEach
void setUp() {
PUBSUB_SERVER_EXTENSION.getRedisClient().useConnection(connection -> {
connection.sync().flushall();
connection.sync().configSet("notify-keyspace-events", "K$");
});
//noinspection unchecked
accountsManager = new AccountsManager(
mock(Accounts.class),
mock(PhoneNumberIdentifiers.class),
mock(FaultTolerantRedisClusterClient.class),
PUBSUB_SERVER_EXTENSION.getRedisClient(),
mock(AccountLockManager.class),
mock(KeysManager.class),
mock(MessagesManager.class),
mock(ProfilesManager.class),
mock(SecureStorageClient.class),
mock(SecureValueRecovery2Client.class),
mock(ClientPresenceManager.class),
mock(RegistrationRecoveryPasswordsManager.class),
mock(ClientPublicKeysManager.class),
mock(ExecutorService.class),
mock(ExecutorService.class),
Clock.systemUTC(),
"link-device-secret".getBytes(StandardCharsets.UTF_8),
mock(DynamicConfigurationManager.class));
accountsManager.start();
}
@AfterEach
void tearDown() {
accountsManager.stop();
}
@Test
void waitForTransferArchive() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final long deviceCreated = System.currentTimeMillis();
final RemoteAttachment transferArchive =
new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("transfer-archive".getBytes(StandardCharsets.UTF_8)));
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
when(device.getCreated()).thenReturn(deviceCreated);
final Account account = mock(Account.class);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
final CompletableFuture<Optional<RemoteAttachment>> displacedFuture =
accountsManager.waitForTransferArchive(account, device, Duration.ofSeconds(5));
final CompletableFuture<Optional<RemoteAttachment>> activeFuture =
accountsManager.waitForTransferArchive(account, device, Duration.ofSeconds(5));
assertEquals(Optional.empty(), displacedFuture.join());
accountsManager.recordTransferArchiveUpload(account, deviceId, Instant.ofEpochMilli(deviceCreated), transferArchive).join();
assertEquals(Optional.of(transferArchive), activeFuture.join());
}
@Test
void waitForTransferArchiveAlreadyAdded() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final long deviceCreated = System.currentTimeMillis();
final RemoteAttachment transferArchive =
new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("transfer-archive".getBytes(StandardCharsets.UTF_8)));
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
when(device.getCreated()).thenReturn(deviceCreated);
final Account account = mock(Account.class);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
accountsManager.recordTransferArchiveUpload(account, deviceId, Instant.ofEpochMilli(deviceCreated), transferArchive).join();
assertEquals(Optional.of(transferArchive),
accountsManager.waitForTransferArchive(account, device, Duration.ofSeconds(5)).join());
}
@Test
void waitForTransferArchiveTimeout() {
final UUID accountIdentifier = UUID.randomUUID();
final byte deviceId = Device.PRIMARY_ID;
final long deviceCreated = System.currentTimeMillis();
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
when(device.getCreated()).thenReturn(deviceCreated);
final Account account = mock(Account.class);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier);
assertEquals(Optional.empty(),
accountsManager.waitForTransferArchive(account, device, Duration.ofMillis(1)).join());
}
}

View File

@ -492,7 +492,7 @@ public class AddRemoveDeviceIntegrationTest {
final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken); final String linkDeviceTokenIdentifier = AccountsManager.getLinkDeviceTokenIdentifier(linkDeviceToken);
final CompletableFuture<Optional<DeviceInfo>> linkedDeviceFuture = final CompletableFuture<Optional<DeviceInfo>> linkedDeviceFuture =
accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofMillis(10)); accountsManager.waitForNewLinkedDevice(linkDeviceTokenIdentifier, Duration.ofMillis(1));
final Optional<DeviceInfo> maybeDeviceInfo = linkedDeviceFuture.join(); final Optional<DeviceInfo> maybeDeviceInfo = linkedDeviceFuture.join();