From f495ff483a94b21248211a841bc13c4045535d22 Mon Sep 17 00:00:00 2001 From: Chris Eager <79161849+eager-signal@users.noreply.github.com> Date: Fri, 5 Jan 2024 17:38:34 -0600 Subject: [PATCH] Update RemoveExpiredLinkedDevicesCommand to retry failures --- .../RemoveExpiredLinkedDevicesCommand.java | 93 +++++++++++++++---- ...RemoveExpiredLinkedDevicesCommandTest.java | 13 +-- 2 files changed, 76 insertions(+), 30 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommand.java index 1295c397a..bfe326019 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommand.java @@ -10,10 +10,14 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; import com.google.common.annotations.VisibleForTesting; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; -import io.micrometer.shaded.reactor.util.function.Tuple2; -import io.micrometer.shaded.reactor.util.function.Tuples; -import java.util.Collection; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.function.Function; import java.util.function.Predicate; +import java.util.stream.Collectors; import net.sourceforge.argparse4j.inf.Subparser; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -21,17 +25,27 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; +import reactor.util.retry.Retry; public class RemoveExpiredLinkedDevicesCommand extends AbstractSinglePassCrawlAccountsCommand { private static final int DEFAULT_MAX_CONCURRENCY = 16; + private static final int DEFAULT_BUFFER_SIZE = 16_384; + private static final int DEFAULT_RETRIES = 3; private static final String DRY_RUN_ARGUMENT = "dry-run"; private static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency"; + private static final String BUFFER_ARGUMENT = "buffer"; + private static final String RETRIES_ARGUMENT = "retries"; private static final String REMOVED_DEVICES_COUNTER_NAME = name(RemoveExpiredLinkedDevicesCommand.class, "removedDevices"); + private static final String RETRIED_UPDATES_COUNTER_NAME = name(RemoveExpiredLinkedDevicesCommand.class, + "retries"); + private static final String FAILED_UPDATES_COUNTER_NAME = name(RemoveExpiredLinkedDevicesCommand.class, "failedUpdates"); private static final Logger logger = LoggerFactory.getLogger(RemoveExpiredLinkedDevicesCommand.class); @@ -56,6 +70,18 @@ public class RemoveExpiredLinkedDevicesCommand extends AbstractSinglePassCrawlAc .dest(MAX_CONCURRENCY_ARGUMENT) .setDefault(DEFAULT_MAX_CONCURRENCY) .help("Max concurrency for DynamoDB operations"); + + subparser.addArgument("--buffer") + .type(Integer.class) + .dest(BUFFER_ARGUMENT) + .setDefault(DEFAULT_BUFFER_SIZE) + .help("Accounts to buffer"); + + subparser.addArgument("--retries") + .type(Integer.class) + .dest(RETRIES_ARGUMENT) + .setDefault(DEFAULT_RETRIES) + .help("Maximum number of retries permitted per device"); } @Override @@ -63,23 +89,33 @@ public class RemoveExpiredLinkedDevicesCommand extends AbstractSinglePassCrawlAc final boolean dryRun = getNamespace().getBoolean(DRY_RUN_ARGUMENT); final int maxConcurrency = getNamespace().getInt(MAX_CONCURRENCY_ARGUMENT); + final int bufferSize = getNamespace().getInt(BUFFER_ARGUMENT); + final int maxRetries = getNamespace().getInt(RETRIES_ARGUMENT); final Counter successCounter = Metrics.counter(REMOVED_DEVICES_COUNTER_NAME, "dryRun", String.valueOf(dryRun)); - final Counter errorCounter = Metrics.counter(FAILED_UPDATES_COUNTER_NAME); - accounts.flatMap(account -> Flux.fromIterable(getExpiredLinkedDeviceIds(account))) - .flatMap(accountAndExpiredDeviceId -> { - final Account account = accountAndExpiredDeviceId.getT1(); - final byte deviceId = accountAndExpiredDeviceId.getT2(); + accounts.map(a -> Tuples.of(a, getExpiredLinkedDeviceIds(a.getDevices()))) + .filter(accountAndExpiredDevices -> !accountAndExpiredDevices.getT2().isEmpty()) + .buffer(bufferSize) + .map(source -> { + final List>> shuffled = new ArrayList<>(source); + Collections.shuffle(shuffled); + return shuffled; + }) + .limitRate(2) + .flatMapIterable(Function.identity()) + .flatMap(accountAndExpiredDevices -> { + final Account account = accountAndExpiredDevices.getT1(); + final Set expiredDevices = accountAndExpiredDevices.getT2(); - Mono removeDevice = dryRun - ? Mono.just(account) - : Mono.fromFuture(() -> getCommandDependencies().accountsManager().removeDevice(account, deviceId)); + final Mono accountUpdate = dryRun + ? Mono.just((long) expiredDevices.size()) + : deleteDevices(account, expiredDevices, maxRetries); - return removeDevice.doOnNext(ignored -> successCounter.increment()) + return accountUpdate + .doOnNext(successCounter::increment) .onErrorResume(t -> { - logger.warn("Failed to remove expired linked device {}.{}", account.getUuid(), deviceId, t); - errorCounter.increment(); + logger.warn("Failed to remove expired linked devices for {}", account.getUuid(), t); return Mono.empty(); }); }, maxConcurrency) @@ -87,14 +123,35 @@ public class RemoveExpiredLinkedDevicesCommand extends AbstractSinglePassCrawlAc .block(); } + private Mono deleteDevices(final Account account, final Set expiredDevices, final int maxRetries) { + + final Counter retryCounter = Metrics.counter(RETRIED_UPDATES_COUNTER_NAME); + final Counter errorCounter = Metrics.counter(FAILED_UPDATES_COUNTER_NAME); + + return Flux.fromIterable(expiredDevices) + .flatMap(deviceId -> + Mono.fromFuture(() -> getCommandDependencies().accountsManager().removeDevice(account, deviceId)) + .retryWhen(Retry.backoff(maxRetries, Duration.ofSeconds(1)) + .doAfterRetry(ignored -> retryCounter.increment()) + .onRetryExhaustedThrow((spec, rs) -> rs.failure())) + .onErrorResume(t -> { + logger.info("Failed to remove expired linked device {}.{}", account.getUuid(), deviceId, t); + errorCounter.increment(); + return Mono.empty(); + }), + // limit concurrency to avoid contested updates + 1) + .count(); + } + @VisibleForTesting - protected static Collection> getExpiredLinkedDeviceIds(Account account) { - return account.getDevices().stream() + protected static Set getExpiredLinkedDeviceIds(List devices) { + return devices.stream() // linked devices .filter(Predicate.not(Device::isPrimary)) // that are expired .filter(Device::isExpired) - .map(device -> Tuples.of(account, device.getId())) - .toList(); + .map(Device::getId) + .collect(Collectors.toSet()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommandTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommandTest.java index 01a24c336..d0c397d6d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommandTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommandTest.java @@ -9,15 +9,12 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import io.micrometer.shaded.reactor.util.function.Tuple2; import java.util.List; import java.util.Set; -import java.util.stream.Collectors; import java.util.stream.Stream; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; class RemoveExpiredLinkedDevicesCommandTest { @@ -52,14 +49,6 @@ class RemoveExpiredLinkedDevicesCommandTest { @ParameterizedTest @MethodSource void getDeviceIdsToRemove(final List devices, final Set expectedIds) { - final Account account = mock(Account.class); - when(account.getDevices()).thenReturn(devices); - - final Set actualIds = RemoveExpiredLinkedDevicesCommand.getExpiredLinkedDeviceIds(account) - .stream() - .map(Tuple2::getT2) - .collect(Collectors.toSet()); - - assertEquals(expectedIds, actualIds); + assertEquals(expectedIds, RemoveExpiredLinkedDevicesCommand.getExpiredLinkedDeviceIds(devices)); } }