diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 2e397bbcc..73ed28fc9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -266,13 +266,16 @@ import org.whispersystems.textsecuregcm.workers.CertificateCommand; import org.whispersystems.textsecuregcm.workers.CheckDynamicConfigurationCommand; import org.whispersystems.textsecuregcm.workers.DeleteUserCommand; import org.whispersystems.textsecuregcm.workers.IdleDeviceNotificationSchedulerFactory; +import org.whispersystems.textsecuregcm.workers.LockAccountsWithoutPqKeysCommand; import org.whispersystems.textsecuregcm.workers.MessagePersisterServiceCommand; import org.whispersystems.textsecuregcm.workers.NotifyIdleDevicesCommand; import org.whispersystems.textsecuregcm.workers.ProcessScheduledJobsServiceCommand; +import org.whispersystems.textsecuregcm.workers.RemoveAccountsWithoutPqKeysCommand; import org.whispersystems.textsecuregcm.workers.RemoveExpiredAccountsCommand; import org.whispersystems.textsecuregcm.workers.RemoveExpiredBackupsCommand; import org.whispersystems.textsecuregcm.workers.RemoveExpiredLinkedDevicesCommand; import org.whispersystems.textsecuregcm.workers.RemoveExpiredUsernameHoldsCommand; +import org.whispersystems.textsecuregcm.workers.RemoveLinkedDevicesWithoutPqKeysCommand; import org.whispersystems.textsecuregcm.workers.ScheduledApnPushNotificationSenderServiceCommand; import org.whispersystems.textsecuregcm.workers.ServerVersionCommand; import org.whispersystems.textsecuregcm.workers.SetRequestLoggingEnabledTask; @@ -335,6 +338,10 @@ public class WhisperServerService extends Application accounts) { + final boolean dryRun = getNamespace().getBoolean(DRY_RUN_ARGUMENT); + final int maxConcurrency = getNamespace().getInt(MAX_CONCURRENCY_ARGUMENT); + final int maxRetries = getNamespace().getInt(RETRIES_ARGUMENT); + + final AccountsManager accountsManager = getCommandDependencies().accountsManager(); + final PqKeysUtil pqKeysUtil = new PqKeysUtil(getCommandDependencies().keysManager(), maxConcurrency, maxRetries); + + accounts + .transform(pqKeysUtil::getAccountsWithoutPqKeys) + .flatMap(accountWithoutPqKeys -> { + final String platform = DevicePlatformUtil.getDevicePlatform(accountWithoutPqKeys.getPrimaryDevice()) + .map(Enum::name) + .orElse("unknown"); + + return dryRun + ? Mono.just(platform) + : Mono.fromFuture(() -> accountsManager.updateAsync(accountWithoutPqKeys, Account::lockAuthTokenHash)) + .retryWhen(Retry.backoff(maxRetries, Duration.ofSeconds(1)) + .onRetryExhaustedThrow((spec, rs) -> rs.failure())) + .thenReturn(platform) + .onErrorResume(throwable -> { + log.warn("Failed to lock account without PQ keys {}", accountWithoutPqKeys.getIdentifier(IdentityType.ACI), throwable); + return Mono.empty(); + }); + }) + .doOnNext(deletedAccountPlatform -> { + Metrics.counter(LOCKED_ACCOUNT_COUNTER_NAME, + "dryRun", String.valueOf(dryRun), + "platform", deletedAccountPlatform) + .increment(); + }) + .then() + .block(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/PqKeysUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/PqKeysUtil.java new file mode 100644 index 000000000..0293bd7d3 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/PqKeysUtil.java @@ -0,0 +1,47 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.workers; + +import java.time.Duration; +import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.KeysManager; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +class PqKeysUtil { + + private final KeysManager keysManager; + private final int maxConcurrency; + private final int maxRetries; + + private static final Logger log = LoggerFactory.getLogger(PqKeysUtil.class); + + PqKeysUtil(final KeysManager keysManager, final int maxConcurrency, final int maxRetries) { + this.keysManager = keysManager; + this.maxConcurrency = maxConcurrency; + this.maxRetries = maxRetries; + } + + public Flux getAccountsWithoutPqKeys(final Flux accounts) { + return accounts.flatMap(account -> Mono.fromFuture( + () -> keysManager.getLastResort(account.getIdentifier(IdentityType.ACI), Device.PRIMARY_ID)) + .retryWhen(Retry.backoff(maxRetries, Duration.ofSeconds(1)) + .onRetryExhaustedThrow((spec, rs) -> rs.failure())) + .onErrorResume(throwable -> { + log.warn("Failed to get last-resort key for {}", account.getIdentifier(IdentityType.ACI), throwable); + return Mono.empty(); + }) + .filter(Optional::isEmpty) + .map(ignored -> account), + maxConcurrency); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveAccountsWithoutPqKeysCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveAccountsWithoutPqKeysCommand.java new file mode 100644 index 000000000..e137e7492 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveAccountsWithoutPqKeysCommand.java @@ -0,0 +1,121 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.workers; + +import com.google.common.annotations.VisibleForTesting; +import io.micrometer.core.instrument.Metrics; +import java.time.Duration; +import net.sourceforge.argparse4j.inf.Subparser; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.metrics.DevicePlatformUtil; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +public class RemoveAccountsWithoutPqKeysCommand extends AbstractSinglePassCrawlAccountsCommand { + + @VisibleForTesting + static final String DRY_RUN_ARGUMENT = "dry-run"; + + @VisibleForTesting + static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency"; + + @VisibleForTesting + static final String RETRIES_ARGUMENT = "retries"; + + @VisibleForTesting + static final String MAX_ACCOUNTS_ARGUMENT = "max-accounts"; + + private static final String REMOVED_ACCOUNT_COUNTER_NAME = + MetricsUtil.name(RemoveAccountsWithoutPqKeysCommand.class, "removedAccount"); + + private static final Logger log = LoggerFactory.getLogger(RemoveAccountsWithoutPqKeysCommand.class); + + public RemoveAccountsWithoutPqKeysCommand() { + super("remove-accounts-without-pq-keys", "Removes accounts with primary devices that don't have PQ keys"); + } + + @Override + public void configure(final Subparser subparser) { + super.configure(subparser); + + subparser.addArgument("--dry-run") + .type(Boolean.class) + .dest(DRY_RUN_ARGUMENT) + .required(false) + .setDefault(true) + .help("If true, don’t actually modify accounts with expired linked devices"); + + subparser.addArgument("--max-concurrency") + .type(Integer.class) + .dest(MAX_CONCURRENCY_ARGUMENT) + .setDefault(16) + .help("Max concurrency for DynamoDB operations"); + + subparser.addArgument("--retries") + .type(Integer.class) + .dest(RETRIES_ARGUMENT) + .setDefault(3) + .help("Maximum number of DynamoDB retries permitted per device"); + + subparser.addArgument("--max-accounts") + .type(Integer.class) + .required(true) + .dest(MAX_ACCOUNTS_ARGUMENT) + .help("Maximum number of accounts to remove per run"); + } + + @Override + protected void crawlAccounts(final Flux accounts) { + final boolean dryRun = getNamespace().getBoolean(DRY_RUN_ARGUMENT); + final int maxConcurrency = getNamespace().getInt(MAX_CONCURRENCY_ARGUMENT); + final int maxRetries = getNamespace().getInt(RETRIES_ARGUMENT); + final int maxAccounts = getNamespace().getInt(MAX_ACCOUNTS_ARGUMENT); + + final AccountsManager accountsManager = getCommandDependencies().accountsManager(); + final PqKeysUtil pqKeysUtil = new PqKeysUtil(getCommandDependencies().keysManager(), maxConcurrency, maxRetries); + + accounts + .transform(pqKeysUtil::getAccountsWithoutPqKeys) + .take(maxAccounts) + .filter(accountWithoutPqKeys -> { + if (!accountWithoutPqKeys.hasLockedCredentials()) { + log.warn("Account {} is not locked", accountWithoutPqKeys.getIdentifier(IdentityType.ACI)); + } + + return accountWithoutPqKeys.hasLockedCredentials(); + }) + .flatMap(accountWithoutPqKeys -> { + final String platform = DevicePlatformUtil.getDevicePlatform(accountWithoutPqKeys.getPrimaryDevice()) + .map(Enum::name) + .orElse("unknown"); + + return dryRun + ? Mono.just(platform) + : Mono.fromFuture(() -> accountsManager.delete(accountWithoutPqKeys, AccountsManager.DeletionReason.ADMIN_DELETED)) + .retryWhen(Retry.backoff(maxRetries, Duration.ofSeconds(1)) + .onRetryExhaustedThrow((spec, rs) -> rs.failure())) + .thenReturn(platform) + .onErrorResume(throwable -> { + log.warn("Failed to remove account without PQ keys {}", accountWithoutPqKeys.getIdentifier(IdentityType.ACI), throwable); + return Mono.empty(); + }); + }) + .doOnNext(deletedAccountPlatform -> { + Metrics.counter(REMOVED_ACCOUNT_COUNTER_NAME, + "dryRun", String.valueOf(dryRun), + "platform", deletedAccountPlatform) + .increment(); + }) + .then() + .block(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveLinkedDevicesWithoutPqKeysCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveLinkedDevicesWithoutPqKeysCommand.java new file mode 100644 index 000000000..ded0502fd --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveLinkedDevicesWithoutPqKeysCommand.java @@ -0,0 +1,110 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.workers; + +import com.google.common.annotations.VisibleForTesting; +import io.micrometer.core.instrument.Metrics; +import java.time.Duration; +import net.sourceforge.argparse4j.inf.Subparser; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.metrics.DevicePlatformUtil; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.KeysManager; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuples; +import reactor.util.retry.Retry; + +public class RemoveLinkedDevicesWithoutPqKeysCommand extends AbstractSinglePassCrawlAccountsCommand { + + @VisibleForTesting + static final String DRY_RUN_ARGUMENT = "dry-run"; + + @VisibleForTesting + static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency"; + + @VisibleForTesting + static final String RETRIES_ARGUMENT = "retries"; + + private static final String REMOVED_DEVICE_COUNTER_NAME = + MetricsUtil.name(RemoveLinkedDevicesWithoutPqKeysCommand.class, "removedDevice"); + + private static final Logger log = LoggerFactory.getLogger(RemoveLinkedDevicesWithoutPqKeysCommand.class); + + public RemoveLinkedDevicesWithoutPqKeysCommand() { + super("remove-linked-devices-without-pq-keys", "Removes linked devices that don't have PQ keys"); + } + + @Override + public void configure(final Subparser subparser) { + super.configure(subparser); + + subparser.addArgument("--dry-run") + .type(Boolean.class) + .dest(DRY_RUN_ARGUMENT) + .required(false) + .setDefault(true) + .help("If true, don’t actually modify accounts with expired linked devices"); + + subparser.addArgument("--max-concurrency") + .type(Integer.class) + .dest(MAX_CONCURRENCY_ARGUMENT) + .setDefault(16) + .help("Max concurrency for DynamoDB operations"); + + subparser.addArgument("--retries") + .type(Integer.class) + .dest(RETRIES_ARGUMENT) + .setDefault(3) + .help("Maximum number of DynamoDB retries permitted per device"); + } + + @Override + protected void crawlAccounts(final Flux accounts) { + final boolean dryRun = getNamespace().getBoolean(DRY_RUN_ARGUMENT); + final int maxConcurrency = getNamespace().getInt(MAX_CONCURRENCY_ARGUMENT); + final int maxRetries = getNamespace().getInt(RETRIES_ARGUMENT); + + final AccountsManager accountsManager = getCommandDependencies().accountsManager(); + final KeysManager keysManager = getCommandDependencies().keysManager(); + + accounts + .flatMap( + account -> Mono.fromFuture(() -> keysManager.getPqEnabledDevices(account.getIdentifier(IdentityType.ACI))) + .retryWhen(Retry.backoff(maxRetries, Duration.ofSeconds(1)) + .onRetryExhaustedThrow((spec, rs) -> rs.failure())) + .onErrorResume(throwable -> { + log.warn("Failed to get PQ key presence for account: {}", account.getIdentifier(IdentityType.ACI)); + return Mono.empty(); + }) + .flatMapMany(pqEnabledDeviceIds -> Flux.fromIterable(account.getDevices()) + .filter(device -> !device.isPrimary()) + .filter(device -> !pqEnabledDeviceIds.contains(device.getId())) + .map(device -> Tuples.of(account, device))), maxConcurrency) + .flatMap(accountAndDevice -> dryRun + ? Mono.just(accountAndDevice.getT2()) + : Mono.fromFuture(() -> accountsManager.removeDevice(accountAndDevice.getT1(), accountAndDevice.getT2().getId())) + .retryWhen(Retry.backoff(maxRetries, Duration.ofSeconds(1)) + .onRetryExhaustedThrow((spec, rs) -> rs.failure())) + .onErrorResume(throwable -> { + log.warn("Failed to remove linked device without PQ keys: {}:{}", + accountAndDevice.getT1().getIdentifier(IdentityType.ACI), accountAndDevice.getT2().getId()); + + return Mono.empty(); + }) + .map(ignored -> accountAndDevice.getT2()), maxConcurrency) + .doOnNext(removedDevice -> Metrics.counter(REMOVED_DEVICE_COUNTER_NAME, + "dryRun", String.valueOf(dryRun), + "platform", DevicePlatformUtil.getDevicePlatform(removedDevice).map(Enum::name).orElse("unknown")) + .increment()) + .then() + .block(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/LockAccountsWithoutPqKeysCommandTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/LockAccountsWithoutPqKeysCommandTest.java new file mode 100644 index 000000000..ddd645ab9 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/LockAccountsWithoutPqKeysCommandTest.java @@ -0,0 +1,131 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.workers; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import net.sourceforge.argparse4j.inf.Namespace; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.KeysManager; +import reactor.core.publisher.Flux; + +class LockAccountsWithoutPqKeysCommandTest { + + private AccountsManager accountsManager; + private KeysManager keysManager; + + private static class TestLockAccountsWithoutPqKeysCommand extends LockAccountsWithoutPqKeysCommand { + + private final CommandDependencies commandDependencies; + private final Namespace namespace; + + TestLockAccountsWithoutPqKeysCommand(final AccountsManager accountsManager, + final KeysManager keysManager, + final boolean dryRun) { + + commandDependencies = new CommandDependencies(accountsManager, + null, + null, + null, + null, + keysManager, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null); + + namespace = new Namespace(Map.of( + LockAccountsWithoutPqKeysCommand.DRY_RUN_ARGUMENT, dryRun, + LockAccountsWithoutPqKeysCommand.MAX_CONCURRENCY_ARGUMENT, 16, + LockAccountsWithoutPqKeysCommand.RETRIES_ARGUMENT, 3)); + } + + @Override + protected CommandDependencies getCommandDependencies() { + return commandDependencies; + } + + @Override + protected Namespace getNamespace() { + return namespace; + } + } + + @BeforeEach + void setUp() { + accountsManager = mock(AccountsManager.class); + keysManager = mock(KeysManager.class); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void crawlAccounts(final boolean dryRun) { + final UUID accountIdentifierWithPqKeys = UUID.randomUUID(); + + final Account accountWithPqKeys = mock(Account.class); + when(accountWithPqKeys.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifierWithPqKeys); + + final Account accountWithoutPqKeys = mock(Account.class); + when(accountWithoutPqKeys.getIdentifier(IdentityType.ACI)).thenReturn(UUID.randomUUID()); + when(accountWithoutPqKeys.getPrimaryDevice()).thenReturn(mock(Device.class)); + + when(keysManager.getLastResort(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); + when(keysManager.getLastResort(accountIdentifierWithPqKeys, Device.PRIMARY_ID)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(mock(KEMSignedPreKey.class)))); + + when(accountsManager.delete(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + when(accountsManager.updateAsync(any(), any())).thenAnswer(invocation -> { + final Account account = invocation.getArgument(0); + final Consumer updater = invocation.getArgument(1); + + updater.accept(account); + + return CompletableFuture.completedFuture(account); + }); + + final LockAccountsWithoutPqKeysCommand lockAccountsWithoutPqKeysCommand = + new TestLockAccountsWithoutPqKeysCommand(accountsManager, keysManager, dryRun); + + lockAccountsWithoutPqKeysCommand.crawlAccounts(Flux.just(accountWithPqKeys, accountWithoutPqKeys)); + + if (dryRun) { + verify(accountsManager, never()).updateAsync(any(), any()); + } else { + verify(accountsManager).updateAsync(eq(accountWithoutPqKeys), any()); + verifyNoMoreInteractions(accountsManager); + + verify(accountWithoutPqKeys).lockAuthTokenHash(); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/PqKeysUtilTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/PqKeysUtilTest.java new file mode 100644 index 000000000..5521c027c --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/PqKeysUtilTest.java @@ -0,0 +1,58 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.workers; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.KeysManager; +import reactor.core.publisher.Flux; + +class PqKeysUtilTest { + + private KeysManager keysManager; + private PqKeysUtil pqKeysUtil; + + @BeforeEach + void setUp() { + keysManager = mock(KeysManager.class); + pqKeysUtil = new PqKeysUtil(keysManager, 16, 3); + } + + @Test + void getAccountsWithoutPqKeys() { + final UUID aciWithPqKeys = UUID.randomUUID(); + + final Account accountWithPqKeys = mock(Account.class); + when(accountWithPqKeys.getIdentifier(IdentityType.ACI)).thenReturn(aciWithPqKeys); + + final Account accountWithoutPqKeys = mock(Account.class); + when(accountWithoutPqKeys.getIdentifier(IdentityType.ACI)).thenReturn(UUID.randomUUID()); + + when(keysManager.getLastResort(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); + when(keysManager.getLastResort(aciWithPqKeys, Device.PRIMARY_ID)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(mock(KEMSignedPreKey.class)))); + + assertEquals(List.of(accountWithoutPqKeys), + Flux.just(accountWithPqKeys, accountWithoutPqKeys) + .transform(pqKeysUtil::getAccountsWithoutPqKeys) + .collectList() + .block()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveAccountsWithoutPqKeysCommandTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveAccountsWithoutPqKeysCommandTest.java new file mode 100644 index 000000000..c89a5bc49 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveAccountsWithoutPqKeysCommandTest.java @@ -0,0 +1,131 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.workers; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import net.sourceforge.argparse4j.inf.Namespace; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.KeysManager; +import reactor.core.publisher.Flux; + +class RemoveAccountsWithoutPqKeysCommandTest { + + private AccountsManager accountsManager; + private KeysManager keysManager; + + private static class TestRemoveAccountsWithoutPqKeysCommand extends RemoveAccountsWithoutPqKeysCommand { + + private final CommandDependencies commandDependencies; + private final Namespace namespace; + + TestRemoveAccountsWithoutPqKeysCommand(final AccountsManager accountsManager, + final KeysManager keysManager, + final int maxAccounts, + final boolean dryRun) { + + commandDependencies = new CommandDependencies(accountsManager, + null, + null, + null, + null, + keysManager, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null); + + namespace = new Namespace(Map.of( + RemoveAccountsWithoutPqKeysCommand.DRY_RUN_ARGUMENT, dryRun, + RemoveAccountsWithoutPqKeysCommand.MAX_CONCURRENCY_ARGUMENT, 16, + RemoveAccountsWithoutPqKeysCommand.RETRIES_ARGUMENT, 3, + RemoveAccountsWithoutPqKeysCommand.MAX_ACCOUNTS_ARGUMENT, maxAccounts)); + } + + @Override + protected CommandDependencies getCommandDependencies() { + return commandDependencies; + } + + @Override + protected Namespace getNamespace() { + return namespace; + } + } + + @BeforeEach + void setUp() { + accountsManager = mock(AccountsManager.class); + keysManager = mock(KeysManager.class); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void crawlAccounts(final boolean dryRun) { + final UUID accountIdentifierWithPqKeys = UUID.randomUUID(); + + final Account accountWithPqKeys = mock(Account.class); + when(accountWithPqKeys.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifierWithPqKeys); + when(accountWithPqKeys.getPrimaryDevice()).thenReturn(mock(Device.class)); + when(accountWithPqKeys.hasLockedCredentials()).thenReturn(true); + + when(keysManager.getLastResort(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); + when(keysManager.getLastResort(accountIdentifierWithPqKeys, Device.PRIMARY_ID)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(mock(KEMSignedPreKey.class)))); + + when(accountsManager.delete(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + + final int maxAccounts = 5; + + final RemoveAccountsWithoutPqKeysCommand removeAccountsWithoutPqKeysCommand = + new TestRemoveAccountsWithoutPqKeysCommand(accountsManager, keysManager, maxAccounts, dryRun); + + removeAccountsWithoutPqKeysCommand.crawlAccounts(Flux.concat( + Flux.just(accountWithPqKeys), + Flux.generate(sink -> { + final Account accountWithoutPqKeys = mock(Account.class); + when(accountWithoutPqKeys.getIdentifier(IdentityType.ACI)).thenReturn(UUID.randomUUID()); + when(accountWithoutPqKeys.getPrimaryDevice()).thenReturn(mock(Device.class)); + when(accountWithoutPqKeys.hasLockedCredentials()).thenReturn(true); + + sink.next(accountWithoutPqKeys); + }))); + + if (dryRun) { + verify(accountsManager, never()).delete(any(), any()); + } else { + verify(accountsManager, times(maxAccounts)).delete(any(), eq(AccountsManager.DeletionReason.ADMIN_DELETED)); + verify(accountsManager, never()).delete(eq(accountWithPqKeys), any()); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveLinkedDevicesWithoutPqKeysCommandTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveLinkedDevicesWithoutPqKeysCommandTest.java new file mode 100644 index 000000000..03e617764 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveLinkedDevicesWithoutPqKeysCommandTest.java @@ -0,0 +1,129 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.workers; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import net.sourceforge.argparse4j.inf.Namespace; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.KeysManager; +import reactor.core.publisher.Flux; + +class RemoveLinkedDevicesWithoutPqKeysCommandTest { + + private AccountsManager accountsManager; + private KeysManager keysManager; + + private static class TestRemoveLinkedDevicesWithoutPqKeysCommand extends RemoveLinkedDevicesWithoutPqKeysCommand { + + private final CommandDependencies commandDependencies; + private final Namespace namespace; + + TestRemoveLinkedDevicesWithoutPqKeysCommand(final AccountsManager accountsManager, + final KeysManager keysManager, + final boolean dryRun) { + + commandDependencies = new CommandDependencies(accountsManager, + null, + null, + null, + null, + keysManager, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null); + + namespace = new Namespace(Map.of( + RemoveLinkedDevicesWithoutPqKeysCommand.DRY_RUN_ARGUMENT, dryRun, + RemoveLinkedDevicesWithoutPqKeysCommand.MAX_CONCURRENCY_ARGUMENT, 16, + RemoveLinkedDevicesWithoutPqKeysCommand.RETRIES_ARGUMENT, 3)); + } + + @Override + protected CommandDependencies getCommandDependencies() { + return commandDependencies; + } + + @Override + protected Namespace getNamespace() { + return namespace; + } + } + + @BeforeEach + void setUp() { + accountsManager = mock(AccountsManager.class); + keysManager = mock(KeysManager.class); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void crawlAccounts(final boolean dryRun) { + final UUID accountIdentifier = UUID.randomUUID(); + final byte deviceIdWithPqKeys = Device.PRIMARY_ID + 1; + final byte deviceIdWithoutPqKeys = deviceIdWithPqKeys + 1; + + final Device primaryDevice = mock(Device.class); + when(primaryDevice.isPrimary()).thenReturn(true); + when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID); + + final Device linkedDeviceWithPqKeys = mock(Device.class); + when(linkedDeviceWithPqKeys.isPrimary()).thenReturn(false); + when(linkedDeviceWithPqKeys.getId()).thenReturn(deviceIdWithPqKeys); + + final Device linkedDeviceWithoutPqKeys = mock(Device.class); + when(linkedDeviceWithoutPqKeys.isPrimary()).thenReturn(false); + when(linkedDeviceWithoutPqKeys.getId()).thenReturn(deviceIdWithoutPqKeys); + + final Account account = mock(Account.class); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountIdentifier); + when(account.getDevices()).thenReturn(List.of(primaryDevice, linkedDeviceWithPqKeys, linkedDeviceWithoutPqKeys)); + + when(keysManager.getPqEnabledDevices(accountIdentifier)) + .thenReturn(CompletableFuture.completedFuture(List.of(deviceIdWithPqKeys))); + + when(accountsManager.removeDevice(any(), anyByte())) + .thenAnswer(invocation -> CompletableFuture.completedFuture(invocation.getArgument(0))); + + final RemoveLinkedDevicesWithoutPqKeysCommand removeLinkedDevicesWithoutPqKeysCommand = + new TestRemoveLinkedDevicesWithoutPqKeysCommand(accountsManager, keysManager, dryRun); + + removeLinkedDevicesWithoutPqKeysCommand.crawlAccounts(Flux.just(account)); + + if (dryRun) { + verify(accountsManager, never()).removeDevice(any(), anyByte()); + } else { + verify(accountsManager).removeDevice(account, deviceIdWithoutPqKeys); + verifyNoMoreInteractions(accountsManager); + } + } +}