diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommand.java index 18c03ea64..1295c397a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommand.java @@ -7,16 +7,13 @@ package org.whispersystems.textsecuregcm.workers; import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; +import com.google.common.annotations.VisibleForTesting; +import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; import io.micrometer.shaded.reactor.util.function.Tuple2; import io.micrometer.shaded.reactor.util.function.Tuples; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Set; -import java.util.function.Function; +import java.util.Collection; import java.util.function.Predicate; -import java.util.stream.Collectors; import net.sourceforge.argparse4j.inf.Subparser; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -31,12 +28,9 @@ public class RemoveExpiredLinkedDevicesCommand extends AbstractSinglePassCrawlAc private static final String DRY_RUN_ARGUMENT = "dry-run"; private static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency"; - private static final String BUFFER_ARGUMENT = "buffer"; private static final String REMOVED_DEVICES_COUNTER_NAME = name(RemoveExpiredLinkedDevicesCommand.class, "removedDevices"); - private static final String UPDATED_ACCOUNTS_COUNTER_NAME = name(RemoveExpiredLinkedDevicesCommand.class, - "updatedAccounts"); private static final String FAILED_UPDATES_COUNTER_NAME = name(RemoveExpiredLinkedDevicesCommand.class, "failedUpdates"); @@ -62,12 +56,6 @@ public class RemoveExpiredLinkedDevicesCommand extends AbstractSinglePassCrawlAc .dest(MAX_CONCURRENCY_ARGUMENT) .setDefault(DEFAULT_MAX_CONCURRENCY) .help("Max concurrency for DynamoDB operations"); - - subparser.addArgument("--buffer") - .type(Integer.class) - .dest(BUFFER_ARGUMENT) - .setDefault(16_384) - .help("Accounts to buffer"); } @Override @@ -75,57 +63,38 @@ public class RemoveExpiredLinkedDevicesCommand extends AbstractSinglePassCrawlAc final boolean dryRun = getNamespace().getBoolean(DRY_RUN_ARGUMENT); final int maxConcurrency = getNamespace().getInt(MAX_CONCURRENCY_ARGUMENT); - final int bufferSize = getNamespace().getInt(BUFFER_ARGUMENT); - accounts.map(a -> Tuples.of(a, getExpiredLinkedDeviceIds(a.getDevices()))) - .filter(accountAndExpiredDevices -> !accountAndExpiredDevices.getT2().isEmpty()) - .buffer(bufferSize) - .map(source -> { - final List>> shuffled = new ArrayList<>(source); - Collections.shuffle(shuffled); - return shuffled; - }) - .flatMapIterable(Function.identity()) - .flatMap(accountAndExpiredDevices -> { - final Account account = accountAndExpiredDevices.getT1(); - final Set expiredDevices = accountAndExpiredDevices.getT2(); + final Counter successCounter = Metrics.counter(REMOVED_DEVICES_COUNTER_NAME, "dryRun", String.valueOf(dryRun)); + final Counter errorCounter = Metrics.counter(FAILED_UPDATES_COUNTER_NAME); - final Mono accountUpdate = dryRun - ? Mono.empty() - : deleteDevices(account, expiredDevices); + accounts.flatMap(account -> Flux.fromIterable(getExpiredLinkedDeviceIds(account))) + .flatMap(accountAndExpiredDeviceId -> { + final Account account = accountAndExpiredDeviceId.getT1(); + final byte deviceId = accountAndExpiredDeviceId.getT2(); - return accountUpdate.thenReturn(expiredDevices.size()); + Mono removeDevice = dryRun + ? Mono.just(account) + : Mono.fromFuture(() -> getCommandDependencies().accountsManager().removeDevice(account, deviceId)); + return removeDevice.doOnNext(ignored -> successCounter.increment()) + .onErrorResume(t -> { + logger.warn("Failed to remove expired linked device {}.{}", account.getUuid(), deviceId, t); + errorCounter.increment(); + return Mono.empty(); + }); }, maxConcurrency) - .doOnNext(removedDevices -> { - Metrics.counter(REMOVED_DEVICES_COUNTER_NAME, "dryRun", String.valueOf(dryRun)).increment(removedDevices); - Metrics.counter(UPDATED_ACCOUNTS_COUNTER_NAME, "dryRun", String.valueOf(dryRun)).increment(); - }) .then() .block(); } - private Mono deleteDevices(final Account account, final Set expiredDevices) { - return Flux.fromIterable(expiredDevices) - .flatMap(deviceId -> - Mono.fromFuture(() -> getCommandDependencies().accountsManager().removeDevice(account, deviceId)), - // limit concurrency to avoid contested updates - 1) - .onErrorResume(t -> { - logger.warn("Failed to remove expired linked device {}", account.getUuid(), t); - Metrics.counter(FAILED_UPDATES_COUNTER_NAME).increment(); - return Mono.empty(); - }) - .count(); - } - - protected static Set getExpiredLinkedDeviceIds(List devices) { - return devices.stream() + @VisibleForTesting + protected static Collection> getExpiredLinkedDeviceIds(Account account) { + return account.getDevices().stream() // linked devices .filter(Predicate.not(Device::isPrimary)) // that are expired .filter(Device::isExpired) - .map(Device::getId) - .collect(Collectors.toSet()); + .map(device -> Tuples.of(account, device.getId())) + .toList(); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommandTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommandTest.java index d0c397d6d..01a24c336 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommandTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredLinkedDevicesCommandTest.java @@ -9,12 +9,15 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import io.micrometer.shaded.reactor.util.function.Tuple2; import java.util.List; import java.util.Set; +import java.util.stream.Collectors; import java.util.stream.Stream; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; class RemoveExpiredLinkedDevicesCommandTest { @@ -49,6 +52,14 @@ class RemoveExpiredLinkedDevicesCommandTest { @ParameterizedTest @MethodSource void getDeviceIdsToRemove(final List devices, final Set expectedIds) { - assertEquals(expectedIds, RemoveExpiredLinkedDevicesCommand.getExpiredLinkedDeviceIds(devices)); + final Account account = mock(Account.class); + when(account.getDevices()).thenReturn(devices); + + final Set actualIds = RemoveExpiredLinkedDevicesCommand.getExpiredLinkedDeviceIds(account) + .stream() + .map(Tuple2::getT2) + .collect(Collectors.toSet()); + + assertEquals(expectedIds, actualIds); } }