diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 6a47e6716..e5e45da9b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -217,6 +217,7 @@ import org.whispersystems.textsecuregcm.workers.CertificateCommand; import org.whispersystems.textsecuregcm.workers.CheckDynamicConfigurationCommand; import org.whispersystems.textsecuregcm.workers.DeleteUserCommand; import org.whispersystems.textsecuregcm.workers.MessagePersisterServiceCommand; +import org.whispersystems.textsecuregcm.workers.MigrateSignedECPreKeysCommand; import org.whispersystems.textsecuregcm.workers.ProcessPushNotificationFeedbackCommand; import org.whispersystems.textsecuregcm.workers.RemoveExpiredAccountsCommand; import org.whispersystems.textsecuregcm.workers.ScheduledApnPushNotificationSenderServiceCommand; @@ -273,6 +274,7 @@ public class WhisperServerService extends Application storeEcSignedPreKeyIfAbsent(final UUID identifier, final byte deviceId, + final ECSignedPreKey signedPreKey) { + return ecSignedPreKeys.storeIfAbsent(identifier, deviceId, signedPreKey); + } + public CompletableFuture storePqLastResort(final UUID identifier, final Map keys) { return pqLastResortKeys.store(identifier, keys); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStore.java index 3193f32f4..9b9fb8253 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStore.java @@ -7,17 +7,27 @@ package org.whispersystems.textsecuregcm.storage; import java.util.Map; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.util.AttributeValues; +import org.whispersystems.textsecuregcm.util.ExceptionUtils; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; +import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; public class RepeatedUseECSignedPreKeyStore extends RepeatedUseSignedPreKeyStore { + private final DynamoDbAsyncClient dynamoDbAsyncClient; + private final String tableName; + public RepeatedUseECSignedPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) { super(dynamoDbAsyncClient, tableName); + + this.dynamoDbAsyncClient = dynamoDbAsyncClient; + this.tableName = tableName; } @Override @@ -43,4 +53,21 @@ public class RepeatedUseECSignedPreKeyStore extends RepeatedUseSignedPreKeyStore throw new IllegalArgumentException(e); } } + + public CompletableFuture storeIfAbsent(final UUID identifier, final byte deviceId, final ECSignedPreKey signedPreKey) { + return dynamoDbAsyncClient.putItem(PutItemRequest.builder() + .tableName(tableName) + .item(getItemFromPreKey(identifier, deviceId, signedPreKey)) + .conditionExpression("attribute_not_exists(#public_key)") + .expressionAttributeNames(Map.of("#public_key", ATTR_PUBLIC_KEY)) + .build()) + .thenApply(ignored -> true) + .exceptionally(throwable -> { + if (ExceptionUtils.unwrap(throwable) instanceof ConditionalCheckFailedException) { + return false; + } + + throw ExceptionUtils.wrap(throwable); + }); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/MigrateSignedECPreKeysCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/MigrateSignedECPreKeysCommand.java new file mode 100644 index 000000000..1630946dd --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/MigrateSignedECPreKeysCommand.java @@ -0,0 +1,98 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.workers; + +import io.micrometer.core.instrument.Metrics; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.UUID; +import java.util.function.Function; +import net.sourceforge.argparse4j.inf.Subparser; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.KeysManager; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple3; +import reactor.util.function.Tuples; +import reactor.util.retry.Retry; + +public class MigrateSignedECPreKeysCommand extends AbstractSinglePassCrawlAccountsCommand { + + private static final String STORE_KEY_ATTEMPT_COUNTER_NAME = + MetricsUtil.name(MigrateSignedECPreKeysCommand.class, "storeKeyAttempt"); + + // It's tricky to find, but the default connection count for the AWS SDK's async DynamoDB client is 50. As long as + // we stay below that, we should be fine. + private static final int DEFAULT_MAX_CONCURRENCY = 32; + + private static final String BUFFER_ARGUMENT = "buffer"; + private static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency"; + + private static final Logger logger = LoggerFactory.getLogger(MigrateSignedECPreKeysCommand.class); + + public MigrateSignedECPreKeysCommand() { + super("migrate-signed-ec-pre-keys", "Migrate signed EC pre-keys from Account records to a dedicated table"); + } + + @Override + public void configure(final Subparser subparser) { + super.configure(subparser); + + subparser.addArgument("--max-concurrency") + .type(Integer.class) + .dest(MAX_CONCURRENCY_ARGUMENT) + .setDefault(DEFAULT_MAX_CONCURRENCY) + .help("Max concurrency for DynamoDB operations"); + + subparser.addArgument("--buffer") + .type(Integer.class) + .dest(BUFFER_ARGUMENT) + .setDefault(16_384) + .help("Devices to buffer"); + } + + @Override + protected void crawlAccounts(final Flux accounts) { + final KeysManager keysManager = getCommandDependencies().keysManager(); + final int maxConcurrency = getNamespace().getInt(MAX_CONCURRENCY_ARGUMENT); + final int bufferSize = getNamespace().getInt(BUFFER_ARGUMENT); + + accounts + .flatMap(account -> Flux.fromIterable(account.getDevices()) + .flatMap(device -> Flux.fromArray(IdentityType.values()) + .filter(identityType -> device.getSignedPreKey(identityType) != null) + .map(identityType -> Tuples.of(account.getIdentifier(identityType), device.getId(), device.getSignedPreKey(identityType))))) + .buffer(bufferSize) + .map(source -> { + final List> shuffled = new ArrayList<>(source); + Collections.shuffle(shuffled); + return shuffled; + }) + .flatMapIterable(Function.identity()) + .flatMap(keyTuple -> { + final UUID identifier = keyTuple.getT1(); + final byte deviceId = keyTuple.getT2(); + final ECSignedPreKey signedPreKey = keyTuple.getT3(); + + return Mono.fromFuture(() -> keysManager.storeEcSignedPreKeyIfAbsent(identifier, deviceId, signedPreKey)) + .retryWhen(Retry.backoff(3, Duration.ofSeconds(1)).onRetryExhaustedThrow((spec, rs) -> rs.failure())) + .onErrorResume(throwable -> { + logger.warn("Failed to migrate key for UUID {}, device {}", identifier, deviceId); + return Mono.just(false); + }) + .doOnSuccess(keyStored -> Metrics.counter(STORE_KEY_ATTEMPT_COUNTER_NAME, "stored", String.valueOf(keyStored)).increment()); + }, maxConcurrency) + .then() + .block(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java index 11c467ddc..c2d0d7fb4 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseECSignedPreKeyStoreTest.java @@ -5,7 +5,14 @@ package org.whispersystems.textsecuregcm.storage; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Optional; +import java.util.UUID; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; @@ -39,4 +46,21 @@ class RepeatedUseECSignedPreKeyStoreTest extends RepeatedUseSignedPreKeyStoreTes protected ECSignedPreKey generateSignedPreKey() { return KeysHelper.signedECPreKey(currentKeyId++, IDENTITY_KEY_PAIR); } + + @Test + void storeIfAbsent() { + final UUID identifier = UUID.randomUUID(); + final byte deviceIdWithExistingKey = 1; + final byte deviceIdWithoutExistingKey = deviceIdWithExistingKey + 1; + + final ECSignedPreKey originalSignedPreKey = generateSignedPreKey(); + + keyStore.store(identifier, deviceIdWithExistingKey, originalSignedPreKey).join(); + + assertFalse(keyStore.storeIfAbsent(identifier, deviceIdWithExistingKey, generateSignedPreKey()).join()); + assertTrue(keyStore.storeIfAbsent(identifier, deviceIdWithoutExistingKey, generateSignedPreKey()).join()); + + assertEquals(Optional.of(originalSignedPreKey), keyStore.find(identifier, deviceIdWithExistingKey).join()); + assertTrue(keyStore.find(identifier, deviceIdWithoutExistingKey).join().isPresent()); + } }