From de371418126fdd0ece57385f2dd3b49f1fd7db48 Mon Sep 17 00:00:00 2001 From: Ravi Khadiwala Date: Fri, 23 Feb 2024 13:05:53 -0600 Subject: [PATCH] Add a crawler that expires old backups --- .../textsecuregcm/WhisperServerService.java | 2 + .../textsecuregcm/backup/BackupManager.java | 77 ++++++++ .../textsecuregcm/backup/BackupsDb.java | 107 ++++++++++- .../textsecuregcm/backup/ExpiredBackup.java | 7 + .../workers/CommandDependencies.java | 31 ++++ .../workers/RemoveExpiredBackupsCommand.java | 166 ++++++++++++++++++ .../backup/BackupManagerTest.java | 153 +++++++++++++++- .../textsecuregcm/backup/BackupsDbTest.java | 41 +++++ 8 files changed, 577 insertions(+), 7 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/backup/ExpiredBackup.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredBackupsCommand.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index d2bf1cfb3..7c92e0000 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -239,6 +239,7 @@ import org.whispersystems.textsecuregcm.workers.DeleteUserCommand; import org.whispersystems.textsecuregcm.workers.MessagePersisterServiceCommand; import org.whispersystems.textsecuregcm.workers.ProcessPushNotificationFeedbackCommand; import org.whispersystems.textsecuregcm.workers.RemoveExpiredAccountsCommand; +import org.whispersystems.textsecuregcm.workers.RemoveExpiredBackupsCommand; import org.whispersystems.textsecuregcm.workers.RemoveExpiredLinkedDevicesCommand; import org.whispersystems.textsecuregcm.workers.RemoveExpiredUsernameHoldsCommand; import org.whispersystems.textsecuregcm.workers.ScheduledApnPushNotificationSenderServiceCommand; @@ -301,6 +302,7 @@ public class WhisperServerService extends Application getExpiredBackups(final int segments, final Scheduler scheduler, final Instant purgeTime) { + return this.backupsDb.getExpiredBackups(segments, scheduler, purgeTime); + } + + /** + * Delete some or all of the objects associated with the backup, and update the backup database. + * + * @param backupTierToRemove If {@link BackupTier#MEDIA}, will only delete media associated with the backup, if + * {@link BackupTier#MESSAGES} will also delete the messageBackup and remove any db record + * of the backup + * @param hashedBackupId The hashed backup-id for the backup + * @return A stage that completes when the deletion operation is finished + */ + public CompletableFuture deleteBackup(final BackupTier backupTierToRemove, final byte[] hashedBackupId) { + return switch (backupTierToRemove) { + case NONE -> CompletableFuture.completedFuture(null); + // Delete any media associated with the backup id, the message backup, and the row in our backups db table + case MESSAGES -> deleteAllMedia(hashedBackupId) + .thenCompose(ignored -> this.remoteStorageManager.delete( + "%s/%s".formatted(encodeForCdn(hashedBackupId), MESSAGE_BACKUP_NAME))) + .thenCompose(ignored -> this.backupsDb.deleteBackup(hashedBackupId)); + // Delete any media associated with the backup id, and clear any used media bytes + case MEDIA -> deleteAllMedia(hashedBackupId).thenCompose(ignore -> backupsDb.clearMediaUsage(hashedBackupId)); + }; + } + + /** + * List and delete all media associated with a backup. + * + * @param hashedBackupId The hashed backup-id for the backup + * @return A stage that completes when all media objects have been deleted + */ + private CompletableFuture deleteAllMedia(final byte[] hashedBackupId) { + final String mediaPrefix = cdnMediaDirectory(hashedBackupId); + return Mono + .fromCompletionStage(this.remoteStorageManager.list(mediaPrefix, Optional.empty(), 1000)) + .expand(listResult -> { + if (listResult.cursor().isEmpty()) { + return Mono.empty(); + } + return Mono.fromCompletionStage(() -> this.remoteStorageManager.list(mediaPrefix, listResult.cursor(), 1000)); + }) + .flatMap(listResult -> Flux.fromIterable(listResult.objects())) + // Delete the media objects. concatMap effectively makes the deletion operation single threaded -- it's expected + // the caller can increase/ concurrency by deleting more backups at once, rather than increasing concurrency + // deleting an individual backup + .concatMap(result -> Mono.fromCompletionStage(() -> + remoteStorageManager.delete("%s%s".formatted(mediaPrefix, result.key())))) + .count() + .doOnSuccess(itemsRemoved -> DistributionSummary.builder(DELETE_MEDIA_COUNT_DISTRIBUTION_NAME) + .publishPercentileHistogram(true) + .register(Metrics.globalRegistry) + .record(itemsRemoved)) + .then() + .toFuture(); + } + /** * Verify the presentation and return the extracted backup tier @@ -541,6 +614,10 @@ public class BackupManager { return "%s/%s/".formatted(encodeBackupIdForCdn(backupUser), MEDIA_DIRECTORY_NAME); } + private static String cdnMediaDirectory(final byte[] hashedBackupId) { + return "%s/%s/".formatted(encodeForCdn(hashedBackupId), MEDIA_DIRECTORY_NAME); + } + private static String cdnMediaPath(final AuthenticatedBackupUser backupUser, final byte[] mediaId) { return "%s%s".formatted(cdnMediaDirectory(backupUser), encodeForCdn(mediaId)); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupsDb.java b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupsDb.java index 9476f76e7..aa586fd11 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupsDb.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupsDb.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ package org.whispersystems.textsecuregcm.backup; import io.grpc.Status; @@ -12,6 +16,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.function.Predicate; import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -19,11 +24,15 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedBackupUser; import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.Util; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Scheduler; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.ConditionalCheckFailedException; +import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; +import software.amazon.awssdk.services.dynamodb.model.ScanRequest; import software.amazon.awssdk.services.dynamodb.model.Update; import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; @@ -64,6 +73,7 @@ public class BackupsDb { // N: Time in seconds since epoch of last backup media usage recalculation. This timestamp is updated whenever we // recalculate the up-to-date bytes used by querying the cdn(s) directly. public static final String ATTR_MEDIA_USAGE_LAST_RECALCULATION = "MBTS"; + // BOOL: If true, public BackupsDb( final DynamoDbAsyncClient dynamoClient, @@ -125,12 +135,13 @@ public class BackupsDb { /** * Update the quota in the backup table * - * @param backupUser The backup user + * @param backupUser The backup user * @param mediaBytesDelta The length of the media after encryption. A negative length implies media being removed * @param mediaCountDelta The number of media objects being added, or if negative, removed * @return A stage that completes successfully once the table are updated. */ - CompletableFuture trackMedia(final AuthenticatedBackupUser backupUser, final long mediaCountDelta, final long mediaBytesDelta) { + CompletableFuture trackMedia(final AuthenticatedBackupUser backupUser, final long mediaCountDelta, + final long mediaBytesDelta) { final Instant now = clock.instant(); return dynamoClient .updateItem( @@ -175,6 +186,14 @@ public class BackupsDb { .thenRun(Util.NOOP); } + CompletableFuture deleteBackup(final byte[] hashedBackupId) { + return dynamoClient.deleteItem(DeleteItemRequest.builder() + .tableName(backupTableName) + .key(Map.of(KEY_BACKUP_ID_HASH, AttributeValues.b(hashedBackupId))) + .build()) + .thenRun(Util.NOOP); + } + record BackupDescription(int cdn, Optional mediaUsedSpace) {} @@ -250,6 +269,66 @@ public class BackupsDb { .thenRun(Util.NOOP); } + CompletableFuture clearMediaUsage(final byte[] hashedBackupId) { + return dynamoClient.updateItem( + new UpdateBuilder(backupTableName, BackupTier.MEDIA, hashedBackupId) + .addSetExpression("#mediaBytesUsed = :mediaBytesUsed", + Map.entry("#mediaBytesUsed", ATTR_MEDIA_BYTES_USED), + Map.entry(":mediaBytesUsed", AttributeValues.n(0L))) + .addSetExpression("#mediaCount = :mediaCount", + Map.entry("#mediaCount", ATTR_MEDIA_COUNT), + Map.entry(":mediaCount", AttributeValues.n(0L))) + .addSetExpression("#mediaRecalc = :mediaRecalc", + Map.entry("#mediaRecalc", ATTR_MEDIA_USAGE_LAST_RECALCULATION), + Map.entry(":mediaRecalc", AttributeValues.n(clock.instant().getEpochSecond()))) + .addRemoveExpression(Map.entry("#mediaRefresh", ATTR_LAST_MEDIA_REFRESH)) + .updateItemBuilder() + .build()) + .thenRun(Util.NOOP); + } + + Flux getExpiredBackups(final int segments, final Scheduler scheduler, final Instant purgeTime) { + if (segments < 1) { + throw new IllegalArgumentException("Total number of segments must be positive"); + } + + return Flux.range(0, segments) + .parallel() + .runOn(scheduler) + .flatMap(segment -> dynamoClient.scanPaginator(ScanRequest.builder() + .tableName(backupTableName) + .consistentRead(true) + .segment(segment) + .totalSegments(segments) + .expressionAttributeNames(Map.of( + "#backupIdHash", KEY_BACKUP_ID_HASH, + "#refresh", ATTR_LAST_REFRESH, + "#mediaRefresh", ATTR_LAST_MEDIA_REFRESH)) + .expressionAttributeValues(Map.of(":purgeTime", AttributeValues.n(purgeTime.getEpochSecond()))) + .projectionExpression("#backupIdHash, #refresh, #mediaRefresh") + .filterExpression("(#refresh < :purgeTime) OR (#mediaRefresh < :purgeTime)") + .build()) + .items()) + .sequential() + .filter(Predicate.not(Map::isEmpty)) + .mapNotNull(item -> { + final byte[] hashedBackupId = AttributeValues.getByteArray(item, KEY_BACKUP_ID_HASH, null); + if (hashedBackupId == null) { + return null; + } + final long lastRefresh = AttributeValues.getLong(item, ATTR_LAST_REFRESH, Long.MAX_VALUE); + final long lastMediaRefresh = AttributeValues.getLong(item, ATTR_LAST_MEDIA_REFRESH, Long.MAX_VALUE); + + if (lastRefresh < purgeTime.getEpochSecond()) { + return new ExpiredBackup(hashedBackupId, BackupTier.MESSAGES); + } else if (lastMediaRefresh < purgeTime.getEpochSecond()) { + return new ExpiredBackup(hashedBackupId, BackupTier.MEDIA); + } else { + return null; + } + }); + } + /** * Build ddb update statements for the backups table @@ -257,6 +336,7 @@ public class BackupsDb { private static class UpdateBuilder { private final List setStatements = new ArrayList<>(); + private final List removeStatements = new ArrayList<>(); private final Map attrValues = new HashMap<>(); private final Map attrNames = new HashMap<>(); @@ -308,6 +388,12 @@ public class BackupsDb { return this; } + UpdateBuilder addRemoveExpression(final Map.Entry attrName) { + addAttrName(attrName); + removeStatements.add(attrName.getKey()); + return this; + } + UpdateBuilder withConditionExpression(final String conditionExpression) { this.conditionExpression = conditionExpression; return this; @@ -369,6 +455,19 @@ public class BackupsDb { return this; } + private String updateExpression() { + final StringBuilder sb = new StringBuilder(); + if (!setStatements.isEmpty()) { + sb.append("SET "); + sb.append(String.join(",", setStatements)); + } + if (!removeStatements.isEmpty()) { + sb.append(" REMOVE "); + sb.append(String.join(",", removeStatements)); + } + return sb.toString(); + } + /** * Prepare a non-transactional update * @@ -378,7 +477,7 @@ public class BackupsDb { final UpdateItemRequest.Builder bldr = UpdateItemRequest.builder() .tableName(tableName) .key(Map.of(KEY_BACKUP_ID_HASH, AttributeValues.b(hashedBackupId))) - .updateExpression("SET %s".formatted(String.join(",", setStatements))) + .updateExpression(updateExpression()) .expressionAttributeNames(attrNames) .expressionAttributeValues(attrValues); if (this.conditionExpression != null) { @@ -396,7 +495,7 @@ public class BackupsDb { final Update.Builder bldr = Update.builder() .tableName(tableName) .key(Map.of(KEY_BACKUP_ID_HASH, AttributeValues.b(hashedBackupId))) - .updateExpression("SET %s".formatted(String.join(",", setStatements))) + .updateExpression(updateExpression()) .expressionAttributeNames(attrNames) .expressionAttributeValues(attrValues); if (this.conditionExpression != null) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/backup/ExpiredBackup.java b/service/src/main/java/org/whispersystems/textsecuregcm/backup/ExpiredBackup.java new file mode 100644 index 000000000..1405b9c38 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/backup/ExpiredBackup.java @@ -0,0 +1,7 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.backup; + +public record ExpiredBackup(byte[] hashedBackupId, BackupTier backupTierToRemove) {} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java index cb211ac07..a776ef2bf 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java @@ -15,9 +15,15 @@ import java.security.cert.CertificateException; import java.time.Clock; import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; +import org.signal.libsignal.zkgroup.GenericServerSecretParams; +import org.signal.libsignal.zkgroup.InvalidInputException; import org.whispersystems.textsecuregcm.WhisperServerConfiguration; import org.whispersystems.textsecuregcm.WhisperServerService; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator; +import org.whispersystems.textsecuregcm.backup.BackupManager; +import org.whispersystems.textsecuregcm.backup.BackupsDb; +import org.whispersystems.textsecuregcm.backup.Cdn3BackupCredentialGenerator; +import org.whispersystems.textsecuregcm.backup.Cdn3RemoteStorageManager; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.controllers.SecureStorageController; import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controller; @@ -61,6 +67,7 @@ record CommandDependencies( KeysManager keysManager, FaultTolerantRedisCluster cacheCluster, ClientResources redisClusterClientResources, + BackupManager backupManager, DynamicConfigurationManager dynamicConfigurationManager) { static CommandDependencies build( @@ -105,6 +112,8 @@ record CommandDependencies( ScheduledExecutorService secureValueRecoveryServiceRetryExecutor = environment.lifecycle() .scheduledExecutorService(name(name, "secureValueRecoveryServiceRetry-%d")).threads(1).build(); + ScheduledExecutorService remoteStorageExecutor = environment.lifecycle() + .scheduledExecutorService(name(name, "remoteStorageRetry-%d")).threads(1).build(); ScheduledExecutorService storageServiceRetryExecutor = environment.lifecycle() .scheduledExecutorService(name(name, "storageServiceRetry-%d")).threads(1).build(); @@ -185,6 +194,27 @@ record CommandDependencies( secureStorageClient, secureValueRecovery2Client, clientPresenceManager, registrationRecoveryPasswordsManager, accountLockExecutor, clientPresenceExecutor, clock); + final BackupsDb backupsDb = + new BackupsDb(dynamoDbAsyncClient, configuration.getDynamoDbTables().getBackups().getTableName(), clock); + final GenericServerSecretParams backupsGenericZkSecretParams; + try { + backupsGenericZkSecretParams = + new GenericServerSecretParams(configuration.getBackupsZkConfig().serverSecret().value()); + } catch (InvalidInputException e) { + throw new IllegalArgumentException(e); + } + final BackupManager backupManager = new BackupManager( + backupsDb, + backupsGenericZkSecretParams, + new Cdn3BackupCredentialGenerator(configuration.getTus()), + new Cdn3RemoteStorageManager( + remoteStorageExecutor, + configuration.getClientCdnConfiguration().getCircuitBreaker(), + configuration.getClientCdnConfiguration().getRetry(), + configuration.getClientCdnConfiguration().getCaCertificates(), + configuration.getCdn3StorageManagerConfiguration()), + configuration.getClientCdnConfiguration().getAttachmentUrls(), + clock); environment.lifecycle().manage(messagesCache); environment.lifecycle().manage(clientPresenceManager); @@ -200,6 +230,7 @@ record CommandDependencies( keys, cacheCluster, redisClusterClientResources, + backupManager, dynamicConfigurationManager ); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredBackupsCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredBackupsCommand.java new file mode 100644 index 000000000..c58fada19 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredBackupsCommand.java @@ -0,0 +1,166 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.workers; + +import io.dropwizard.core.Application; +import io.dropwizard.core.cli.Cli; +import io.dropwizard.core.cli.EnvironmentCommand; +import io.dropwizard.core.setup.Environment; +import io.micrometer.core.instrument.Metrics; +import java.time.Clock; +import java.time.Duration; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicLong; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.Subparser; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.WhisperServerConfiguration; +import org.whispersystems.textsecuregcm.backup.BackupManager; +import org.whispersystems.textsecuregcm.backup.ExpiredBackup; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.util.logging.UncaughtExceptionHandler; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +public class RemoveExpiredBackupsCommand extends EnvironmentCommand { + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private static final String SEGMENT_COUNT_ARGUMENT = "segments"; + private static final String DRY_RUN_ARGUMENT = "dry-run"; + private static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency"; + private static final String GRACE_PERIOD_ARGUMENT = "grace-period"; + + // A backup that has not been refreshed after a grace period is eligible for deletion + private static final Duration DEFAULT_GRACE_PERIOD = Duration.ofDays(60); + private static final int DEFAULT_SEGMENT_COUNT = 1; + private static final int DEFAULT_CONCURRENCY = 16; + + private static final String EXPIRED_BACKUPS_COUNTER_NAME = MetricsUtil.name(RemoveExpiredBackupsCommand.class, + "expiredBackups"); + + private final Clock clock; + + public RemoveExpiredBackupsCommand(final Clock clock) { + super(new Application<>() { + @Override + public void run(final WhisperServerConfiguration configuration, final Environment environment) { + } + }, "remove-expired-backups", "Removes backups that have expired"); + this.clock = clock; + } + + @Override + public void configure(final Subparser subparser) { + super.configure(subparser); + + subparser.addArgument("--segments") + .type(Integer.class) + .dest(SEGMENT_COUNT_ARGUMENT) + .required(false) + .setDefault(DEFAULT_SEGMENT_COUNT) + .help("The total number of segments for a DynamoDB scan"); + + subparser.addArgument("--grace-period") + .type(Long.class) + .dest(GRACE_PERIOD_ARGUMENT) + .required(false) + .setDefault(DEFAULT_GRACE_PERIOD.getSeconds()) + .help("The number of seconds after which a backup is eligible for removal"); + + subparser.addArgument("--max-concurrency") + .type(Integer.class) + .dest(MAX_CONCURRENCY_ARGUMENT) + .required(false) + .setDefault(DEFAULT_CONCURRENCY) + .help("Max concurrency for backup expirations. Each expiration may do multiple cdn operations"); + + subparser.addArgument("--dry-run") + .type(Boolean.class) + .dest(DRY_RUN_ARGUMENT) + .required(false) + .setDefault(true) + .help("If true, don’t actually remove expired backups"); + } + + @Override + protected void run(final Environment environment, final Namespace namespace, + final WhisperServerConfiguration configuration) throws Exception { + + UncaughtExceptionHandler.register(); + final CommandDependencies commandDependencies = CommandDependencies.build(getName(), environment, configuration); + MetricsUtil.configureRegistries(configuration, environment, commandDependencies.dynamicConfigurationManager()); + + final int segments = Objects.requireNonNull(namespace.getInt(SEGMENT_COUNT_ARGUMENT)); + final int concurrency = Objects.requireNonNull(namespace.getInt(MAX_CONCURRENCY_ARGUMENT)); + final boolean dryRun = namespace.getBoolean(DRY_RUN_ARGUMENT); + final Duration gracePeriod = Duration.ofSeconds(Objects.requireNonNull(namespace.getLong(GRACE_PERIOD_ARGUMENT))); + + logger.info("Crawling backups with {} segments and {} processors, grace period {}", + segments, + Runtime.getRuntime().availableProcessors(), + gracePeriod); + + try { + environment.lifecycle().getManagedObjects().forEach(managedObject -> { + try { + managedObject.start(); + } catch (final Exception e) { + logger.error("Failed to start managed object", e); + throw new RuntimeException(e); + } + }); + final AtomicLong backupsExpired = new AtomicLong(); + final BackupManager backupManager = commandDependencies.backupManager(); + backupManager + .getExpiredBackups(segments, Schedulers.parallel(), clock.instant().plus(gracePeriod)) + .flatMap(expiredBackup -> removeExpiredBackup(backupManager, expiredBackup, dryRun), concurrency) + .doOnNext(ignored -> backupsExpired.incrementAndGet()) + .then() + .block(); + logger.info("Expired {} backups", backupsExpired.get()); + } finally { + environment.lifecycle().getManagedObjects().forEach(managedObject -> { + try { + managedObject.stop(); + } catch (final Exception e) { + logger.error("Failed to stop managed object", e); + } + }); + } + } + + private Mono removeExpiredBackup( + final BackupManager backupManager, final ExpiredBackup expiredBackup, + final boolean dryRun) { + + final Mono mono; + if (dryRun) { + mono = Mono.empty(); + } else { + mono = Mono.fromCompletionStage(() -> + backupManager.deleteBackup(expiredBackup.backupTierToRemove(), expiredBackup.hashedBackupId())); + } + + return mono + .doOnSuccess(ignored -> Metrics + .counter(EXPIRED_BACKUPS_COUNTER_NAME, + "tier", expiredBackup.backupTierToRemove().name(), + "dryRun", String.valueOf(dryRun)) + .increment()) + .onErrorResume(throwable -> { + logger.warn("Failed to remove tier {} for backup {}", expiredBackup.backupTierToRemove(), + expiredBackup.hashedBackupId()); + return Mono.empty(); + }); + } + + @Override + public void onError(final Cli cli, final Namespace namespace, final Throwable throwable) { + logger.error("Unhandled error", throwable); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java index 8fa04e9eb..aa5c6affc 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java @@ -10,6 +10,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatNoException; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; @@ -17,12 +18,14 @@ import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import io.grpc.Status; import io.grpc.StatusRuntimeException; import java.io.IOException; import java.net.URI; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; @@ -32,11 +35,15 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.function.Function; +import java.util.stream.IntStream; import javax.annotation.Nullable; import org.apache.commons.lang3.StringUtils; import org.junit.jupiter.api.BeforeEach; @@ -57,6 +64,7 @@ import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.TestRandomUtil; +import reactor.core.scheduler.Schedulers; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; @@ -417,9 +425,10 @@ public class BackupManagerTest { .toCompletableFuture().join(); assertThat(result.media()).hasSize(1); assertThat(result.media().get(0).cdn()).isEqualTo(13); - assertThat(result.media().get(0).key()).isEqualTo(Base64.getDecoder().decode("aaa".getBytes(StandardCharsets.UTF_8))); + assertThat(result.media().get(0).key()).isEqualTo( + Base64.getDecoder().decode("aaa".getBytes(StandardCharsets.UTF_8))); assertThat(result.media().get(0).length()).isEqualTo(123); - assertThat(result.cursor()).get().isEqualTo("newCursor"); + assertThat(result.cursor().get()).isEqualTo("newCursor"); } @@ -449,7 +458,7 @@ public class BackupManagerTest { final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupTier.MEDIA); when(remoteStorageManager.cdnNumber()).thenReturn(5); assertThatThrownBy(() -> - backupManager.delete( backupUser, List.of(new BackupManager.StorageDescriptor(4, TestRandomUtil.nextBytes(15))))) + backupManager.delete(backupUser, List.of(new BackupManager.StorageDescriptor(4, TestRandomUtil.nextBytes(15))))) .isInstanceOf(StatusRuntimeException.class) .matches(e -> ((StatusRuntimeException) e).getStatus().getCode() == Status.INVALID_ARGUMENT.getCode()); } @@ -508,6 +517,144 @@ public class BackupManagerTest { .isEqualTo(new UsageInfo(100, 5)); } + @Test + public void listExpiredBackups() { + final List backupUsers = IntStream.range(0, 10) + .mapToObj(i -> backupUser(TestRandomUtil.nextBytes(16), BackupTier.MEDIA)) + .toList(); + for (int i = 0; i < backupUsers.size(); i++) { + testClock.pin(Instant.ofEpochSecond(i)); + backupManager.createMessageBackupUploadDescriptor(backupUsers.get(i)).join(); + } + + // set of backup-id hashes that should be expired (initially t=0) + final Set expectedHashes = new HashSet<>(); + + for (int i = 0; i < backupUsers.size(); i++) { + testClock.pin(Instant.ofEpochSecond(i)); + + // get backups expired at t=i + final List expired = backupManager + .getExpiredBackups(1, Schedulers.immediate(), Instant.ofEpochSecond(i)) + .collectList() + .block(); + + // all the backups tht should be expired at t=i should be returned (ones with expiration time 0,1,...i-1) + assertThat(expired.size()).isEqualTo(expectedHashes.size()); + assertThat(expired.stream() + .map(ExpiredBackup::hashedBackupId) + .map(ByteBuffer::wrap) + .allMatch(expectedHashes::contains)).isTrue(); + assertThat(expired.stream().allMatch(eb -> eb.backupTierToRemove() == BackupTier.MESSAGES)).isTrue(); + + // on next iteration, backup i should be expired + expectedHashes.add(ByteBuffer.wrap(hashedBackupId(backupUsers.get(i).backupId()))); + } + } + + @Test + public void listExpiredBackupsByTier() { + final byte[] backupId = TestRandomUtil.nextBytes(16); + + // refreshed media timestamp at t=5 + testClock.pin(Instant.ofEpochSecond(5)); + backupManager.createMessageBackupUploadDescriptor(backupUser(backupId, BackupTier.MEDIA)).join(); + + // refreshed messages timestamp at t=6 + testClock.pin(Instant.ofEpochSecond(6)); + backupManager.createMessageBackupUploadDescriptor(backupUser(backupId, BackupTier.MESSAGES)).join(); + + Function> getExpired = time -> backupManager + .getExpiredBackups(1, Schedulers.immediate(), time) + .collectList().block(); + + assertThat(getExpired.apply(Instant.ofEpochSecond(5))).isEmpty(); + + assertThat(getExpired.apply(Instant.ofEpochSecond(6))) + .hasSize(1).first() + .matches(eb -> eb.backupTierToRemove() == BackupTier.MEDIA, "is media tier"); + + assertThat(getExpired.apply(Instant.ofEpochSecond(7))) + .hasSize(1).first() + .matches(eb -> eb.backupTierToRemove() == BackupTier.MESSAGES, "is messages tier"); + } + + @ParameterizedTest + @EnumSource(mode = EnumSource.Mode.INCLUDE, names = {"MESSAGES", "MEDIA"}) + public void deleteBackup(BackupTier backupTier) { + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupTier.MEDIA); + backupManager.createMessageBackupUploadDescriptor(backupUser).join(); + final String mediaPrefix = "%s/%s/" + .formatted(BackupManager.encodeBackupIdForCdn(backupUser), BackupManager.MEDIA_DIRECTORY_NAME); + when(remoteStorageManager.list(eq(mediaPrefix), eq(Optional.empty()), anyLong())) + .thenReturn(CompletableFuture.completedFuture(new RemoteStorageManager.ListResult(List.of( + new RemoteStorageManager.ListResult.Entry("abc", 1), + new RemoteStorageManager.ListResult.Entry("def", 1), + new RemoteStorageManager.ListResult.Entry("ghi", 1)), Optional.empty()))); + when(remoteStorageManager.delete(anyString())).thenReturn(CompletableFuture.completedFuture(1L)); + + backupManager.deleteBackup(backupTier, hashedBackupId(backupUser.backupId())).join(); + verify(remoteStorageManager, times(1)).list(anyString(), any(), anyLong()); + verify(remoteStorageManager, times(1)).delete(mediaPrefix + "abc"); + verify(remoteStorageManager, times(1)).delete(mediaPrefix + "def"); + verify(remoteStorageManager, times(1)).delete(mediaPrefix + "ghi"); + verify(remoteStorageManager, times(backupTier == BackupTier.MESSAGES ? 1 : 0)) + .delete("%s/%s".formatted(BackupManager.encodeBackupIdForCdn(backupUser), BackupManager.MESSAGE_BACKUP_NAME)); + verifyNoMoreInteractions(remoteStorageManager); + + final BackupsDb.TimestampedUsageInfo usage = backupsDb.getMediaUsage(backupUser).join(); + assertThat(usage.usageInfo().bytesUsed()).isEqualTo(0L); + assertThat(usage.usageInfo().numObjects()).isEqualTo(0L); + + if (backupTier == BackupTier.MEDIA) { + // should have deleted all the media, but left the backup descriptor in place + assertThatNoException().isThrownBy(() -> backupsDb.describeBackup(backupUser).join()); + } else { + // should have deleted the db row for the backup + assertThat(CompletableFutureTestUtil.assertFailsWithCause( + StatusRuntimeException.class, + backupsDb.describeBackup(backupUser)) + .getStatus().getCode()) + .isEqualTo(Status.NOT_FOUND.getCode()); + } + } + + @Test + public void deleteBackupPaginated() { + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupTier.MEDIA); + backupManager.createMessageBackupUploadDescriptor(backupUser).join(); + final String mediaPrefix = "%s/%s/".formatted(BackupManager.encodeBackupIdForCdn(backupUser), + BackupManager.MEDIA_DIRECTORY_NAME); + + // Return 1 item per page. Initially the provided cursor is empty and we'll return the cursor string "1". + // When we get the cursor "1", we'll return "2", when "2" we'll return empty indicating listing + // is complete + when(remoteStorageManager.list(eq(mediaPrefix), any(), anyLong())).thenAnswer(a -> { + Optional cursor = a.getArgument(1); + return CompletableFuture.completedFuture( + new RemoteStorageManager.ListResult(List.of(new RemoteStorageManager.ListResult.Entry( + switch (cursor.orElse("0")) { + case "0" -> "abc"; + case "1" -> "def"; + case "2" -> "ghi"; + default -> throw new IllegalArgumentException(); + }, 1L)), + switch (cursor.orElse("0")) { + case "0" -> Optional.of("1"); + case "1" -> Optional.of("2"); + case "2" -> Optional.empty(); + default -> throw new IllegalArgumentException(); + })); + }); + when(remoteStorageManager.delete(anyString())).thenReturn(CompletableFuture.completedFuture(1L)); + backupManager.deleteBackup(BackupTier.MEDIA, hashedBackupId(backupUser.backupId())).join(); + verify(remoteStorageManager, times(3)).list(anyString(), any(), anyLong()); + verify(remoteStorageManager, times(1)).delete(mediaPrefix + "abc"); + verify(remoteStorageManager, times(1)).delete(mediaPrefix + "def"); + verify(remoteStorageManager, times(1)).delete(mediaPrefix + "ghi"); + verifyNoMoreInteractions(remoteStorageManager); + } + private Map getBackupItem(final AuthenticatedBackupUser backupUser) { return DYNAMO_DB_EXTENSION.getDynamoDbClient().getItem(GetItemRequest.builder() .tableName(DynamoDbExtensionSchema.Tables.BACKUPS.tableName()) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupsDbTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupsDbTest.java index 02d226da8..e91245bae 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupsDbTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupsDbTest.java @@ -14,6 +14,8 @@ import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.time.Instant; import java.util.Arrays; +import java.util.List; +import java.util.function.Function; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; @@ -25,6 +27,7 @@ import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.TestRandomUtil; +import reactor.core.scheduler.Schedulers; public class BackupsDbTest { @@ -79,6 +82,44 @@ public class BackupsDbTest { assertThat(info.usageInfo().numObjects()).isEqualTo(17L); } + @Test + public void expirationDetectedOnce() { + final byte[] backupId = TestRandomUtil.nextBytes(16); + // Refresh media/messages at t=0 + testClock.pin(Instant.ofEpochSecond(0L)); + this.backupsDb.ttlRefresh(backupUser(backupId, BackupTier.MEDIA)).join(); + + // refresh only messages at t=2 + testClock.pin(Instant.ofEpochSecond(2L)); + this.backupsDb.ttlRefresh(backupUser(backupId, BackupTier.MESSAGES)).join(); + + final Function> expiredBackups = purgeTime -> backupsDb + .getExpiredBackups(1, Schedulers.immediate(), purgeTime) + .collectList() + .block(); + + List expired = expiredBackups.apply(Instant.ofEpochSecond(1)); + assertThat(expired).hasSize(1).first() + .matches(eb -> eb.backupTierToRemove() == BackupTier.MEDIA); + + // Expire the media + backupsDb.clearMediaUsage(expired.get(0).hashedBackupId()).join(); + + // should be nothing to expire at t=1 + assertThat(expiredBackups.apply(Instant.ofEpochSecond(1))).isEmpty(); + + // at t=3, should now expire messages as well + expired = expiredBackups.apply(Instant.ofEpochSecond(3)); + assertThat(expired).hasSize(1).first() + .matches(eb -> eb.backupTierToRemove() == BackupTier.MESSAGES); + + // Expire the messages + backupsDb.deleteBackup(expired.get(0).hashedBackupId()).join(); + + // should be nothing to expire at t=3 + assertThat(expiredBackups.apply(Instant.ofEpochSecond(3))).isEmpty(); + } + private AuthenticatedBackupUser backupUser(final byte[] backupId, final BackupTier backupTier) { return new AuthenticatedBackupUser(backupId, backupTier); }