Reduce fan-out by processing a single stream of expired linked devices

This commit is contained in:
Chris Eager 2023-12-21 17:02:55 -06:00 committed by Chris Eager
parent 19a8a80a30
commit b9dd9fc47d
2 changed files with 35 additions and 55 deletions

View File

@ -7,16 +7,13 @@ package org.whispersystems.textsecuregcm.workers;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.shaded.reactor.util.function.Tuple2;
import io.micrometer.shaded.reactor.util.function.Tuples;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.Collection;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import net.sourceforge.argparse4j.inf.Subparser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -31,12 +28,9 @@ public class RemoveExpiredLinkedDevicesCommand extends AbstractSinglePassCrawlAc
private static final String DRY_RUN_ARGUMENT = "dry-run";
private static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency";
private static final String BUFFER_ARGUMENT = "buffer";
private static final String REMOVED_DEVICES_COUNTER_NAME = name(RemoveExpiredLinkedDevicesCommand.class,
"removedDevices");
private static final String UPDATED_ACCOUNTS_COUNTER_NAME = name(RemoveExpiredLinkedDevicesCommand.class,
"updatedAccounts");
private static final String FAILED_UPDATES_COUNTER_NAME = name(RemoveExpiredLinkedDevicesCommand.class,
"failedUpdates");
@ -62,12 +56,6 @@ public class RemoveExpiredLinkedDevicesCommand extends AbstractSinglePassCrawlAc
.dest(MAX_CONCURRENCY_ARGUMENT)
.setDefault(DEFAULT_MAX_CONCURRENCY)
.help("Max concurrency for DynamoDB operations");
subparser.addArgument("--buffer")
.type(Integer.class)
.dest(BUFFER_ARGUMENT)
.setDefault(16_384)
.help("Accounts to buffer");
}
@Override
@ -75,57 +63,38 @@ public class RemoveExpiredLinkedDevicesCommand extends AbstractSinglePassCrawlAc
final boolean dryRun = getNamespace().getBoolean(DRY_RUN_ARGUMENT);
final int maxConcurrency = getNamespace().getInt(MAX_CONCURRENCY_ARGUMENT);
final int bufferSize = getNamespace().getInt(BUFFER_ARGUMENT);
accounts.map(a -> Tuples.of(a, getExpiredLinkedDeviceIds(a.getDevices())))
.filter(accountAndExpiredDevices -> !accountAndExpiredDevices.getT2().isEmpty())
.buffer(bufferSize)
.map(source -> {
final List<Tuple2<Account, Set<Byte>>> shuffled = new ArrayList<>(source);
Collections.shuffle(shuffled);
return shuffled;
})
.flatMapIterable(Function.identity())
.flatMap(accountAndExpiredDevices -> {
final Account account = accountAndExpiredDevices.getT1();
final Set<Byte> expiredDevices = accountAndExpiredDevices.getT2();
final Counter successCounter = Metrics.counter(REMOVED_DEVICES_COUNTER_NAME, "dryRun", String.valueOf(dryRun));
final Counter errorCounter = Metrics.counter(FAILED_UPDATES_COUNTER_NAME);
final Mono<Long> accountUpdate = dryRun
? Mono.empty()
: deleteDevices(account, expiredDevices);
accounts.flatMap(account -> Flux.fromIterable(getExpiredLinkedDeviceIds(account)))
.flatMap(accountAndExpiredDeviceId -> {
final Account account = accountAndExpiredDeviceId.getT1();
final byte deviceId = accountAndExpiredDeviceId.getT2();
return accountUpdate.thenReturn(expiredDevices.size());
Mono<Account> removeDevice = dryRun
? Mono.just(account)
: Mono.fromFuture(() -> getCommandDependencies().accountsManager().removeDevice(account, deviceId));
return removeDevice.doOnNext(ignored -> successCounter.increment())
.onErrorResume(t -> {
logger.warn("Failed to remove expired linked device {}.{}", account.getUuid(), deviceId, t);
errorCounter.increment();
return Mono.empty();
});
}, maxConcurrency)
.doOnNext(removedDevices -> {
Metrics.counter(REMOVED_DEVICES_COUNTER_NAME, "dryRun", String.valueOf(dryRun)).increment(removedDevices);
Metrics.counter(UPDATED_ACCOUNTS_COUNTER_NAME, "dryRun", String.valueOf(dryRun)).increment();
})
.then()
.block();
}
private Mono<Long> deleteDevices(final Account account, final Set<Byte> expiredDevices) {
return Flux.fromIterable(expiredDevices)
.flatMap(deviceId ->
Mono.fromFuture(() -> getCommandDependencies().accountsManager().removeDevice(account, deviceId)),
// limit concurrency to avoid contested updates
1)
.onErrorResume(t -> {
logger.warn("Failed to remove expired linked device {}", account.getUuid(), t);
Metrics.counter(FAILED_UPDATES_COUNTER_NAME).increment();
return Mono.empty();
})
.count();
}
protected static Set<Byte> getExpiredLinkedDeviceIds(List<Device> devices) {
return devices.stream()
@VisibleForTesting
protected static Collection<Tuple2<Account, Byte>> getExpiredLinkedDeviceIds(Account account) {
return account.getDevices().stream()
// linked devices
.filter(Predicate.not(Device::isPrimary))
// that are expired
.filter(Device::isExpired)
.map(Device::getId)
.collect(Collectors.toSet());
.map(device -> Tuples.of(account, device.getId()))
.toList();
}
}

View File

@ -9,12 +9,15 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.micrometer.shaded.reactor.util.function.Tuple2;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
class RemoveExpiredLinkedDevicesCommandTest {
@ -49,6 +52,14 @@ class RemoveExpiredLinkedDevicesCommandTest {
@ParameterizedTest
@MethodSource
void getDeviceIdsToRemove(final List<Device> devices, final Set<Byte> expectedIds) {
assertEquals(expectedIds, RemoveExpiredLinkedDevicesCommand.getExpiredLinkedDeviceIds(devices));
final Account account = mock(Account.class);
when(account.getDevices()).thenReturn(devices);
final Set<Byte> actualIds = RemoveExpiredLinkedDevicesCommand.getExpiredLinkedDeviceIds(account)
.stream()
.map(Tuple2::getT2)
.collect(Collectors.toSet());
assertEquals(expectedIds, actualIds);
}
}