From aaa36fd8f53d96311ec9b0d0363c18788882cbca Mon Sep 17 00:00:00 2001 From: Ravi Khadiwala Date: Mon, 2 Jun 2025 11:52:05 -0500 Subject: [PATCH] Add a crawler for orphaned prekey pages --- .../storage/DeviceKEMPreKeyPages.java | 24 +++ .../textsecuregcm/storage/KeysManager.java | 30 +++- .../storage/PagedSingleUseKEMPreKeyStore.java | 67 +++++++- .../workers/CommandDependencies.java | 8 +- .../RemoveOrphanedPreKeyPagesCommand.java | 143 ++++++++++++++++++ .../PagedSingleUseKEMPreKeyStoreTest.java | 88 ++++++++--- .../RemoveOrphanedPreKeyPagesCommandTest.java | 138 +++++++++++++++++ 7 files changed, 470 insertions(+), 28 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/DeviceKEMPreKeyPages.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveOrphanedPreKeyPagesCommand.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveOrphanedPreKeyPagesCommandTest.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeviceKEMPreKeyPages.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeviceKEMPreKeyPages.java new file mode 100644 index 000000000..9ae901ff7 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/DeviceKEMPreKeyPages.java @@ -0,0 +1,24 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import java.time.Instant; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; + +/** + * The prekey pages stored for a particular device + * + * @param identifier The account identifier or phone number identifier that the keys belong to + * @param deviceId The device identifier + * @param currentPage If present, the active stored page prekeys are being distributed from + * @param pageIdToLastModified The last modified time for all the device's stored pages, keyed by the pageId + */ +public record DeviceKEMPreKeyPages( + UUID identifier, byte deviceId, + Optional currentPage, + Map pageIdToLastModified) {} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java index 602ae61d0..13e9c4b64 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java @@ -5,6 +5,7 @@ package org.whispersystems.textsecuregcm.storage; +import java.time.Instant; import java.util.List; import java.util.Optional; import java.util.UUID; @@ -12,6 +13,7 @@ import java.util.concurrent.CompletableFuture; import org.whispersystems.textsecuregcm.entities.ECPreKey; import org.whispersystems.textsecuregcm.entities.ECSignedPreKey; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; +import reactor.core.publisher.Flux; import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem; public class KeysManager { @@ -131,8 +133,32 @@ public class KeysManager { public CompletableFuture deleteSingleUsePreKeys(final UUID accountUuid, final byte deviceId) { return CompletableFuture.allOf( - ecPreKeys.delete(accountUuid, deviceId), - pqPreKeys.delete(accountUuid, deviceId) + ecPreKeys.delete(accountUuid, deviceId), + pqPreKeys.delete(accountUuid, deviceId) ); } + + /** + * List all the current remotely stored prekey pages across all devices. Pages that are no longer in use can be + * removed with {@link #pruneDeadPage} + * + * @param lookupConcurrency the number of concurrent lookup operations to perform when populating list results + * @return All stored prekey pages + */ + public Flux listStoredKEMPreKeyPages(int lookupConcurrency) { + return pagedPqPreKeys.listStoredPages(lookupConcurrency); + } + + /** + * Remove a prekey page that is no longer in use. A page should only be removed if it is not the active page and + * it has no chance of being updated to be. + * + * @param identifier The owner of the dead page + * @param deviceId The device of the dead page + * @param pageId The dead page to remove from storage + * @return A future that completes when the page has been removed + */ + public CompletableFuture pruneDeadPage(final UUID identifier, final byte deviceId, final UUID pageId) { + return pagedPqPreKeys.deleteBundleFromS3(identifier, deviceId, pageId); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStore.java index 3ebd1cecf..d15582993 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStore.java @@ -12,6 +12,7 @@ import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Timer; import java.io.IOException; import java.nio.ByteBuffer; +import java.time.Instant; import java.util.Comparator; import java.util.List; import java.util.Map; @@ -19,6 +20,9 @@ import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; +import java.util.function.Function; +import java.util.stream.Collectors; + import org.signal.libsignal.protocol.InvalidKeyException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,9 +44,7 @@ import software.amazon.awssdk.services.dynamodb.model.QueryRequest; import software.amazon.awssdk.services.dynamodb.model.ReturnValue; import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; import software.amazon.awssdk.services.s3.S3AsyncClient; -import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; -import software.amazon.awssdk.services.s3.model.GetObjectRequest; -import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.*; /** * @implNote This version of a {@link SingleUsePreKeyStore} store bundles prekeys into "pages", which are stored in on @@ -294,6 +296,40 @@ public class PagedSingleUseKEMPreKeyStore { .thenRun(() -> sample.stop(deleteForDeviceTimer)); } + + public Flux listStoredPages(int lookupConcurrency) { + return Flux + .from(s3AsyncClient.listObjectsV2Paginator(ListObjectsV2Request.builder() + .bucket(bucketName) + .build())) + .flatMapIterable(ListObjectsV2Response::contents) + .map(PagedSingleUseKEMPreKeyStore::parseS3Key) + .bufferUntilChanged(Function.identity(), S3PageKey::fromSameDevice) + .flatMapSequential(pages -> { + final UUID identifier = pages.getFirst().identifier(); + final byte deviceId = pages.getFirst().deviceId(); + return Mono.fromCompletionStage(() -> dynamoDbAsyncClient.getItem(GetItemRequest.builder() + .tableName(tableName) + .key(Map.of( + KEY_ACCOUNT_UUID, AttributeValues.fromUUID(identifier), + KEY_DEVICE_ID, AttributeValues.fromInt(deviceId))) + // Make sure we get the most up to date pageId to minimize cases where we see a new page in S3 but + // view a stale dynamodb record + .consistentRead(true) + .projectionExpression("#uuid,#deviceid,#pageid") + .expressionAttributeNames(Map.of( + "#uuid", KEY_ACCOUNT_UUID, + "#deviceid", KEY_DEVICE_ID, + "#pageid", ATTR_PAGE_ID)) + .build()) + .thenApply(getItemResponse -> new DeviceKEMPreKeyPages( + identifier, + deviceId, + Optional.ofNullable(AttributeValues.getUUID(getItemResponse.item(), ATTR_PAGE_ID, null)), + pages.stream().collect(Collectors.toMap(S3PageKey::pageId, S3PageKey::lastModified))))); + }, lookupConcurrency); + } + private CompletableFuture deleteItems(final UUID identifier, final Flux> items) { return items @@ -322,6 +358,29 @@ public class PagedSingleUseKEMPreKeyStore { return String.format("%s/%s/%s", identifier, deviceId, pageId); } + private record S3PageKey(UUID identifier, byte deviceId, UUID pageId, Instant lastModified) { + + boolean fromSameDevice(final S3PageKey other) { + return deviceId == other.deviceId && identifier.equals(other.identifier); + } + } + + private static S3PageKey parseS3Key(final S3Object page) { + try { + final String[] parts = page.key().split("/", 3); + if (parts.length != 3 || parts[2].contains("/")) { + throw new IllegalArgumentException("wrong number of path components"); + } + return new S3PageKey( + UUID.fromString(parts[0]), + Byte.parseByte(parts[1]), + UUID.fromString(parts[2]), page.lastModified()); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("invalid s3 page key: " + page.key(), e); + } + } + + private CompletableFuture writeBundleToS3(final UUID identifier, final byte deviceId, final ByteBuffer bundle) { final UUID pageId = UUID.randomUUID(); @@ -332,7 +391,7 @@ public class PagedSingleUseKEMPreKeyStore { .thenApply(ignoredResponse -> pageId); } - private CompletableFuture deleteBundleFromS3(final UUID identifier, final byte deviceId, final UUID pageId) { + CompletableFuture deleteBundleFromS3(final UUID identifier, final byte deviceId, final UUID pageId) { return s3AsyncClient.deleteObject(DeleteObjectRequest.builder() .bucket(bucketName) .key(s3Key(identifier, deviceId, pageId)) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java index a1260faa7..2a015f887 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java @@ -213,12 +213,14 @@ record CommandDependencies( .credentialsProvider(awsCredentialsProvider) .region(Region.of(configuration.getPagedSingleUseKEMPreKeyStore().region())) .build(); + PagedSingleUseKEMPreKeyStore pagedSingleUseKEMPreKeyStore = new PagedSingleUseKEMPreKeyStore( + dynamoDbAsyncClient, asyncKeysS3Client, + configuration.getDynamoDbTables().getPagedKemKeys().getTableName(), + configuration.getPagedSingleUseKEMPreKeyStore().bucket()); KeysManager keys = new KeysManager( new SingleUseECPreKeyStore(dynamoDbAsyncClient, configuration.getDynamoDbTables().getEcKeys().getTableName()), new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, configuration.getDynamoDbTables().getKemKeys().getTableName()), - new PagedSingleUseKEMPreKeyStore(dynamoDbAsyncClient, asyncKeysS3Client, - configuration.getDynamoDbTables().getPagedKemKeys().getTableName(), - configuration.getPagedSingleUseKEMPreKeyStore().bucket()), + pagedSingleUseKEMPreKeyStore, new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient, configuration.getDynamoDbTables().getEcSignedPreKeys().getTableName()), new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveOrphanedPreKeyPagesCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveOrphanedPreKeyPagesCommand.java new file mode 100644 index 000000000..8770c966e --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveOrphanedPreKeyPagesCommand.java @@ -0,0 +1,143 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.workers; + +import com.google.common.annotations.VisibleForTesting; +import io.dropwizard.core.Application; +import io.dropwizard.core.setup.Environment; +import io.micrometer.core.instrument.Metrics; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.UUID; +import java.util.stream.Stream; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.Subparser; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.WhisperServerConfiguration; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.storage.DeviceKEMPreKeyPages; +import org.whispersystems.textsecuregcm.storage.KeysManager; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +public class RemoveOrphanedPreKeyPagesCommand extends AbstractCommandWithDependencies { + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private static final String PAGE_CONSIDERED_COUNTER_NAME = MetricsUtil.name(RemoveOrphanedPreKeyPagesCommand.class, + "pageConsidered"); + + @VisibleForTesting + static final String DRY_RUN_ARGUMENT = "dry-run"; + + @VisibleForTesting + static final String CONCURRENCY_ARGUMENT = "concurrency"; + private static final int DEFAULT_CONCURRENCY = 10; + + @VisibleForTesting + static final String MINIMUM_ORPHAN_AGE_ARGUMENT = "orphan-age"; + private static final Duration DEFAULT_MINIMUM_ORPHAN_AGE = Duration.ofDays(7); + + + + private final Clock clock; + + public RemoveOrphanedPreKeyPagesCommand(final Clock clock) { + super(new Application<>() { + @Override + public void run(final WhisperServerConfiguration configuration, final Environment environment) { + } + }, "remove-orphaned-pre-key-pages", "Remove pre-key pages that are unreferenced"); + this.clock = clock; + } + + @Override + public void configure(final Subparser subparser) { + super.configure(subparser); + + subparser.addArgument("--concurrency") + .type(Integer.class) + .dest(CONCURRENCY_ARGUMENT) + .required(false) + .setDefault(DEFAULT_CONCURRENCY) + .help("The maximum number of parallel dynamodb operations to process concurrently"); + + subparser.addArgument("--dry-run") + .type(Boolean.class) + .dest(DRY_RUN_ARGUMENT) + .required(false) + .setDefault(true) + .help("If true, don't actually remove orphaned pre-key pages"); + + subparser.addArgument("--minimum-orphan-age") + .type(String.class) + .dest(MINIMUM_ORPHAN_AGE_ARGUMENT) + .required(false) + .setDefault(DEFAULT_MINIMUM_ORPHAN_AGE.toString()) + .help("Only remove orphans that are at least this old. Provide as an ISO-8601 duration string"); + } + + @Override + protected void run(final Environment environment, final Namespace namespace, + final WhisperServerConfiguration configuration, final CommandDependencies commandDependencies) throws Exception { + + final int concurrency = Objects.requireNonNull(namespace.getInt(CONCURRENCY_ARGUMENT)); + final boolean dryRun = Objects.requireNonNull(namespace.getBoolean(DRY_RUN_ARGUMENT)); + final Duration orphanAgeMinimum = + Duration.parse(Objects.requireNonNull(namespace.getString(MINIMUM_ORPHAN_AGE_ARGUMENT))); + final Instant olderThan = clock.instant().minus(orphanAgeMinimum); + + logger.info("Crawling preKey page store with concurrency={}, processors={}, dryRun={}. Removing orphans written before={}", + concurrency, + Runtime.getRuntime().availableProcessors(), + dryRun, + olderThan); + + final KeysManager keysManager = commandDependencies.keysManager(); + final int deletedPages = keysManager.listStoredKEMPreKeyPages(concurrency) + .flatMap(storedPages -> Flux.fromStream(getDetetablePages(storedPages, olderThan)) + .concatMap(pageId -> dryRun + ? Mono.just(0) + : Mono.fromCompletionStage(() -> + keysManager.pruneDeadPage(storedPages.identifier(), storedPages.deviceId(), pageId)) + .retryWhen(Retry.backoff(3, Duration.ofSeconds(1))) + .thenReturn(1)), concurrency) + .reduce(0, Integer::sum) + .block(); + logger.info("Deleted {} orphaned pages", deletedPages); + } + + private static Stream getDetetablePages(final DeviceKEMPreKeyPages storedPages, final Instant olderThan) { + return storedPages.pageIdToLastModified() + .entrySet() + .stream() + .filter(page -> { + final UUID pageId = page.getKey(); + final Instant lastModified = page.getValue(); + return shouldDeletePage(storedPages.currentPage(), pageId, olderThan, lastModified); + }) + .map(Map.Entry::getKey); + } + + @VisibleForTesting + static boolean shouldDeletePage( + final Optional currentPage, final UUID page, + final Instant deleteBefore, final Instant lastModified) { + final boolean isCurrentPageForDevice = currentPage.map(uuid -> uuid.equals(page)).orElse(false); + final boolean isStale = lastModified.isBefore(deleteBefore); + Metrics.counter(PAGE_CONSIDERED_COUNTER_NAME, + "isCurrentPageForDevice", Boolean.toString(isCurrentPageForDevice), + "stale", Boolean.toString(isStale)) + .increment(); + return !isCurrentPageForDevice && isStale; + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStoreTest.java index 1b3f906bc..c90ab02de 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/PagedSingleUseKEMPreKeyStoreTest.java @@ -9,37 +9,27 @@ import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.testcontainers.containers.localstack.LocalStackContainer.Service.S3; import java.util.Comparator; import java.util.HashSet; import java.util.List; -import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.UUID; import java.util.concurrent.ThreadLocalRandom; -import java.util.stream.Collectors; import java.util.stream.IntStream; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.RepeatedTest; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.testcontainers.containers.localstack.LocalStackContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.utility.DockerImageName; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; +import org.whispersystems.textsecuregcm.util.TestRandomUtil; import reactor.core.publisher.Flux; -import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; -import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.s3.S3AsyncClient; -import software.amazon.awssdk.services.s3.model.CreateBucketRequest; +import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.services.s3.model.ListObjectsV2Request; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.S3Object; class PagedSingleUseKEMPreKeyStoreTest { @@ -77,7 +67,7 @@ class PagedSingleUseKEMPreKeyStoreTest { assertDoesNotThrow(() -> keyStore.store(accountIdentifier, deviceId, preKeys).join()); final List sortedPreKeys = preKeys.stream() - .sorted(Comparator.comparing(preKey -> preKey.keyId())) + .sorted(Comparator.comparing(KEMSignedPreKey::keyId)) .toList(); assertEquals(Optional.of(sortedPreKeys.get(0)), keyStore.take(accountIdentifier, deviceId).join()); @@ -91,18 +81,18 @@ class PagedSingleUseKEMPreKeyStoreTest { final List preKeys1 = generateRandomPreKeys(); keyStore.store(accountIdentifier, deviceId, preKeys1).join(); - List oldPages = listPages(accountIdentifier).stream().map(S3Object::key).collect(Collectors.toList()); + List oldPages = listPages(accountIdentifier).stream().map(S3Object::key).toList(); assertEquals(1, oldPages.size()); final List preKeys2 = generateRandomPreKeys(); keyStore.store(accountIdentifier, deviceId, preKeys2).join(); - List newPages = listPages(accountIdentifier).stream().map(S3Object::key).collect(Collectors.toList()); + List newPages = listPages(accountIdentifier).stream().map(S3Object::key).toList(); assertEquals(1, newPages.size()); assertNotEquals(oldPages.getFirst(), newPages.getFirst()); assertEquals( - preKeys2.stream().sorted(Comparator.comparing(preKey -> preKey.keyId())).toList(), + preKeys2.stream().sorted(Comparator.comparing(KEMSignedPreKey::keyId)).toList(), IntStream.range(0, preKeys2.size()) .mapToObj(i -> keyStore.take(accountIdentifier, deviceId).join()) @@ -122,7 +112,7 @@ class PagedSingleUseKEMPreKeyStoreTest { assertDoesNotThrow(() -> keyStore.store(accountIdentifier, deviceId, preKeys).join()); final List sortedPreKeys = preKeys.stream() - .sorted(Comparator.comparing(preKey -> preKey.keyId())) + .sorted(Comparator.comparing(KEMSignedPreKey::keyId)) .toList(); for (int i = 0; i < KEY_COUNT; i++) { @@ -171,7 +161,7 @@ class PagedSingleUseKEMPreKeyStoreTest { final List pages = listPages(accountIdentifier); assertEquals(1, pages.size()); - assertTrue(pages.get(0).key().startsWith("%s/%s".formatted(accountIdentifier, deviceId + 1))); + assertTrue(pages.getFirst().key().startsWith("%s/%s".formatted(accountIdentifier, deviceId + 1))); } @Test @@ -194,6 +184,66 @@ class PagedSingleUseKEMPreKeyStoreTest { assertEquals(0, listPages(accountIdentifier).size()); } + @Test + void listPages() { + final UUID aci1 = UUID.randomUUID(); + final UUID aci2 = new UUID(aci1.getMostSignificantBits(), aci1.getLeastSignificantBits() + 1); + final byte deviceId = 1; + + keyStore.store(aci1, deviceId, generateRandomPreKeys()).join(); + keyStore.store(aci1, (byte) (deviceId + 1), generateRandomPreKeys()).join(); + keyStore.store(aci2, deviceId, generateRandomPreKeys()).join(); + + List stored = keyStore.listStoredPages(1).collectList().block(); + assertEquals(3, stored.size()); + for (DeviceKEMPreKeyPages pages : stored) { + assertEquals(1, pages.pageIdToLastModified().size()); + } + + assertEquals(List.of(aci1, aci1, aci2), stored.stream().map(DeviceKEMPreKeyPages::identifier).toList()); + assertEquals( + List.of(deviceId, (byte) (deviceId + 1), deviceId), + stored.stream().map(DeviceKEMPreKeyPages::deviceId).toList()); + } + + @Test + void listPagesWithOrphans() { + final UUID aci1 = UUID.randomUUID(); + final UUID aci2 = new UUID(aci1.getMostSignificantBits(), aci1.getLeastSignificantBits() + 1); + final byte deviceId = 1; + + // Two orphans + keyStore.store(aci1, deviceId, generateRandomPreKeys()).join(); + writeOrphanedS3Object(aci1, deviceId); + writeOrphanedS3Object(aci1, deviceId); + + // No orphans + keyStore.store(aci1, (byte) (deviceId + 1), generateRandomPreKeys()).join(); + + // One orphan + keyStore.store(aci2, deviceId, generateRandomPreKeys()).join(); + writeOrphanedS3Object(aci2, deviceId); + + // Orphan with no database record + writeOrphanedS3Object(aci2, (byte) (deviceId + 2)); + + List stored = keyStore.listStoredPages(1).collectList().block(); + assertEquals(4, stored.size()); + + assertEquals( + List.of(3, 1, 2, 1), + stored.stream().map(s -> s.pageIdToLastModified().size()).toList()); + } + + private void writeOrphanedS3Object(final UUID identifier, final byte deviceId) { + S3_EXTENSION.getS3Client() + .putObject(PutObjectRequest.builder() + .bucket(BUCKET_NAME) + .key("%s/%s/%s".formatted(identifier, deviceId, UUID.randomUUID())).build(), + AsyncRequestBody.fromBytes(TestRandomUtil.nextBytes(10))) + .join(); + } + private List listPages(final UUID identifier) { return Flux.from(S3_EXTENSION.getS3Client().listObjectsV2Paginator(ListObjectsV2Request.builder() .bucket(BUCKET_NAME) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveOrphanedPreKeyPagesCommandTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveOrphanedPreKeyPagesCommandTest.java new file mode 100644 index 000000000..7d16ac7af --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/workers/RemoveOrphanedPreKeyPagesCommandTest.java @@ -0,0 +1,138 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.workers; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import io.dropwizard.core.setup.Environment; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import net.sourceforge.argparse4j.inf.Namespace; +import org.assertj.core.api.Assertions; +import org.junit.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.junitpioneer.jupiter.cartesian.CartesianTest; +import org.whispersystems.textsecuregcm.WhisperServerConfiguration; +import org.whispersystems.textsecuregcm.storage.DeviceKEMPreKeyPages; +import org.whispersystems.textsecuregcm.storage.KeysManager; +import org.whispersystems.textsecuregcm.util.TestClock; +import reactor.core.publisher.Flux; + +public class RemoveOrphanedPreKeyPagesCommandTest { + + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void removeStalePages(boolean dryRun) throws Exception { + final TestClock clock = TestClock.pinned(Instant.EPOCH.plus(Duration.ofSeconds(10))); + final KeysManager keysManager = mock(KeysManager.class); + + final UUID currentPage = UUID.randomUUID(); + final UUID freshOrphanedPage = UUID.randomUUID(); + final UUID staleOrphanedPage = UUID.randomUUID(); + + when(keysManager.listStoredKEMPreKeyPages(anyInt())).thenReturn(Flux.fromIterable(List.of( + new DeviceKEMPreKeyPages(UUID.randomUUID(), (byte) 1, Optional.of(currentPage), Map.of( + currentPage, Instant.EPOCH, + staleOrphanedPage, Instant.EPOCH.plus(Duration.ofSeconds(4)), + freshOrphanedPage, Instant.EPOCH.plus(Duration.ofSeconds(5))))))); + + when(keysManager.pruneDeadPage(any(), anyByte(), any())) + .thenReturn(CompletableFuture.completedFuture(null)); + + runCommand(clock, Duration.ofSeconds(5), dryRun, keysManager); + verify(keysManager, times(dryRun ? 0 : 1)) + .pruneDeadPage(any(), eq((byte) 1), eq(staleOrphanedPage)); + verify(keysManager, times(1)).listStoredKEMPreKeyPages(anyInt()); + verifyNoMoreInteractions(keysManager); + } + + @Test + public void noCurrentPage() throws Exception { + final TestClock clock = TestClock.pinned(Instant.EPOCH.plus(Duration.ofSeconds(10))); + final KeysManager keysManager = mock(KeysManager.class); + + final UUID freshOrphanedPage = UUID.randomUUID(); + final UUID staleOrphanedPage = UUID.randomUUID(); + + when(keysManager.listStoredKEMPreKeyPages(anyInt())).thenReturn(Flux.fromIterable(List.of( + new DeviceKEMPreKeyPages(UUID.randomUUID(), (byte) 1, Optional.empty(), Map.of( + staleOrphanedPage, Instant.EPOCH.plus(Duration.ofSeconds(4)), + freshOrphanedPage, Instant.EPOCH.plus(Duration.ofSeconds(5))))))); + + when(keysManager.pruneDeadPage(any(), anyByte(), any())) + .thenReturn(CompletableFuture.completedFuture(null)); + + runCommand(clock, Duration.ofSeconds(5), false, keysManager); + verify(keysManager, times(1)) + .pruneDeadPage(any(), eq((byte) 1), eq(staleOrphanedPage)); + verify(keysManager, times(1)).listStoredKEMPreKeyPages(anyInt()); + verifyNoMoreInteractions(keysManager); + } + + @Test + public void noPages() throws Exception { + final TestClock clock = TestClock.pinned(Instant.EPOCH); + final KeysManager keysManager = mock(KeysManager.class); + when(keysManager.listStoredKEMPreKeyPages(anyInt())).thenReturn(Flux.empty()); + runCommand(clock, Duration.ofSeconds(5), false, keysManager); + verify(keysManager).listStoredKEMPreKeyPages(anyInt()); + verifyNoMoreInteractions(keysManager); + } + + private enum PageStatus {NO_CURRENT, MATCH_CURRENT, MISMATCH_CURRENT} + + @CartesianTest + void shouldDeletePage( + @CartesianTest.Enum final PageStatus pageStatus, + @CartesianTest.Values(booleans = {false, true}) final boolean isOld) { + final Optional currentPage = pageStatus == PageStatus.NO_CURRENT + ? Optional.empty() + : Optional.of(UUID.randomUUID()); + final UUID page = switch (pageStatus) { + case MATCH_CURRENT -> currentPage.orElseThrow(); + case NO_CURRENT, MISMATCH_CURRENT -> UUID.randomUUID(); + }; + + final Instant threshold = Instant.EPOCH.plus(Duration.ofSeconds(10)); + final Instant lastModified = isOld ? threshold.minus(Duration.ofSeconds(1)) : threshold; + + final boolean shouldDelete = pageStatus != PageStatus.MATCH_CURRENT && isOld; + Assertions.assertThat(RemoveOrphanedPreKeyPagesCommand.shouldDeletePage(currentPage, page, threshold, lastModified)) + .isEqualTo(shouldDelete); + } + + + private void runCommand(final Clock clock, final Duration minimumOrphanAge, final boolean dryRun, + final KeysManager keysManager) throws Exception { + final CommandDependencies commandDependencies = mock(CommandDependencies.class); + when(commandDependencies.keysManager()).thenReturn(keysManager); + + final Namespace namespace = mock(Namespace.class); + when(namespace.getBoolean(RemoveOrphanedPreKeyPagesCommand.DRY_RUN_ARGUMENT)).thenReturn(dryRun); + when(namespace.getInt(RemoveOrphanedPreKeyPagesCommand.CONCURRENCY_ARGUMENT)).thenReturn(2); + when(namespace.getString(RemoveOrphanedPreKeyPagesCommand.MINIMUM_ORPHAN_AGE_ARGUMENT)) + .thenReturn(minimumOrphanAge.toString()); + + final RemoveOrphanedPreKeyPagesCommand command = new RemoveOrphanedPreKeyPagesCommand(clock); + command.run(mock(Environment.class), namespace, mock(WhisperServerConfiguration.class), commandDependencies); + } +}