Add command to remove expired linked devices

This commit is contained in:
Chris Eager 2023-12-18 16:35:22 -05:00 committed by Chris Eager
parent 5b7f91827a
commit 3b509bf820
5 changed files with 198 additions and 5 deletions

View File

@ -223,6 +223,7 @@ import org.whispersystems.textsecuregcm.workers.MessagePersisterServiceCommand;
import org.whispersystems.textsecuregcm.workers.MigrateSignedECPreKeysCommand;
import org.whispersystems.textsecuregcm.workers.ProcessPushNotificationFeedbackCommand;
import org.whispersystems.textsecuregcm.workers.RemoveExpiredAccountsCommand;
import org.whispersystems.textsecuregcm.workers.RemoveExpiredLinkedDevicesCommand;
import org.whispersystems.textsecuregcm.workers.ScheduledApnPushNotificationSenderServiceCommand;
import org.whispersystems.textsecuregcm.workers.ServerVersionCommand;
import org.whispersystems.textsecuregcm.workers.SetRequestLoggingEnabledTask;
@ -280,6 +281,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
bootstrap.addCommand(new MigrateSignedECPreKeysCommand());
bootstrap.addCommand(new RemoveExpiredAccountsCommand(Clock.systemUTC()));
bootstrap.addCommand(new ProcessPushNotificationFeedbackCommand(Clock.systemUTC()));
bootstrap.addCommand(new RemoveExpiredLinkedDevicesCommand());
}
@Override

View File

@ -8,9 +8,9 @@ package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import java.time.Duration;
import java.util.List;
import java.util.OptionalInt;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.annotation.Nullable;
@ -28,6 +28,9 @@ public class Device {
public static final List<Byte> ALL_POSSIBLE_DEVICE_IDS = IntStream.range(Device.PRIMARY_ID, MAXIMUM_DEVICE_ID).boxed()
.map(Integer::byteValue).collect(Collectors.toList());
private static final long ALLOWED_LINKED_IDLE_MILLIS = Duration.ofDays(30).toMillis();
private static final long ALLOWED_PRIMARY_IDLE_MILLIS = Duration.ofDays(180).toMillis();
@JsonDeserialize(using = DeviceIdDeserializer.class)
@JsonProperty
private byte id;
@ -206,8 +209,13 @@ public class Device {
public boolean isEnabled() {
boolean hasChannel = fetchesMessages || StringUtils.isNotEmpty(getApnId()) || StringUtils.isNotEmpty(getGcmId());
return (id == PRIMARY_ID && hasChannel) ||
(id != PRIMARY_ID && hasChannel && lastSeen > (System.currentTimeMillis() - TimeUnit.DAYS.toMillis(30)));
return (id == PRIMARY_ID && hasChannel) || (id != PRIMARY_ID && hasChannel && !isExpired());
}
public boolean isExpired() {
return isPrimary()
? lastSeen < (System.currentTimeMillis() - ALLOWED_PRIMARY_IDLE_MILLIS)
: lastSeen < (System.currentTimeMillis() - ALLOWED_LINKED_IDLE_MILLIS);
}
public boolean getFetchesMessages() {

View File

@ -0,0 +1,100 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.workers;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.shaded.reactor.util.function.Tuples;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import net.sourceforge.argparse4j.inf.Subparser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
public class RemoveExpiredLinkedDevicesCommand extends AbstractSinglePassCrawlAccountsCommand {
private static final int MAX_CONCURRENCY = 16;
private static final String DRY_RUN_ARGUMENT = "dry-run";
private static final String REMOVED_DEVICES_COUNTER_NAME = name(RemoveExpiredLinkedDevicesCommand.class,
"removedDevices");
private static final String UPDATED_ACCOUNTS_COUNTER_NAME = name(RemoveExpiredLinkedDevicesCommand.class,
"updatedAccounts");
private static final Logger logger = LoggerFactory.getLogger(RemoveExpiredLinkedDevicesCommand.class);
public RemoveExpiredLinkedDevicesCommand() {
super("remove-expired-devices", "Removes expired linked devices");
}
@Override
public void configure(final Subparser subparser) {
super.configure(subparser);
subparser.addArgument("--dry-run")
.type(Boolean.class)
.dest(DRY_RUN_ARGUMENT)
.required(false)
.setDefault(true)
.help("If true, dont actually modify accounts with expired linked devices");
}
@Override
protected void crawlAccounts(final Flux<Account> accounts) {
final boolean dryRun = getNamespace().getBoolean(DRY_RUN_ARGUMENT);
accounts.map(a -> Tuples.of(a, getExpiredLinkedDeviceIds(a.getDevices())))
.filter(accountAndExpiredDevices -> !accountAndExpiredDevices.getT2().isEmpty())
.flatMap(accountAndExpiredDevices -> {
final Account account = accountAndExpiredDevices.getT1();
final Set<Byte> expiredDevices = accountAndExpiredDevices.getT2();
final Mono<Void> accountUpdate = dryRun
? Mono.empty()
: deleteDevices(account, expiredDevices);
return accountUpdate.thenReturn(expiredDevices.size())
.onErrorResume(t -> {
logger.warn("Failed to remove expired linked devices {}", account.getUuid(),
t);
return Mono.empty();
});
}, MAX_CONCURRENCY)
.doOnNext(removedDevices -> {
Metrics.counter(REMOVED_DEVICES_COUNTER_NAME, "dryRun", String.valueOf(dryRun)).increment(removedDevices);
Metrics.counter(UPDATED_ACCOUNTS_COUNTER_NAME, "dryRun", String.valueOf(dryRun)).increment();
})
.then()
.block();
}
private Mono<Void> deleteDevices(final Account account, final Set<Byte> expiredDevices) {
return Flux.fromIterable(expiredDevices)
.flatMap(deviceId ->
Mono.fromFuture(() -> getCommandDependencies().accountsManager().removeDevice(account, deviceId)),
// limit concurrency to avoid contested updates
1)
.then();
}
protected static Set<Byte> getExpiredLinkedDeviceIds(List<Device> devices) {
return devices.stream()
// linked devices
.filter(Predicate.not(Device::isPrimary))
// that are expired
.filter(Device::isExpired)
.map(Device::getId)
.collect(Collectors.toSet());
}
}

View File

@ -6,14 +6,14 @@
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import java.time.Duration;
import java.time.Instant;
import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
class DeviceTest {
@ -56,4 +56,33 @@ class DeviceTest {
Arguments.of(false, true, null, null, Duration.ofDays(1), true)
);
}
@ParameterizedTest
@CsvSource({
"true, P1D, false",
"true, P30D, false",
"true, P31D, false",
"true, P180D, false",
"true, P181D, true",
"false, P1D, false",
"false, P30D, false",
"false, P31D, true",
"false, P180D, true",
})
public void testIsExpired(final boolean primary, final Duration timeSinceLastSeen, final boolean expectExpired) {
final long lastSeen = Instant.now()
.minus(timeSinceLastSeen)
// buffer for test runtime
.plusSeconds(1)
.toEpochMilli();
final Device device = new Device();
device.setId(primary ? Device.PRIMARY_ID : Device.PRIMARY_ID + 1);
device.setCreated(lastSeen);
device.setLastSeen(lastSeen);
assertEquals(expectExpired, device.isExpired());
}
}

View File

@ -0,0 +1,54 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.workers;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.util.List;
import java.util.Set;
import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.storage.Device;
class RemoveExpiredLinkedDevicesCommandTest {
public static Stream<Arguments> getDeviceIdsToRemove() {
final Device primary = device(Device.PRIMARY_ID, false);
final byte expiredDevice2Id = 2;
final Device expiredDevice2 = device(expiredDevice2Id, true);
final byte deviceId3 = 3;
final Device device3 = device(deviceId3, false);
final Device expiredPrimary = device(Device.PRIMARY_ID, true);
return Stream.of(
Arguments.of(List.of(primary), Set.of()),
Arguments.of(List.of(primary, expiredDevice2), Set.of(expiredDevice2Id)),
Arguments.of(List.of(primary, expiredDevice2, device3), Set.of(expiredDevice2Id)),
Arguments.of(List.of(expiredPrimary, expiredDevice2, device3), Set.of(expiredDevice2Id))
);
}
private static Device device(byte id, boolean expired) {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(id);
when(device.isExpired()).thenReturn(expired);
when(device.isPrimary()).thenCallRealMethod();
return device;
}
@ParameterizedTest
@MethodSource
void getDeviceIdsToRemove(final List<Device> devices, final Set<Byte> expectedIds) {
assertEquals(expectedIds, RemoveExpiredLinkedDevicesCommand.getExpiredLinkedDeviceIds(devices));
}
}