Add a crawler for orphaned prekey pages
This commit is contained in:
parent
2bb14892af
commit
aaa36fd8f5
|
@ -0,0 +1,24 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.storage;
|
||||||
|
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The prekey pages stored for a particular device
|
||||||
|
*
|
||||||
|
* @param identifier The account identifier or phone number identifier that the keys belong to
|
||||||
|
* @param deviceId The device identifier
|
||||||
|
* @param currentPage If present, the active stored page prekeys are being distributed from
|
||||||
|
* @param pageIdToLastModified The last modified time for all the device's stored pages, keyed by the pageId
|
||||||
|
*/
|
||||||
|
public record DeviceKEMPreKeyPages(
|
||||||
|
UUID identifier, byte deviceId,
|
||||||
|
Optional<UUID> currentPage,
|
||||||
|
Map<UUID, Instant> pageIdToLastModified) {}
|
|
@ -5,6 +5,7 @@
|
||||||
|
|
||||||
package org.whispersystems.textsecuregcm.storage;
|
package org.whispersystems.textsecuregcm.storage;
|
||||||
|
|
||||||
|
import java.time.Instant;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
@ -12,6 +13,7 @@ import java.util.concurrent.CompletableFuture;
|
||||||
import org.whispersystems.textsecuregcm.entities.ECPreKey;
|
import org.whispersystems.textsecuregcm.entities.ECPreKey;
|
||||||
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
|
import org.whispersystems.textsecuregcm.entities.ECSignedPreKey;
|
||||||
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
|
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
|
||||||
|
import reactor.core.publisher.Flux;
|
||||||
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
|
import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem;
|
||||||
|
|
||||||
public class KeysManager {
|
public class KeysManager {
|
||||||
|
@ -131,8 +133,32 @@ public class KeysManager {
|
||||||
|
|
||||||
public CompletableFuture<Void> deleteSingleUsePreKeys(final UUID accountUuid, final byte deviceId) {
|
public CompletableFuture<Void> deleteSingleUsePreKeys(final UUID accountUuid, final byte deviceId) {
|
||||||
return CompletableFuture.allOf(
|
return CompletableFuture.allOf(
|
||||||
ecPreKeys.delete(accountUuid, deviceId),
|
ecPreKeys.delete(accountUuid, deviceId),
|
||||||
pqPreKeys.delete(accountUuid, deviceId)
|
pqPreKeys.delete(accountUuid, deviceId)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* List all the current remotely stored prekey pages across all devices. Pages that are no longer in use can be
|
||||||
|
* removed with {@link #pruneDeadPage}
|
||||||
|
*
|
||||||
|
* @param lookupConcurrency the number of concurrent lookup operations to perform when populating list results
|
||||||
|
* @return All stored prekey pages
|
||||||
|
*/
|
||||||
|
public Flux<DeviceKEMPreKeyPages> listStoredKEMPreKeyPages(int lookupConcurrency) {
|
||||||
|
return pagedPqPreKeys.listStoredPages(lookupConcurrency);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Remove a prekey page that is no longer in use. A page should only be removed if it is not the active page and
|
||||||
|
* it has no chance of being updated to be.
|
||||||
|
*
|
||||||
|
* @param identifier The owner of the dead page
|
||||||
|
* @param deviceId The device of the dead page
|
||||||
|
* @param pageId The dead page to remove from storage
|
||||||
|
* @return A future that completes when the page has been removed
|
||||||
|
*/
|
||||||
|
public CompletableFuture<Void> pruneDeadPage(final UUID identifier, final byte deviceId, final UUID pageId) {
|
||||||
|
return pagedPqPreKeys.deleteBundleFromS3(identifier, deviceId, pageId);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ import io.micrometer.core.instrument.Metrics;
|
||||||
import io.micrometer.core.instrument.Timer;
|
import io.micrometer.core.instrument.Timer;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
|
import java.time.Instant;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@ -19,6 +20,9 @@ import java.util.Optional;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
import java.util.concurrent.CompletionException;
|
import java.util.concurrent.CompletionException;
|
||||||
|
import java.util.function.Function;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import org.signal.libsignal.protocol.InvalidKeyException;
|
import org.signal.libsignal.protocol.InvalidKeyException;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
@ -40,9 +44,7 @@ import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
|
||||||
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
|
import software.amazon.awssdk.services.dynamodb.model.ReturnValue;
|
||||||
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
|
import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest;
|
||||||
import software.amazon.awssdk.services.s3.S3AsyncClient;
|
import software.amazon.awssdk.services.s3.S3AsyncClient;
|
||||||
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
|
import software.amazon.awssdk.services.s3.model.*;
|
||||||
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
|
|
||||||
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @implNote This version of a {@link SingleUsePreKeyStore} store bundles prekeys into "pages", which are stored in on
|
* @implNote This version of a {@link SingleUsePreKeyStore} store bundles prekeys into "pages", which are stored in on
|
||||||
|
@ -294,6 +296,40 @@ public class PagedSingleUseKEMPreKeyStore {
|
||||||
.thenRun(() -> sample.stop(deleteForDeviceTimer));
|
.thenRun(() -> sample.stop(deleteForDeviceTimer));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public Flux<DeviceKEMPreKeyPages> listStoredPages(int lookupConcurrency) {
|
||||||
|
return Flux
|
||||||
|
.from(s3AsyncClient.listObjectsV2Paginator(ListObjectsV2Request.builder()
|
||||||
|
.bucket(bucketName)
|
||||||
|
.build()))
|
||||||
|
.flatMapIterable(ListObjectsV2Response::contents)
|
||||||
|
.map(PagedSingleUseKEMPreKeyStore::parseS3Key)
|
||||||
|
.bufferUntilChanged(Function.identity(), S3PageKey::fromSameDevice)
|
||||||
|
.flatMapSequential(pages -> {
|
||||||
|
final UUID identifier = pages.getFirst().identifier();
|
||||||
|
final byte deviceId = pages.getFirst().deviceId();
|
||||||
|
return Mono.fromCompletionStage(() -> dynamoDbAsyncClient.getItem(GetItemRequest.builder()
|
||||||
|
.tableName(tableName)
|
||||||
|
.key(Map.of(
|
||||||
|
KEY_ACCOUNT_UUID, AttributeValues.fromUUID(identifier),
|
||||||
|
KEY_DEVICE_ID, AttributeValues.fromInt(deviceId)))
|
||||||
|
// Make sure we get the most up to date pageId to minimize cases where we see a new page in S3 but
|
||||||
|
// view a stale dynamodb record
|
||||||
|
.consistentRead(true)
|
||||||
|
.projectionExpression("#uuid,#deviceid,#pageid")
|
||||||
|
.expressionAttributeNames(Map.of(
|
||||||
|
"#uuid", KEY_ACCOUNT_UUID,
|
||||||
|
"#deviceid", KEY_DEVICE_ID,
|
||||||
|
"#pageid", ATTR_PAGE_ID))
|
||||||
|
.build())
|
||||||
|
.thenApply(getItemResponse -> new DeviceKEMPreKeyPages(
|
||||||
|
identifier,
|
||||||
|
deviceId,
|
||||||
|
Optional.ofNullable(AttributeValues.getUUID(getItemResponse.item(), ATTR_PAGE_ID, null)),
|
||||||
|
pages.stream().collect(Collectors.toMap(S3PageKey::pageId, S3PageKey::lastModified)))));
|
||||||
|
}, lookupConcurrency);
|
||||||
|
}
|
||||||
|
|
||||||
private CompletableFuture<Void> deleteItems(final UUID identifier,
|
private CompletableFuture<Void> deleteItems(final UUID identifier,
|
||||||
final Flux<Map<String, AttributeValue>> items) {
|
final Flux<Map<String, AttributeValue>> items) {
|
||||||
return items
|
return items
|
||||||
|
@ -322,6 +358,29 @@ public class PagedSingleUseKEMPreKeyStore {
|
||||||
return String.format("%s/%s/%s", identifier, deviceId, pageId);
|
return String.format("%s/%s/%s", identifier, deviceId, pageId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private record S3PageKey(UUID identifier, byte deviceId, UUID pageId, Instant lastModified) {
|
||||||
|
|
||||||
|
boolean fromSameDevice(final S3PageKey other) {
|
||||||
|
return deviceId == other.deviceId && identifier.equals(other.identifier);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static S3PageKey parseS3Key(final S3Object page) {
|
||||||
|
try {
|
||||||
|
final String[] parts = page.key().split("/", 3);
|
||||||
|
if (parts.length != 3 || parts[2].contains("/")) {
|
||||||
|
throw new IllegalArgumentException("wrong number of path components");
|
||||||
|
}
|
||||||
|
return new S3PageKey(
|
||||||
|
UUID.fromString(parts[0]),
|
||||||
|
Byte.parseByte(parts[1]),
|
||||||
|
UUID.fromString(parts[2]), page.lastModified());
|
||||||
|
} catch (IllegalArgumentException e) {
|
||||||
|
throw new IllegalArgumentException("invalid s3 page key: " + page.key(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
private CompletableFuture<UUID> writeBundleToS3(final UUID identifier, final byte deviceId,
|
private CompletableFuture<UUID> writeBundleToS3(final UUID identifier, final byte deviceId,
|
||||||
final ByteBuffer bundle) {
|
final ByteBuffer bundle) {
|
||||||
final UUID pageId = UUID.randomUUID();
|
final UUID pageId = UUID.randomUUID();
|
||||||
|
@ -332,7 +391,7 @@ public class PagedSingleUseKEMPreKeyStore {
|
||||||
.thenApply(ignoredResponse -> pageId);
|
.thenApply(ignoredResponse -> pageId);
|
||||||
}
|
}
|
||||||
|
|
||||||
private CompletableFuture<Void> deleteBundleFromS3(final UUID identifier, final byte deviceId, final UUID pageId) {
|
CompletableFuture<Void> deleteBundleFromS3(final UUID identifier, final byte deviceId, final UUID pageId) {
|
||||||
return s3AsyncClient.deleteObject(DeleteObjectRequest.builder()
|
return s3AsyncClient.deleteObject(DeleteObjectRequest.builder()
|
||||||
.bucket(bucketName)
|
.bucket(bucketName)
|
||||||
.key(s3Key(identifier, deviceId, pageId))
|
.key(s3Key(identifier, deviceId, pageId))
|
||||||
|
|
|
@ -213,12 +213,14 @@ record CommandDependencies(
|
||||||
.credentialsProvider(awsCredentialsProvider)
|
.credentialsProvider(awsCredentialsProvider)
|
||||||
.region(Region.of(configuration.getPagedSingleUseKEMPreKeyStore().region()))
|
.region(Region.of(configuration.getPagedSingleUseKEMPreKeyStore().region()))
|
||||||
.build();
|
.build();
|
||||||
|
PagedSingleUseKEMPreKeyStore pagedSingleUseKEMPreKeyStore = new PagedSingleUseKEMPreKeyStore(
|
||||||
|
dynamoDbAsyncClient, asyncKeysS3Client,
|
||||||
|
configuration.getDynamoDbTables().getPagedKemKeys().getTableName(),
|
||||||
|
configuration.getPagedSingleUseKEMPreKeyStore().bucket());
|
||||||
KeysManager keys = new KeysManager(
|
KeysManager keys = new KeysManager(
|
||||||
new SingleUseECPreKeyStore(dynamoDbAsyncClient, configuration.getDynamoDbTables().getEcKeys().getTableName()),
|
new SingleUseECPreKeyStore(dynamoDbAsyncClient, configuration.getDynamoDbTables().getEcKeys().getTableName()),
|
||||||
new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, configuration.getDynamoDbTables().getKemKeys().getTableName()),
|
new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, configuration.getDynamoDbTables().getKemKeys().getTableName()),
|
||||||
new PagedSingleUseKEMPreKeyStore(dynamoDbAsyncClient, asyncKeysS3Client,
|
pagedSingleUseKEMPreKeyStore,
|
||||||
configuration.getDynamoDbTables().getPagedKemKeys().getTableName(),
|
|
||||||
configuration.getPagedSingleUseKEMPreKeyStore().bucket()),
|
|
||||||
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient,
|
new RepeatedUseECSignedPreKeyStore(dynamoDbAsyncClient,
|
||||||
configuration.getDynamoDbTables().getEcSignedPreKeys().getTableName()),
|
configuration.getDynamoDbTables().getEcSignedPreKeys().getTableName()),
|
||||||
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient,
|
new RepeatedUseKEMSignedPreKeyStore(dynamoDbAsyncClient,
|
||||||
|
|
|
@ -0,0 +1,143 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.workers;
|
||||||
|
|
||||||
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
|
import io.dropwizard.core.Application;
|
||||||
|
import io.dropwizard.core.setup.Environment;
|
||||||
|
import io.micrometer.core.instrument.Metrics;
|
||||||
|
import java.time.Clock;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.UUID;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
import net.sourceforge.argparse4j.inf.Namespace;
|
||||||
|
import net.sourceforge.argparse4j.inf.Subparser;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
|
||||||
|
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.DeviceKEMPreKeyPages;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.KeysManager;
|
||||||
|
import reactor.core.publisher.Flux;
|
||||||
|
import reactor.core.publisher.Mono;
|
||||||
|
import reactor.util.retry.Retry;
|
||||||
|
|
||||||
|
public class RemoveOrphanedPreKeyPagesCommand extends AbstractCommandWithDependencies {
|
||||||
|
|
||||||
|
private final Logger logger = LoggerFactory.getLogger(getClass());
|
||||||
|
|
||||||
|
private static final String PAGE_CONSIDERED_COUNTER_NAME = MetricsUtil.name(RemoveOrphanedPreKeyPagesCommand.class,
|
||||||
|
"pageConsidered");
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
static final String DRY_RUN_ARGUMENT = "dry-run";
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
static final String CONCURRENCY_ARGUMENT = "concurrency";
|
||||||
|
private static final int DEFAULT_CONCURRENCY = 10;
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
static final String MINIMUM_ORPHAN_AGE_ARGUMENT = "orphan-age";
|
||||||
|
private static final Duration DEFAULT_MINIMUM_ORPHAN_AGE = Duration.ofDays(7);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
private final Clock clock;
|
||||||
|
|
||||||
|
public RemoveOrphanedPreKeyPagesCommand(final Clock clock) {
|
||||||
|
super(new Application<>() {
|
||||||
|
@Override
|
||||||
|
public void run(final WhisperServerConfiguration configuration, final Environment environment) {
|
||||||
|
}
|
||||||
|
}, "remove-orphaned-pre-key-pages", "Remove pre-key pages that are unreferenced");
|
||||||
|
this.clock = clock;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void configure(final Subparser subparser) {
|
||||||
|
super.configure(subparser);
|
||||||
|
|
||||||
|
subparser.addArgument("--concurrency")
|
||||||
|
.type(Integer.class)
|
||||||
|
.dest(CONCURRENCY_ARGUMENT)
|
||||||
|
.required(false)
|
||||||
|
.setDefault(DEFAULT_CONCURRENCY)
|
||||||
|
.help("The maximum number of parallel dynamodb operations to process concurrently");
|
||||||
|
|
||||||
|
subparser.addArgument("--dry-run")
|
||||||
|
.type(Boolean.class)
|
||||||
|
.dest(DRY_RUN_ARGUMENT)
|
||||||
|
.required(false)
|
||||||
|
.setDefault(true)
|
||||||
|
.help("If true, don't actually remove orphaned pre-key pages");
|
||||||
|
|
||||||
|
subparser.addArgument("--minimum-orphan-age")
|
||||||
|
.type(String.class)
|
||||||
|
.dest(MINIMUM_ORPHAN_AGE_ARGUMENT)
|
||||||
|
.required(false)
|
||||||
|
.setDefault(DEFAULT_MINIMUM_ORPHAN_AGE.toString())
|
||||||
|
.help("Only remove orphans that are at least this old. Provide as an ISO-8601 duration string");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void run(final Environment environment, final Namespace namespace,
|
||||||
|
final WhisperServerConfiguration configuration, final CommandDependencies commandDependencies) throws Exception {
|
||||||
|
|
||||||
|
final int concurrency = Objects.requireNonNull(namespace.getInt(CONCURRENCY_ARGUMENT));
|
||||||
|
final boolean dryRun = Objects.requireNonNull(namespace.getBoolean(DRY_RUN_ARGUMENT));
|
||||||
|
final Duration orphanAgeMinimum =
|
||||||
|
Duration.parse(Objects.requireNonNull(namespace.getString(MINIMUM_ORPHAN_AGE_ARGUMENT)));
|
||||||
|
final Instant olderThan = clock.instant().minus(orphanAgeMinimum);
|
||||||
|
|
||||||
|
logger.info("Crawling preKey page store with concurrency={}, processors={}, dryRun={}. Removing orphans written before={}",
|
||||||
|
concurrency,
|
||||||
|
Runtime.getRuntime().availableProcessors(),
|
||||||
|
dryRun,
|
||||||
|
olderThan);
|
||||||
|
|
||||||
|
final KeysManager keysManager = commandDependencies.keysManager();
|
||||||
|
final int deletedPages = keysManager.listStoredKEMPreKeyPages(concurrency)
|
||||||
|
.flatMap(storedPages -> Flux.fromStream(getDetetablePages(storedPages, olderThan))
|
||||||
|
.concatMap(pageId -> dryRun
|
||||||
|
? Mono.just(0)
|
||||||
|
: Mono.fromCompletionStage(() ->
|
||||||
|
keysManager.pruneDeadPage(storedPages.identifier(), storedPages.deviceId(), pageId))
|
||||||
|
.retryWhen(Retry.backoff(3, Duration.ofSeconds(1)))
|
||||||
|
.thenReturn(1)), concurrency)
|
||||||
|
.reduce(0, Integer::sum)
|
||||||
|
.block();
|
||||||
|
logger.info("Deleted {} orphaned pages", deletedPages);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Stream<UUID> getDetetablePages(final DeviceKEMPreKeyPages storedPages, final Instant olderThan) {
|
||||||
|
return storedPages.pageIdToLastModified()
|
||||||
|
.entrySet()
|
||||||
|
.stream()
|
||||||
|
.filter(page -> {
|
||||||
|
final UUID pageId = page.getKey();
|
||||||
|
final Instant lastModified = page.getValue();
|
||||||
|
return shouldDeletePage(storedPages.currentPage(), pageId, olderThan, lastModified);
|
||||||
|
})
|
||||||
|
.map(Map.Entry::getKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
static boolean shouldDeletePage(
|
||||||
|
final Optional<UUID> currentPage, final UUID page,
|
||||||
|
final Instant deleteBefore, final Instant lastModified) {
|
||||||
|
final boolean isCurrentPageForDevice = currentPage.map(uuid -> uuid.equals(page)).orElse(false);
|
||||||
|
final boolean isStale = lastModified.isBefore(deleteBefore);
|
||||||
|
Metrics.counter(PAGE_CONSIDERED_COUNTER_NAME,
|
||||||
|
"isCurrentPageForDevice", Boolean.toString(isCurrentPageForDevice),
|
||||||
|
"stale", Boolean.toString(isStale))
|
||||||
|
.increment();
|
||||||
|
return !isCurrentPageForDevice && isStale;
|
||||||
|
}
|
||||||
|
}
|
|
@ -9,37 +9,27 @@ import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
import static org.testcontainers.containers.localstack.LocalStackContainer.Service.S3;
|
|
||||||
|
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import java.util.concurrent.ThreadLocalRandom;
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import java.util.stream.IntStream;
|
import java.util.stream.IntStream;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.RepeatedTest;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.extension.RegisterExtension;
|
import org.junit.jupiter.api.extension.RegisterExtension;
|
||||||
import org.signal.libsignal.protocol.ecc.Curve;
|
import org.signal.libsignal.protocol.ecc.Curve;
|
||||||
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
import org.signal.libsignal.protocol.ecc.ECKeyPair;
|
||||||
import org.testcontainers.containers.localstack.LocalStackContainer;
|
|
||||||
import org.testcontainers.junit.jupiter.Container;
|
|
||||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
|
||||||
import org.testcontainers.utility.DockerImageName;
|
|
||||||
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
|
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
|
||||||
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
|
import org.whispersystems.textsecuregcm.tests.util.KeysHelper;
|
||||||
|
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
||||||
import reactor.core.publisher.Flux;
|
import reactor.core.publisher.Flux;
|
||||||
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
|
import software.amazon.awssdk.core.async.AsyncRequestBody;
|
||||||
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
|
|
||||||
import software.amazon.awssdk.regions.Region;
|
|
||||||
import software.amazon.awssdk.services.s3.S3AsyncClient;
|
|
||||||
import software.amazon.awssdk.services.s3.model.CreateBucketRequest;
|
|
||||||
import software.amazon.awssdk.services.s3.model.ListObjectsV2Request;
|
import software.amazon.awssdk.services.s3.model.ListObjectsV2Request;
|
||||||
|
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
|
||||||
import software.amazon.awssdk.services.s3.model.S3Object;
|
import software.amazon.awssdk.services.s3.model.S3Object;
|
||||||
|
|
||||||
class PagedSingleUseKEMPreKeyStoreTest {
|
class PagedSingleUseKEMPreKeyStoreTest {
|
||||||
|
@ -77,7 +67,7 @@ class PagedSingleUseKEMPreKeyStoreTest {
|
||||||
assertDoesNotThrow(() -> keyStore.store(accountIdentifier, deviceId, preKeys).join());
|
assertDoesNotThrow(() -> keyStore.store(accountIdentifier, deviceId, preKeys).join());
|
||||||
|
|
||||||
final List<KEMSignedPreKey> sortedPreKeys = preKeys.stream()
|
final List<KEMSignedPreKey> sortedPreKeys = preKeys.stream()
|
||||||
.sorted(Comparator.comparing(preKey -> preKey.keyId()))
|
.sorted(Comparator.comparing(KEMSignedPreKey::keyId))
|
||||||
.toList();
|
.toList();
|
||||||
|
|
||||||
assertEquals(Optional.of(sortedPreKeys.get(0)), keyStore.take(accountIdentifier, deviceId).join());
|
assertEquals(Optional.of(sortedPreKeys.get(0)), keyStore.take(accountIdentifier, deviceId).join());
|
||||||
|
@ -91,18 +81,18 @@ class PagedSingleUseKEMPreKeyStoreTest {
|
||||||
|
|
||||||
final List<KEMSignedPreKey> preKeys1 = generateRandomPreKeys();
|
final List<KEMSignedPreKey> preKeys1 = generateRandomPreKeys();
|
||||||
keyStore.store(accountIdentifier, deviceId, preKeys1).join();
|
keyStore.store(accountIdentifier, deviceId, preKeys1).join();
|
||||||
List<String> oldPages = listPages(accountIdentifier).stream().map(S3Object::key).collect(Collectors.toList());
|
List<String> oldPages = listPages(accountIdentifier).stream().map(S3Object::key).toList();
|
||||||
assertEquals(1, oldPages.size());
|
assertEquals(1, oldPages.size());
|
||||||
|
|
||||||
final List<KEMSignedPreKey> preKeys2 = generateRandomPreKeys();
|
final List<KEMSignedPreKey> preKeys2 = generateRandomPreKeys();
|
||||||
keyStore.store(accountIdentifier, deviceId, preKeys2).join();
|
keyStore.store(accountIdentifier, deviceId, preKeys2).join();
|
||||||
List<String> newPages = listPages(accountIdentifier).stream().map(S3Object::key).collect(Collectors.toList());
|
List<String> newPages = listPages(accountIdentifier).stream().map(S3Object::key).toList();
|
||||||
assertEquals(1, newPages.size());
|
assertEquals(1, newPages.size());
|
||||||
|
|
||||||
assertNotEquals(oldPages.getFirst(), newPages.getFirst());
|
assertNotEquals(oldPages.getFirst(), newPages.getFirst());
|
||||||
|
|
||||||
assertEquals(
|
assertEquals(
|
||||||
preKeys2.stream().sorted(Comparator.comparing(preKey -> preKey.keyId())).toList(),
|
preKeys2.stream().sorted(Comparator.comparing(KEMSignedPreKey::keyId)).toList(),
|
||||||
|
|
||||||
IntStream.range(0, preKeys2.size())
|
IntStream.range(0, preKeys2.size())
|
||||||
.mapToObj(i -> keyStore.take(accountIdentifier, deviceId).join())
|
.mapToObj(i -> keyStore.take(accountIdentifier, deviceId).join())
|
||||||
|
@ -122,7 +112,7 @@ class PagedSingleUseKEMPreKeyStoreTest {
|
||||||
assertDoesNotThrow(() -> keyStore.store(accountIdentifier, deviceId, preKeys).join());
|
assertDoesNotThrow(() -> keyStore.store(accountIdentifier, deviceId, preKeys).join());
|
||||||
|
|
||||||
final List<KEMSignedPreKey> sortedPreKeys = preKeys.stream()
|
final List<KEMSignedPreKey> sortedPreKeys = preKeys.stream()
|
||||||
.sorted(Comparator.comparing(preKey -> preKey.keyId()))
|
.sorted(Comparator.comparing(KEMSignedPreKey::keyId))
|
||||||
.toList();
|
.toList();
|
||||||
|
|
||||||
for (int i = 0; i < KEY_COUNT; i++) {
|
for (int i = 0; i < KEY_COUNT; i++) {
|
||||||
|
@ -171,7 +161,7 @@ class PagedSingleUseKEMPreKeyStoreTest {
|
||||||
|
|
||||||
final List<S3Object> pages = listPages(accountIdentifier);
|
final List<S3Object> pages = listPages(accountIdentifier);
|
||||||
assertEquals(1, pages.size());
|
assertEquals(1, pages.size());
|
||||||
assertTrue(pages.get(0).key().startsWith("%s/%s".formatted(accountIdentifier, deviceId + 1)));
|
assertTrue(pages.getFirst().key().startsWith("%s/%s".formatted(accountIdentifier, deviceId + 1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -194,6 +184,66 @@ class PagedSingleUseKEMPreKeyStoreTest {
|
||||||
assertEquals(0, listPages(accountIdentifier).size());
|
assertEquals(0, listPages(accountIdentifier).size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void listPages() {
|
||||||
|
final UUID aci1 = UUID.randomUUID();
|
||||||
|
final UUID aci2 = new UUID(aci1.getMostSignificantBits(), aci1.getLeastSignificantBits() + 1);
|
||||||
|
final byte deviceId = 1;
|
||||||
|
|
||||||
|
keyStore.store(aci1, deviceId, generateRandomPreKeys()).join();
|
||||||
|
keyStore.store(aci1, (byte) (deviceId + 1), generateRandomPreKeys()).join();
|
||||||
|
keyStore.store(aci2, deviceId, generateRandomPreKeys()).join();
|
||||||
|
|
||||||
|
List<DeviceKEMPreKeyPages> stored = keyStore.listStoredPages(1).collectList().block();
|
||||||
|
assertEquals(3, stored.size());
|
||||||
|
for (DeviceKEMPreKeyPages pages : stored) {
|
||||||
|
assertEquals(1, pages.pageIdToLastModified().size());
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(List.of(aci1, aci1, aci2), stored.stream().map(DeviceKEMPreKeyPages::identifier).toList());
|
||||||
|
assertEquals(
|
||||||
|
List.of(deviceId, (byte) (deviceId + 1), deviceId),
|
||||||
|
stored.stream().map(DeviceKEMPreKeyPages::deviceId).toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void listPagesWithOrphans() {
|
||||||
|
final UUID aci1 = UUID.randomUUID();
|
||||||
|
final UUID aci2 = new UUID(aci1.getMostSignificantBits(), aci1.getLeastSignificantBits() + 1);
|
||||||
|
final byte deviceId = 1;
|
||||||
|
|
||||||
|
// Two orphans
|
||||||
|
keyStore.store(aci1, deviceId, generateRandomPreKeys()).join();
|
||||||
|
writeOrphanedS3Object(aci1, deviceId);
|
||||||
|
writeOrphanedS3Object(aci1, deviceId);
|
||||||
|
|
||||||
|
// No orphans
|
||||||
|
keyStore.store(aci1, (byte) (deviceId + 1), generateRandomPreKeys()).join();
|
||||||
|
|
||||||
|
// One orphan
|
||||||
|
keyStore.store(aci2, deviceId, generateRandomPreKeys()).join();
|
||||||
|
writeOrphanedS3Object(aci2, deviceId);
|
||||||
|
|
||||||
|
// Orphan with no database record
|
||||||
|
writeOrphanedS3Object(aci2, (byte) (deviceId + 2));
|
||||||
|
|
||||||
|
List<DeviceKEMPreKeyPages> stored = keyStore.listStoredPages(1).collectList().block();
|
||||||
|
assertEquals(4, stored.size());
|
||||||
|
|
||||||
|
assertEquals(
|
||||||
|
List.of(3, 1, 2, 1),
|
||||||
|
stored.stream().map(s -> s.pageIdToLastModified().size()).toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
private void writeOrphanedS3Object(final UUID identifier, final byte deviceId) {
|
||||||
|
S3_EXTENSION.getS3Client()
|
||||||
|
.putObject(PutObjectRequest.builder()
|
||||||
|
.bucket(BUCKET_NAME)
|
||||||
|
.key("%s/%s/%s".formatted(identifier, deviceId, UUID.randomUUID())).build(),
|
||||||
|
AsyncRequestBody.fromBytes(TestRandomUtil.nextBytes(10)))
|
||||||
|
.join();
|
||||||
|
}
|
||||||
|
|
||||||
private List<S3Object> listPages(final UUID identifier) {
|
private List<S3Object> listPages(final UUID identifier) {
|
||||||
return Flux.from(S3_EXTENSION.getS3Client().listObjectsV2Paginator(ListObjectsV2Request.builder()
|
return Flux.from(S3_EXTENSION.getS3Client().listObjectsV2Paginator(ListObjectsV2Request.builder()
|
||||||
.bucket(BUCKET_NAME)
|
.bucket(BUCKET_NAME)
|
||||||
|
|
|
@ -0,0 +1,138 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2025 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.workers;
|
||||||
|
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyByte;
|
||||||
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
|
import static org.mockito.Mockito.anyInt;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
import io.dropwizard.core.setup.Environment;
|
||||||
|
import java.time.Clock;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import net.sourceforge.argparse4j.inf.Namespace;
|
||||||
|
import org.assertj.core.api.Assertions;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
|
import org.junit.jupiter.params.provider.ValueSource;
|
||||||
|
import org.junitpioneer.jupiter.cartesian.CartesianTest;
|
||||||
|
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.DeviceKEMPreKeyPages;
|
||||||
|
import org.whispersystems.textsecuregcm.storage.KeysManager;
|
||||||
|
import org.whispersystems.textsecuregcm.util.TestClock;
|
||||||
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
|
public class RemoveOrphanedPreKeyPagesCommandTest {
|
||||||
|
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource(booleans = {true, false})
|
||||||
|
public void removeStalePages(boolean dryRun) throws Exception {
|
||||||
|
final TestClock clock = TestClock.pinned(Instant.EPOCH.plus(Duration.ofSeconds(10)));
|
||||||
|
final KeysManager keysManager = mock(KeysManager.class);
|
||||||
|
|
||||||
|
final UUID currentPage = UUID.randomUUID();
|
||||||
|
final UUID freshOrphanedPage = UUID.randomUUID();
|
||||||
|
final UUID staleOrphanedPage = UUID.randomUUID();
|
||||||
|
|
||||||
|
when(keysManager.listStoredKEMPreKeyPages(anyInt())).thenReturn(Flux.fromIterable(List.of(
|
||||||
|
new DeviceKEMPreKeyPages(UUID.randomUUID(), (byte) 1, Optional.of(currentPage), Map.of(
|
||||||
|
currentPage, Instant.EPOCH,
|
||||||
|
staleOrphanedPage, Instant.EPOCH.plus(Duration.ofSeconds(4)),
|
||||||
|
freshOrphanedPage, Instant.EPOCH.plus(Duration.ofSeconds(5)))))));
|
||||||
|
|
||||||
|
when(keysManager.pruneDeadPage(any(), anyByte(), any()))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(null));
|
||||||
|
|
||||||
|
runCommand(clock, Duration.ofSeconds(5), dryRun, keysManager);
|
||||||
|
verify(keysManager, times(dryRun ? 0 : 1))
|
||||||
|
.pruneDeadPage(any(), eq((byte) 1), eq(staleOrphanedPage));
|
||||||
|
verify(keysManager, times(1)).listStoredKEMPreKeyPages(anyInt());
|
||||||
|
verifyNoMoreInteractions(keysManager);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void noCurrentPage() throws Exception {
|
||||||
|
final TestClock clock = TestClock.pinned(Instant.EPOCH.plus(Duration.ofSeconds(10)));
|
||||||
|
final KeysManager keysManager = mock(KeysManager.class);
|
||||||
|
|
||||||
|
final UUID freshOrphanedPage = UUID.randomUUID();
|
||||||
|
final UUID staleOrphanedPage = UUID.randomUUID();
|
||||||
|
|
||||||
|
when(keysManager.listStoredKEMPreKeyPages(anyInt())).thenReturn(Flux.fromIterable(List.of(
|
||||||
|
new DeviceKEMPreKeyPages(UUID.randomUUID(), (byte) 1, Optional.empty(), Map.of(
|
||||||
|
staleOrphanedPage, Instant.EPOCH.plus(Duration.ofSeconds(4)),
|
||||||
|
freshOrphanedPage, Instant.EPOCH.plus(Duration.ofSeconds(5)))))));
|
||||||
|
|
||||||
|
when(keysManager.pruneDeadPage(any(), anyByte(), any()))
|
||||||
|
.thenReturn(CompletableFuture.completedFuture(null));
|
||||||
|
|
||||||
|
runCommand(clock, Duration.ofSeconds(5), false, keysManager);
|
||||||
|
verify(keysManager, times(1))
|
||||||
|
.pruneDeadPage(any(), eq((byte) 1), eq(staleOrphanedPage));
|
||||||
|
verify(keysManager, times(1)).listStoredKEMPreKeyPages(anyInt());
|
||||||
|
verifyNoMoreInteractions(keysManager);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void noPages() throws Exception {
|
||||||
|
final TestClock clock = TestClock.pinned(Instant.EPOCH);
|
||||||
|
final KeysManager keysManager = mock(KeysManager.class);
|
||||||
|
when(keysManager.listStoredKEMPreKeyPages(anyInt())).thenReturn(Flux.empty());
|
||||||
|
runCommand(clock, Duration.ofSeconds(5), false, keysManager);
|
||||||
|
verify(keysManager).listStoredKEMPreKeyPages(anyInt());
|
||||||
|
verifyNoMoreInteractions(keysManager);
|
||||||
|
}
|
||||||
|
|
||||||
|
private enum PageStatus {NO_CURRENT, MATCH_CURRENT, MISMATCH_CURRENT}
|
||||||
|
|
||||||
|
@CartesianTest
|
||||||
|
void shouldDeletePage(
|
||||||
|
@CartesianTest.Enum final PageStatus pageStatus,
|
||||||
|
@CartesianTest.Values(booleans = {false, true}) final boolean isOld) {
|
||||||
|
final Optional<UUID> currentPage = pageStatus == PageStatus.NO_CURRENT
|
||||||
|
? Optional.empty()
|
||||||
|
: Optional.of(UUID.randomUUID());
|
||||||
|
final UUID page = switch (pageStatus) {
|
||||||
|
case MATCH_CURRENT -> currentPage.orElseThrow();
|
||||||
|
case NO_CURRENT, MISMATCH_CURRENT -> UUID.randomUUID();
|
||||||
|
};
|
||||||
|
|
||||||
|
final Instant threshold = Instant.EPOCH.plus(Duration.ofSeconds(10));
|
||||||
|
final Instant lastModified = isOld ? threshold.minus(Duration.ofSeconds(1)) : threshold;
|
||||||
|
|
||||||
|
final boolean shouldDelete = pageStatus != PageStatus.MATCH_CURRENT && isOld;
|
||||||
|
Assertions.assertThat(RemoveOrphanedPreKeyPagesCommand.shouldDeletePage(currentPage, page, threshold, lastModified))
|
||||||
|
.isEqualTo(shouldDelete);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private void runCommand(final Clock clock, final Duration minimumOrphanAge, final boolean dryRun,
|
||||||
|
final KeysManager keysManager) throws Exception {
|
||||||
|
final CommandDependencies commandDependencies = mock(CommandDependencies.class);
|
||||||
|
when(commandDependencies.keysManager()).thenReturn(keysManager);
|
||||||
|
|
||||||
|
final Namespace namespace = mock(Namespace.class);
|
||||||
|
when(namespace.getBoolean(RemoveOrphanedPreKeyPagesCommand.DRY_RUN_ARGUMENT)).thenReturn(dryRun);
|
||||||
|
when(namespace.getInt(RemoveOrphanedPreKeyPagesCommand.CONCURRENCY_ARGUMENT)).thenReturn(2);
|
||||||
|
when(namespace.getString(RemoveOrphanedPreKeyPagesCommand.MINIMUM_ORPHAN_AGE_ARGUMENT))
|
||||||
|
.thenReturn(minimumOrphanAge.toString());
|
||||||
|
|
||||||
|
final RemoveOrphanedPreKeyPagesCommand command = new RemoveOrphanedPreKeyPagesCommand(clock);
|
||||||
|
command.run(mock(Environment.class), namespace, mock(WhisperServerConfiguration.class), commandDependencies);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue