From f20d3043d6f245efabcae199841167d887b6af65 Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Thu, 7 Dec 2023 21:42:49 -0500 Subject: [PATCH] Process key migrations sequentially to better control concurrency --- .../MigrateSignedECPreKeysCommand.java | 78 ++++++++++++++----- 1 file changed, 57 insertions(+), 21 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/MigrateSignedECPreKeysCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/MigrateSignedECPreKeysCommand.java index 5065ca0fc..c9c8702ab 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/MigrateSignedECPreKeysCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/MigrateSignedECPreKeysCommand.java @@ -8,8 +8,13 @@ package org.whispersystems.textsecuregcm.workers; import io.micrometer.core.instrument.Metrics; import java.time.Duration; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.UUID; +import java.util.function.Function; +import net.sourceforge.argparse4j.inf.Subparser; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; @@ -27,37 +32,68 @@ public class MigrateSignedECPreKeysCommand extends AbstractSinglePassCrawlAccoun private static final String STORE_KEY_ATTEMPT_COUNTER_NAME = MetricsUtil.name(MigrateSignedECPreKeysCommand.class, "storeKeyAttempt"); - // It's tricky to find, but the default connection count for the AWS SDK's async DynamoDB client is 50. We expect - // four workers, so this should keep us below the concurrency limit. - private static final int MAX_CONCURRENCY = 12; + // It's tricky to find, but the default connection count for the AWS SDK's async DynamoDB client is 50. As long as + // we stay below that, we should be fine. + private static final int DEFAULT_MAX_CONCURRENCY = 32; + + private static final String BUFFER_ARGUMENT = "buffer"; + private static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency"; + + private static final Logger logger = LoggerFactory.getLogger(MigrateSignedECPreKeysCommand.class); public MigrateSignedECPreKeysCommand() { super("migrate-signed-ec-pre-keys", "Migrate signed EC pre-keys from Account records to a dedicated table"); } + @Override + public void configure(final Subparser subparser) { + super.configure(subparser); + + subparser.addArgument("--max-concurrency") + .type(Integer.class) + .dest(MAX_CONCURRENCY_ARGUMENT) + .setDefault(DEFAULT_MAX_CONCURRENCY) + .help("Max concurrency for DynamoDB operations"); + + subparser.addArgument("--buffer") + .type(Integer.class) + .dest(BUFFER_ARGUMENT) + .setDefault(16_384) + .help("Devices to buffer"); + } + @Override protected void crawlAccounts(final ParallelFlux accounts) { final KeysManager keysManager = getCommandDependencies().keysManager(); + final int maxConcurrency = getNamespace().getInt(MAX_CONCURRENCY_ARGUMENT); + final int bufferSize = getNamespace().getInt(BUFFER_ARGUMENT); - accounts.flatMap(account -> Flux.fromIterable(account.getDevices()) - .flatMap(device -> { - final List> keys = new ArrayList<>(2); + accounts + .sequential() + .flatMap(account -> Flux.fromIterable(account.getDevices()) + .flatMap(device -> Flux.fromArray(IdentityType.values()) + .filter(identityType -> device.getSignedPreKey(identityType) != null) + .map(identityType -> Tuples.of(account.getIdentifier(identityType), device.getId(), device.getSignedPreKey(identityType))))) + .buffer(bufferSize) + .map(source -> { + final List> shuffled = new ArrayList<>(source); + Collections.shuffle(shuffled); + return shuffled; + }) + .flatMapIterable(Function.identity()) + .flatMap(keyTuple -> { + final UUID identifier = keyTuple.getT1(); + final byte deviceId = keyTuple.getT2(); + final ECSignedPreKey signedPreKey = keyTuple.getT3(); - if (device.getSignedPreKey(IdentityType.ACI) != null) { - keys.add(Tuples.of(account.getUuid(), device.getId(), device.getSignedPreKey(IdentityType.ACI))); - } - - if (device.getSignedPreKey(IdentityType.PNI) != null) { - keys.add(Tuples.of(account.getPhoneNumberIdentifier(), device.getId(), - device.getSignedPreKey(IdentityType.PNI))); - } - - return Flux.fromIterable(keys); - })) - .flatMap(keyTuple -> Mono.fromFuture(() -> keysManager.storeEcSignedPreKeyIfAbsent(keyTuple.getT1(), keyTuple.getT2(), keyTuple.getT3())) - .retryWhen(Retry.backoff(3, Duration.ofSeconds(1)).onRetryExhaustedThrow((spec, rs) -> rs.failure())), - false, MAX_CONCURRENCY) - .doOnNext(keyStored -> Metrics.counter(STORE_KEY_ATTEMPT_COUNTER_NAME, "stored", String.valueOf(keyStored)).increment()) + return Mono.fromFuture(() -> keysManager.storeEcSignedPreKeyIfAbsent(identifier, deviceId, signedPreKey)) + .retryWhen(Retry.backoff(3, Duration.ofSeconds(1)).onRetryExhaustedThrow((spec, rs) -> rs.failure())) + .onErrorResume(throwable -> { + logger.warn("Failed to migrate key for UUID {}, device {}", identifier, deviceId); + return Mono.just(false); + }) + .doOnSuccess(keyStored -> Metrics.counter(STORE_KEY_ATTEMPT_COUNTER_NAME, "stored", String.valueOf(keyStored)).increment()); + }, maxConcurrency) .then() .block(); }