diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedBackupUser.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedBackupUser.java index 6c55e9135..0cb9bb7f8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedBackupUser.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedBackupUser.java @@ -7,4 +7,4 @@ package org.whispersystems.textsecuregcm.auth; import org.whispersystems.textsecuregcm.backup.BackupTier; -public record AuthenticatedBackupUser(byte[] backupId, BackupTier backupTier) {} +public record AuthenticatedBackupUser(byte[] backupId, BackupTier backupTier, String backupDir, String mediaDir) {} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupManager.java index 51cd56024..22ddd80ed 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupManager.java @@ -16,7 +16,6 @@ import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.Base64; -import java.util.HexFormat; import java.util.List; import java.util.Map; import java.util.Optional; @@ -24,7 +23,6 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.stream.Collectors; import org.apache.commons.lang3.StringUtils; -import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.signal.libsignal.zkgroup.GenericServerSecretParams; import org.signal.libsignal.zkgroup.VerificationFailedException; @@ -42,7 +40,6 @@ public class BackupManager { private static final Logger logger = LoggerFactory.getLogger(BackupManager.class); - static final String MEDIA_DIRECTORY_NAME = "media"; static final String MESSAGE_BACKUP_NAME = "messageBackup"; static final long MAX_TOTAL_BACKUP_MEDIA_BYTES = 1024L * 1024L * 1024L * 50L; static final long MAX_MEDIA_OBJECT_SIZE = 1024L * 1024L * 101L; @@ -53,8 +50,8 @@ public class BackupManager { "authorizationFailure"); private static final String USAGE_RECALCULATION_COUNTER_NAME = MetricsUtil.name(BackupManager.class, "usageRecalculation"); - private static final String DELETE_MEDIA_COUNT_DISTRIBUTION_NAME = MetricsUtil.name(BackupManager.class, - "deleteMediaCount"); + private static final String DELETE_COUNT_DISTRIBUTION_NAME = MetricsUtil.name(BackupManager.class, + "deleteCount"); private static final String SUCCESS_TAG_NAME = "success"; private static final String FAILURE_REASON_TAG_NAME = "reason"; @@ -157,7 +154,8 @@ public class BackupManager { return backupsDb.ttlRefresh(backupUser); } - public record BackupInfo(int cdn, String backupSubdir, String messageBackupKey, Optional mediaUsedSpace) {} + public record BackupInfo(int cdn, String backupSubdir, String mediaSubdir, String messageBackupKey, + Optional mediaUsedSpace) {} /** * Retrieve information about the existing backup @@ -174,7 +172,8 @@ public class BackupManager { return backupsDb.describeBackup(backupUser) .thenApply(backupDescription -> new BackupInfo( backupDescription.cdn(), - encodeBackupIdForCdn(backupUser), + backupUser.backupDir(), + backupUser.mediaDir(), MESSAGE_BACKUP_NAME, backupDescription.mediaUsedSpace())); } @@ -315,8 +314,7 @@ public class BackupManager { .withDescription("credential does not support read auth operation") .asRuntimeException(); } - final String encodedBackupId = encodeBackupIdForCdn(backupUser); - return cdn3BackupCredentialGenerator.readHeaders(encodedBackupId); + return cdn3BackupCredentialGenerator.readHeaders(backupUser.backupDir()); } @@ -354,7 +352,7 @@ public class BackupManager { .stream() .map(entry -> new StorageDescriptorWithLength( remoteStorageManager.cdnNumber(), - decodeFromCdn(entry.key()), + decodeMediaIdFromCdn(entry.key()), entry.length() )) .toList(), @@ -447,9 +445,9 @@ public class BackupManager { final BackupAuthCredentialPresentation presentation, final byte[] signature) { return backupsDb - .retrievePublicKey(presentation.getBackupId()) - .thenApply(optionalPublicKey -> { - final byte[] publicKeyBytes = optionalPublicKey + .retrieveAuthenticationData(presentation.getBackupId()) + .thenApply(optionalAuthenticationData -> { + final BackupsDb.AuthenticationData authenticationData = optionalAuthenticationData .orElseThrow(() -> { Metrics.counter(ZK_AUTHN_COUNTER_NAME, SUCCESS_TAG_NAME, String.valueOf(false), @@ -457,23 +455,10 @@ public class BackupManager { .increment(); return Status.NOT_FOUND.withDescription("Backup not found").asRuntimeException(); }); - try { - final ECPublicKey publicKey = new ECPublicKey(publicKeyBytes); - return new AuthenticatedBackupUser( - presentation.getBackupId(), - verifySignatureAndCheckPresentation(presentation, signature, publicKey)); - } catch (InvalidKeyException e) { - Metrics.counter(ZK_AUTHN_COUNTER_NAME, - SUCCESS_TAG_NAME, String.valueOf(false), - FAILURE_REASON_TAG_NAME, "invalid_public_key") - .increment(); - logger.error("Invalid publicKey for backupId hash {}", - HexFormat.of().formatHex(BackupsDb.hashedBackupId(presentation.getBackupId())), e); - throw Status.INTERNAL - .withCause(e) - .withDescription("Could not deserialize stored public key") - .asRuntimeException(); - } + return new AuthenticatedBackupUser( + presentation.getBackupId(), + verifySignatureAndCheckPresentation(presentation, signature, authenticationData.publicKey()), + authenticationData.backupDir(), authenticationData.mediaDir()); }) .thenApply(result -> { Metrics.counter(ZK_AUTHN_COUNTER_NAME, SUCCESS_TAG_NAME, String.valueOf(true)).increment(); @@ -498,49 +483,43 @@ public class BackupManager { /** * Delete some or all of the objects associated with the backup, and update the backup database. * - * @param backupTierToRemove If {@link BackupTier#MEDIA}, will only delete media associated with the backup, if - * {@link BackupTier#MESSAGES} will also delete the messageBackup and remove any db record - * of the backup - * @param hashedBackupId The hashed backup-id for the backup + * @param expiredBackup The backup to expire. If the {@link ExpiredBackup} is a media expiration, only the media + * objects will be deleted, otherwise all backup objects will be deleted * @return A stage that completes when the deletion operation is finished */ - public CompletableFuture deleteBackup(final BackupTier backupTierToRemove, final byte[] hashedBackupId) { - return switch (backupTierToRemove) { - case NONE -> CompletableFuture.completedFuture(null); - // Delete any media associated with the backup id, the message backup, and the row in our backups db table - case MESSAGES -> deleteAllMedia(hashedBackupId) - .thenCompose(ignored -> this.remoteStorageManager.delete( - "%s/%s".formatted(encodeForCdn(hashedBackupId), MESSAGE_BACKUP_NAME))) - .thenCompose(ignored -> this.backupsDb.deleteBackup(hashedBackupId)); - // Delete any media associated with the backup id, and clear any used media bytes - case MEDIA -> deleteAllMedia(hashedBackupId).thenCompose(ignore -> backupsDb.clearMediaUsage(hashedBackupId)); - }; + public CompletableFuture expireBackup(final ExpiredBackup expiredBackup) { + return backupsDb.startExpiration(expiredBackup) + .thenCompose(ignored -> deletePrefix(expiredBackup.prefixToDelete())) + .thenCompose(ignored -> backupsDb.finishExpiration(expiredBackup)); } /** - * List and delete all media associated with a backup. + * List and delete all files associated with a prefix * - * @param hashedBackupId The hashed backup-id for the backup - * @return A stage that completes when all media objects have been deleted + * @param prefixToDelete The prefix to expire. + * @return A stage that completes when all objects with the given prefix have been deleted */ - private CompletableFuture deleteAllMedia(final byte[] hashedBackupId) { - final String mediaPrefix = cdnMediaDirectory(hashedBackupId); + private CompletableFuture deletePrefix(final String prefixToDelete) { + if (prefixToDelete.length() != BackupsDb.BACKUP_DIRECTORY_PATH_LENGTH + && prefixToDelete.length() != BackupsDb.MEDIA_DIRECTORY_PATH_LENGTH) { + throw new IllegalArgumentException("Unexpected prefix deletion for " + prefixToDelete); + } + final String prefix = prefixToDelete + "/"; return Mono - .fromCompletionStage(this.remoteStorageManager.list(mediaPrefix, Optional.empty(), 1000)) + .fromCompletionStage(this.remoteStorageManager.list(prefix, Optional.empty(), 1000)) .expand(listResult -> { if (listResult.cursor().isEmpty()) { return Mono.empty(); } - return Mono.fromCompletionStage(() -> this.remoteStorageManager.list(mediaPrefix, listResult.cursor(), 1000)); + return Mono.fromCompletionStage(() -> this.remoteStorageManager.list(prefix, listResult.cursor(), 1000)); }) .flatMap(listResult -> Flux.fromIterable(listResult.objects())) - // Delete the media objects. concatMap effectively makes the deletion operation single threaded -- it's expected - // the caller can increase/ concurrency by deleting more backups at once, rather than increasing concurrency + // Delete the objects. concatMap effectively makes the deletion operation single threaded -- it's expected + // the caller can increase concurrency by deleting more backups at once, rather than increasing concurrency // deleting an individual backup - .concatMap(result -> Mono.fromCompletionStage(() -> - remoteStorageManager.delete("%s%s".formatted(mediaPrefix, result.key())))) + .concatMap(result -> Mono.fromCompletionStage(() -> remoteStorageManager.delete(prefix + result.key()))) .count() - .doOnSuccess(itemsRemoved -> DistributionSummary.builder(DELETE_MEDIA_COUNT_DISTRIBUTION_NAME) + .doOnSuccess(itemsRemoved -> DistributionSummary.builder(DELETE_COUNT_DISTRIBUTION_NAME) .publishPercentileHistogram(true) .register(Metrics.globalRegistry) .record(itemsRemoved)) @@ -593,33 +572,23 @@ public class BackupManager { } @VisibleForTesting - static String encodeBackupIdForCdn(final AuthenticatedBackupUser backupUser) { - return encodeForCdn(BackupsDb.hashedBackupId(backupUser.backupId())); - } - - @VisibleForTesting - static String encodeForCdn(final byte[] bytes) { + static String encodeMediaIdForCdn(final byte[] bytes) { return Base64.getUrlEncoder().encodeToString(bytes); } - private static byte[] decodeFromCdn(final String base64) { + private static byte[] decodeMediaIdFromCdn(final String base64) { return Base64.getUrlDecoder().decode(base64); } private static String cdnMessageBackupName(final AuthenticatedBackupUser backupUser) { - return "%s/%s".formatted(encodeBackupIdForCdn(backupUser), MESSAGE_BACKUP_NAME); + return "%s/%s".formatted(backupUser.backupDir(), MESSAGE_BACKUP_NAME); } private static String cdnMediaDirectory(final AuthenticatedBackupUser backupUser) { - return "%s/%s/".formatted(encodeBackupIdForCdn(backupUser), MEDIA_DIRECTORY_NAME); - } - - private static String cdnMediaDirectory(final byte[] hashedBackupId) { - return "%s/%s/".formatted(encodeForCdn(hashedBackupId), MEDIA_DIRECTORY_NAME); + return "%s/%s/".formatted(backupUser.backupDir(), backupUser.mediaDir()); } private static String cdnMediaPath(final AuthenticatedBackupUser backupUser, final byte[] mediaId) { - return "%s%s".formatted(cdnMediaDirectory(backupUser), encodeForCdn(mediaId)); + return "%s%s".formatted(cdnMediaDirectory(backupUser), encodeMediaIdForCdn(mediaId)); } - } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupsDb.java b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupsDb.java index aa586fd11..6238c6dab 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupsDb.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupsDb.java @@ -7,16 +7,20 @@ package org.whispersystems.textsecuregcm.backup; import io.grpc.Status; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; import java.time.Clock; import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; +import java.util.Base64; import java.util.HashMap; +import java.util.HexFormat; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.function.Predicate; +import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,9 +45,29 @@ import software.amazon.awssdk.services.dynamodb.model.UpdateItemRequest; *

* It's assumed that the caller has already validated that the backupUser being operated on has valid credentials and * possesses the appropriate {@link BackupTier} to perform the current operation. + *

+ * Backup records track two timestamps indicating the last time that a user interacted with their backup. One for the + * last refresh that contained a credential including media tier, and the other for any access. After a period of + * inactivity stale backups can be purged (either just the media, or the entire backup). Callers can discover what + * backups are stale and whether only the media or the entire backup is stale via {@link #getExpiredBackups}. + *

+ * Because backup objects reside on a transactionally unrelated store, expiring anything from the backup requires a 2 + * phase process. First the caller calls {@link #startExpiration} which will atomically update the user's backup + * directories and record the cdn directory that should be expired. Then the caller must delete the expired directory, + * calling {@link #finishExpiration} to clear the recorded expired prefix when complete. Since the user's backup + * directories have been swapped, the deleter does not have to account for a user coming back and starting to upload + * concurrently with the deletion. + *

+ * If the directory deletion fails, a subsequent call to {@link #getExpiredBackups} will return the backup again + * indicating that the old expired prefix needs to be cleaned up before any other expiration action is taken. For + * example, if a media expiration fails and then in the next expiration pass the backup has become eligible for total + * deletion, the caller still must process the stale media expiration first before processing the full deletion. */ public class BackupsDb { + private static final int DIR_NAME_LENGTH = generateDirName(new SecureRandom()).length(); + public static final int BACKUP_DIRECTORY_PATH_LENGTH = DIR_NAME_LENGTH; + public static final int MEDIA_DIRECTORY_PATH_LENGTH = BACKUP_DIRECTORY_PATH_LENGTH + "/".length() + DIR_NAME_LENGTH; private static final Logger logger = LoggerFactory.getLogger(BackupsDb.class); static final int BACKUP_CDN = 3; @@ -51,6 +75,8 @@ public class BackupsDb { private final String backupTableName; private final Clock clock; + private final SecureRandom secureRandom; + // The backups table // B: 16 bytes that identifies the backup @@ -73,7 +99,12 @@ public class BackupsDb { // N: Time in seconds since epoch of last backup media usage recalculation. This timestamp is updated whenever we // recalculate the up-to-date bytes used by querying the cdn(s) directly. public static final String ATTR_MEDIA_USAGE_LAST_RECALCULATION = "MBTS"; - // BOOL: If true, + // S: The name of the user's backup directory on the CDN + public static final String ATTR_BACKUP_DIR = "BD"; + // S: The name of the user's media directory within the backup directory on the CDN + public static final String ATTR_MEDIA_DIR = "MD"; + // S: A prefix pending deletion + public static final String ATTR_EXPIRED_PREFIX = "EP"; public BackupsDb( final DynamoDbAsyncClient dynamoClient, @@ -82,6 +113,7 @@ public class BackupsDb { this.dynamoClient = dynamoClient; this.backupTableName = backupTableName; this.clock = clock; + this.secureRandom = new SecureRandom(); } /** @@ -102,35 +134,76 @@ public class BackupsDb { .addSetExpression("#publicKey = :publicKey", Map.entry("#publicKey", ATTR_PUBLIC_KEY), Map.entry(":publicKey", AttributeValues.b(publicKey.serialize()))) + // When the user sets a public key, we ensure that they have a backupDir/mediaDir assigned + .setDirectoryNamesIfMissing(secureRandom) .setRefreshTimes(clock) .withConditionExpression("attribute_not_exists(#publicKey) OR #publicKey = :publicKey") .updateItemBuilder() .build()) - .exceptionally(throwable -> { - // There was already a row for this backup-id and it contained a different publicKey - if (ExceptionUtils.unwrap(throwable) instanceof ConditionalCheckFailedException) { - throw ExceptionUtils.wrap(new PublicKeyConflictException()); - } - throw ExceptionUtils.wrap(throwable); - }) + .exceptionally(ExceptionUtils.marshal(ConditionalCheckFailedException.class, e -> + // There was already a row for this backup-id and it contained a different publicKey + new PublicKeyConflictException())) .thenRun(Util.NOOP); } - CompletableFuture> retrievePublicKey(byte[] backupId) { + /** + * Data stored to authenticate a backup user + * + * @param publicKey The public key for the backup entry. All credentials for this backup user must be signed * by this + * public key for the credential to be valid + * @param backupDir The current backupDir for the backup user. If authentication is successful, the user may be given + * credentials for this backupDir on the CDN + * @param mediaDir The current mediaDir for the backup user. If authentication is successful, the user may be given * + * credentials for the path backupDir/mediaDir on the CDN + */ + record AuthenticationData(ECPublicKey publicKey, String backupDir, String mediaDir) {} + + CompletableFuture> retrieveAuthenticationData(byte[] backupId) { final byte[] hashedBackupId = hashedBackupId(backupId); return dynamoClient.getItem(GetItemRequest.builder() .tableName(backupTableName) .key(Map.of(KEY_BACKUP_ID_HASH, AttributeValues.b(hashedBackupId))) .consistentRead(true) - .projectionExpression("#publicKey") - .expressionAttributeNames(Map.of("#publicKey", ATTR_PUBLIC_KEY)) + .projectionExpression("#publicKey,#backupDir,#mediaDir") + .expressionAttributeNames(Map.of( + "#publicKey", ATTR_PUBLIC_KEY, + "#backupDir", ATTR_BACKUP_DIR, + "#mediaDir", ATTR_MEDIA_DIR)) .build()) - .thenApply(response -> - AttributeValues.get(response.item(), ATTR_PUBLIC_KEY) - .map(AttributeValue::b) - .map(SdkBytes::asByteArray)); + .thenApply(response -> extractStoredPublicKey(response.item()) + .map(pubKey -> new AuthenticationData( + pubKey, + getDirName(response.item(), ATTR_BACKUP_DIR), + getDirName(response.item(), ATTR_MEDIA_DIR)))); } + private static String getDirName(final Map item, final String attr) { + return AttributeValues.get(item, attr).map(AttributeValue::s).orElseThrow(() -> { + logger.error("Backups with public keys should have directory names"); + return Status.INTERNAL + .withDescription("Backups with public keys must have directory names") + .asRuntimeException(); + }); + } + + private static Optional extractStoredPublicKey(final Map item) { + return AttributeValues.get(item, ATTR_PUBLIC_KEY) + .map(AttributeValue::b) + .map(SdkBytes::asByteArray) + .map(BackupsDb::deserializeStoredPublicKey); + } + + private static ECPublicKey deserializeStoredPublicKey(final byte[] publicKeyBytes) { + try { + return new ECPublicKey(publicKeyBytes); + } catch (InvalidKeyException e) { + logger.error("Invalid publicKey {}", HexFormat.of().formatHex(publicKeyBytes), e); + throw Status.INTERNAL + .withCause(e) + .withDescription("Could not deserialize stored public key") + .asRuntimeException(); + } + } /** * Update the quota in the backup table @@ -186,15 +259,6 @@ public class BackupsDb { .thenRun(Util.NOOP); } - CompletableFuture deleteBackup(final byte[] hashedBackupId) { - return dynamoClient.deleteItem(DeleteItemRequest.builder() - .tableName(backupTableName) - .key(Map.of(KEY_BACKUP_ID_HASH, AttributeValues.b(hashedBackupId))) - .build()) - .thenRun(Util.NOOP); - } - - record BackupDescription(int cdn, Optional mediaUsedSpace) {} /** @@ -214,13 +278,10 @@ public class BackupsDb { .build()) .thenApply(response -> { if (!response.hasItem()) { - throw Status.NOT_FOUND.withDescription("Backup not found").asRuntimeException(); + throw Status.NOT_FOUND.withDescription("Backup ID not found").asRuntimeException(); } - final int cdn = AttributeValues.get(response.item(), ATTR_CDN) - .map(AttributeValue::n) - .map(Integer::parseInt) - .orElseThrow(() -> Status.NOT_FOUND.withDescription("Stored backup not found").asRuntimeException()); - + // If the client hasn't already uploaded a backup, return the cdn we would return if they did create one + final int cdn = AttributeValues.getInt(response.item(), ATTR_CDN, BACKUP_CDN); final Optional mediaUsed = AttributeValues.get(response.item(), ATTR_MEDIA_BYTES_USED) .map(AttributeValue::n) .map(Long::parseLong); @@ -269,24 +330,74 @@ public class BackupsDb { .thenRun(Util.NOOP); } - CompletableFuture clearMediaUsage(final byte[] hashedBackupId) { - return dynamoClient.updateItem( - new UpdateBuilder(backupTableName, BackupTier.MEDIA, hashedBackupId) - .addSetExpression("#mediaBytesUsed = :mediaBytesUsed", - Map.entry("#mediaBytesUsed", ATTR_MEDIA_BYTES_USED), - Map.entry(":mediaBytesUsed", AttributeValues.n(0L))) - .addSetExpression("#mediaCount = :mediaCount", - Map.entry("#mediaCount", ATTR_MEDIA_COUNT), - Map.entry(":mediaCount", AttributeValues.n(0L))) - .addSetExpression("#mediaRecalc = :mediaRecalc", - Map.entry("#mediaRecalc", ATTR_MEDIA_USAGE_LAST_RECALCULATION), - Map.entry(":mediaRecalc", AttributeValues.n(clock.instant().getEpochSecond()))) - .addRemoveExpression(Map.entry("#mediaRefresh", ATTR_LAST_MEDIA_REFRESH)) - .updateItemBuilder() - .build()) + + /** + * Marks the backup as undergoing expiration. + *

+ * This must be called before beginning to delete items in the CDN with the prefix specified by + * {@link ExpiredBackup#prefixToDelete()}. If the prefix has been successfully deleted, {@link #finishExpiration} must + * be called. + * + * @param expiredBackup The backup to expire + * @return A stage that completes when the backup has been marked for expiration + */ + CompletableFuture startExpiration(final ExpiredBackup expiredBackup) { + if (expiredBackup.expirationType() == ExpiredBackup.ExpirationType.GARBAGE_COLLECTION) { + // We've already updated the row on a prior (failed) attempt, just need to remove the data from the cdn now + return CompletableFuture.completedFuture(null); + } + + // Clear usage metadata, swap names of things we intend to delete, and record our intent to delete in attr_expired_prefix + return dynamoClient.updateItem(new UpdateBuilder(backupTableName, BackupTier.MEDIA, expiredBackup.hashedBackupId()) + .addSetExpression("#mediaBytesUsed = :mediaBytesUsed", + Map.entry("#mediaBytesUsed", ATTR_MEDIA_BYTES_USED), + Map.entry(":mediaBytesUsed", AttributeValues.n(0L))) + .addSetExpression("#mediaCount = :mediaCount", + Map.entry("#mediaCount", ATTR_MEDIA_COUNT), + Map.entry(":mediaCount", AttributeValues.n(0L))) + .addSetExpression("#mediaRecalc = :mediaRecalc", + Map.entry("#mediaRecalc", ATTR_MEDIA_USAGE_LAST_RECALCULATION), + Map.entry(":mediaRecalc", AttributeValues.n(clock.instant().getEpochSecond()))) + .expireDirectoryNames(secureRandom, expiredBackup.expirationType()) + .addRemoveExpression(Map.entry("#mediaRefresh", ATTR_LAST_MEDIA_REFRESH)) + .addSetExpression("#expiredPrefix = :expiredPrefix", + Map.entry("#expiredPrefix", ATTR_EXPIRED_PREFIX), + Map.entry(":expiredPrefix", AttributeValues.s(expiredBackup.prefixToDelete()))) + .withConditionExpression("attribute_not_exists(#expiredPrefix) OR #expiredPrefix = :expiredPrefix") + .updateItemBuilder() + .build()) .thenRun(Util.NOOP); } + /** + * Complete expiration of a backup started with {@link #startExpiration} + *

+ * If the expiration was for the entire backup, this will delete the entire item for the backup. + * + * @param expiredBackup The backup to expire + * @return A stage that completes when the expiration is marked as finished + */ + CompletableFuture finishExpiration(final ExpiredBackup expiredBackup) { + final byte[] hashedBackupId = expiredBackup.hashedBackupId(); + if (expiredBackup.expirationType() == ExpiredBackup.ExpirationType.ALL) { + final long expectedLastRefresh = expiredBackup.lastRefresh().getEpochSecond(); + return dynamoClient.deleteItem(DeleteItemRequest.builder() + .tableName(backupTableName) + .key(Map.of(KEY_BACKUP_ID_HASH, AttributeValues.b(hashedBackupId))) + .conditionExpression("#lastRefresh <= :expectedLastRefresh") + .expressionAttributeNames(Map.of("#lastRefresh", ATTR_LAST_REFRESH)) + .expressionAttributeValues(Map.of(":expectedLastRefresh", AttributeValues.n(expectedLastRefresh))) + .build()) + .thenRun(Util.NOOP); + } else { + return dynamoClient.updateItem(new UpdateBuilder(backupTableName, BackupTier.MEDIA, hashedBackupId) + .addRemoveExpression(Map.entry("#expiredPrefixes", ATTR_EXPIRED_PREFIX)) + .updateItemBuilder() + .build()) + .thenRun(Util.NOOP); + } + } + Flux getExpiredBackups(final int segments, final Scheduler scheduler, final Instant purgeTime) { if (segments < 1) { throw new IllegalArgumentException("Total number of segments must be positive"); @@ -303,10 +414,14 @@ public class BackupsDb { .expressionAttributeNames(Map.of( "#backupIdHash", KEY_BACKUP_ID_HASH, "#refresh", ATTR_LAST_REFRESH, - "#mediaRefresh", ATTR_LAST_MEDIA_REFRESH)) + "#mediaRefresh", ATTR_LAST_MEDIA_REFRESH, + "#backupDir", ATTR_BACKUP_DIR, + "#mediaDir", ATTR_MEDIA_DIR, + "#expiredPrefix", ATTR_EXPIRED_PREFIX)) .expressionAttributeValues(Map.of(":purgeTime", AttributeValues.n(purgeTime.getEpochSecond()))) - .projectionExpression("#backupIdHash, #refresh, #mediaRefresh") - .filterExpression("(#refresh < :purgeTime) OR (#mediaRefresh < :purgeTime)") + .projectionExpression("#backupIdHash, #refresh, #mediaRefresh, #backupDir, #mediaDir, #expiredPrefix") + .filterExpression( + "(#refresh < :purgeTime) OR (#mediaRefresh < :purgeTime) OR attribute_exists(#expiredPrefix)") .build()) .items()) .sequential() @@ -316,19 +431,60 @@ public class BackupsDb { if (hashedBackupId == null) { return null; } + final String backupDir = AttributeValues.getString(item, ATTR_BACKUP_DIR, null); + final String mediaDir = AttributeValues.getString(item, ATTR_MEDIA_DIR, null); + if (backupDir == null || mediaDir == null) { + // Could be the case for backups that have not yet set a public key + return null; + } final long lastRefresh = AttributeValues.getLong(item, ATTR_LAST_REFRESH, Long.MAX_VALUE); final long lastMediaRefresh = AttributeValues.getLong(item, ATTR_LAST_MEDIA_REFRESH, Long.MAX_VALUE); + final String existingExpiration = AttributeValues.getString(item, ATTR_EXPIRED_PREFIX, null); - if (lastRefresh < purgeTime.getEpochSecond()) { - return new ExpiredBackup(hashedBackupId, BackupTier.MESSAGES); + final ExpiredBackup expiredBackup; + if (existingExpiration != null) { + // If we have work from a failed previous expiration, handle that before worrying about any new expirations. + // This guarantees we won't accumulate expirations + expiredBackup = new ExpiredBackup(hashedBackupId, ExpiredBackup.ExpirationType.GARBAGE_COLLECTION, + Instant.ofEpochSecond(lastRefresh), existingExpiration); + } else if (lastRefresh < purgeTime.getEpochSecond()) { + // The whole backup was expired + expiredBackup = new ExpiredBackup(hashedBackupId, ExpiredBackup.ExpirationType.ALL, + Instant.ofEpochSecond(lastRefresh), backupDir); } else if (lastMediaRefresh < purgeTime.getEpochSecond()) { - return new ExpiredBackup(hashedBackupId, BackupTier.MEDIA); + // The media was expired + expiredBackup = new ExpiredBackup(hashedBackupId, ExpiredBackup.ExpirationType.MEDIA, + Instant.ofEpochSecond(lastRefresh), backupDir + "/" + mediaDir); } else { return null; } + + if (!isValid(expiredBackup)) { + logger.error("Not expiring backup {} for backupId {} with invalid cdn path prefixes", + HexFormat.of().formatHex(expiredBackup.hashedBackupId()), + expiredBackup); + return null; + } + return expiredBackup; }); } + /** + * Backup expiration will expire any prefix we tell it to, so confirm that the directory names that came out of the + * database have the correct shape before handing them off. + * + * @param expiredBackup The ExpiredBackup object to check + * @return Whether this is a valid expiration object + */ + private static boolean isValid(final ExpiredBackup expiredBackup) { + // expired prefixes should be of the form "backupDir" or "backupDir/mediaDir" + return switch (expiredBackup.expirationType()) { + case MEDIA -> expiredBackup.prefixToDelete().length() == MEDIA_DIRECTORY_PATH_LENGTH; + case ALL -> expiredBackup.prefixToDelete().length() == BACKUP_DIRECTORY_PATH_LENGTH; + case GARBAGE_COLLECTION -> expiredBackup.prefixToDelete().length() == MEDIA_DIRECTORY_PATH_LENGTH || + expiredBackup.prefixToDelete().length() == BACKUP_DIRECTORY_PATH_LENGTH; + }; + } /** * Build ddb update statements for the backups table @@ -431,6 +587,39 @@ public class BackupsDb { return this; } + UpdateBuilder setDirectoryNamesIfMissing(final SecureRandom secureRandom) { + final String backupDir = generateDirName(secureRandom); + final String mediaDir = generateDirName(secureRandom); + addSetExpression("#backupDir = if_not_exists(#backupDir, :backupDir)", + Map.entry("#backupDir", ATTR_BACKUP_DIR), + Map.entry(":backupDir", AttributeValues.s(backupDir))); + + addSetExpression("#mediaDir = if_not_exists(#mediaDir, :mediaDir)", + Map.entry("#mediaDir", ATTR_MEDIA_DIR), + Map.entry(":mediaDir", AttributeValues.s(mediaDir))); + return this; + } + + UpdateBuilder expireDirectoryNames( + final SecureRandom secureRandom, + final ExpiredBackup.ExpirationType expirationType) { + final String backupDir = generateDirName(secureRandom); + final String mediaDir = generateDirName(secureRandom); + return switch (expirationType) { + case GARBAGE_COLLECTION -> this; + case MEDIA -> this.addSetExpression("#mediaDir = :mediaDir", + Map.entry("#mediaDir", ATTR_MEDIA_DIR), + Map.entry(":mediaDir", AttributeValues.s(mediaDir))); + case ALL -> this + .addSetExpression("#mediaDir = :mediaDir", + Map.entry("#mediaDir", ATTR_MEDIA_DIR), + Map.entry(":mediaDir", AttributeValues.s(mediaDir))) + .addSetExpression("#backupDir = :backupDir", + Map.entry("#backupDir", ATTR_BACKUP_DIR), + Map.entry(":backupDir", AttributeValues.s(backupDir))); + }; + } + /** * Set the lastRefresh time as part of the update *

@@ -478,8 +667,10 @@ public class BackupsDb { .tableName(tableName) .key(Map.of(KEY_BACKUP_ID_HASH, AttributeValues.b(hashedBackupId))) .updateExpression(updateExpression()) - .expressionAttributeNames(attrNames) - .expressionAttributeValues(attrValues); + .expressionAttributeNames(attrNames); + if (!this.attrValues.isEmpty()) { + bldr.expressionAttributeValues(attrValues); + } if (this.conditionExpression != null) { bldr.conditionExpression(conditionExpression); } @@ -505,6 +696,11 @@ public class BackupsDb { } } + static String generateDirName(final SecureRandom secureRandom) { + final byte[] bytes = new byte[16]; + secureRandom.nextBytes(bytes); + return Base64.getUrlEncoder().withoutPadding().encodeToString(bytes); + } private static byte[] hashedBackupId(final AuthenticatedBackupUser backupId) { return hashedBackupId(backupId.backupId()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/backup/ExpiredBackup.java b/service/src/main/java/org/whispersystems/textsecuregcm/backup/ExpiredBackup.java index 1405b9c38..78ad972a8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/backup/ExpiredBackup.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/backup/ExpiredBackup.java @@ -4,4 +4,29 @@ */ package org.whispersystems.textsecuregcm.backup; -public record ExpiredBackup(byte[] hashedBackupId, BackupTier backupTierToRemove) {} +import java.time.Instant; +import java.util.List; + +/** + * Represents a backup that requires some or all of its content to be deleted + * + * @param hashedBackupId The hashedBackupId that owns this content + * @param expirationType What triggered the expiration + * @param lastRefresh The timestamp of the last time the backup user was seen + * @param prefixToDelete The prefix on the CDN associated with this backup that should be deleted + */ +public record ExpiredBackup( + byte[] hashedBackupId, + ExpirationType expirationType, + Instant lastRefresh, + String prefixToDelete) { + + public enum ExpirationType { + // The prefixToDelete expiration is for the entire backup + ALL, + // The prefixToDelete is for the media associated with the backup + MEDIA, + // The prefixToDelete is from a prior expiration attempt + GARBAGE_COLLECTION + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ArchiveController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ArchiveController.java index 81fb94af6..b7c24e403 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ArchiveController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ArchiveController.java @@ -232,16 +232,22 @@ public class ArchiveController { } public record BackupInfoResponse( - @Schema(description = "If present, the CDN type where the message backup is stored") + @Schema(description = "The CDN type where the message backup is stored. Media may be stored elsewhere.") int cdn, @Schema(description = """ - If present, the directory of your backup data on the cdn. The message backup can be found at /backupDir/backupName - and stored media can be found at /backupDir/media/mediaId. + The base directory of your backup data on the cdn. The message backup can be found in the returned cdn at + /backupDir/backupName and stored media can be found at /backupDir/mediaDir/mediaId """) String backupDir, - @Schema(description = "If present, the name of the most recent message backup on the cdn. The backup is at /backupDir/backupName") + @Schema(description = """ + The prefix path component for media objects on a cdn. Stored media for mediaId can be found at + /backupDir/mediaDir/mediaId. + """) + String mediaDir, + + @Schema(description = "The name of the most recent message backup on the cdn. The backup is at /backupDir/backupName") String backupName, @Nullable @@ -276,6 +282,7 @@ public class ArchiveController { .thenApply(backupInfo -> new BackupInfoResponse( backupInfo.cdn(), backupInfo.backupSubdir(), + backupInfo.mediaSubdir(), backupInfo.messageBackupKey(), backupInfo.mediaUsedSpace().orElse(null))); } @@ -641,6 +648,15 @@ public class ArchiveController { @Schema(description = "A page of media objects stored for this backup ID") List storedMediaObjects, + @Schema(description = """ + The base directory of your backup data on the cdn. The stored media can be found at /backupDir/mediaDir/mediaId + """) + String backupDir, + + @Schema(description = """ + The prefix path component for the media objects. The stored media for mediaId can be found at /backupDir/mediaDir/mediaId. + """) + String mediaDir, @Schema(description = "If set, the cursor value to pass to the next list request to continue listing. If absent, all objects have been listed") String cursor) {} @@ -679,12 +695,14 @@ public class ArchiveController { } return backupManager .authenticateBackupUser(presentation.presentation, signature.signature) - .thenCompose(backupUser -> backupManager.list(backupUser, cursor, limit.orElse(1000))) - .thenApply(result -> new ListResponse( - result.media() - .stream().map(entry -> new StoredMediaObject(entry.cdn(), entry.key(), entry.length())) - .toList(), - result.cursor().orElse(null))); + .thenCompose(backupUser ->backupManager.list(backupUser, cursor, limit.orElse(1000)) + .thenApply(result -> new ListResponse( + result.media() + .stream().map(entry -> new StoredMediaObject(entry.cdn(), entry.key(), entry.length())) + .toList(), + backupUser.backupDir(), + backupUser.mediaDir(), + result.cursor().orElse(null)))); } public record DeleteMedia(@Size(min = 1, max = 1000) List<@Valid MediaToDelete> mediaToDelete) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/ExceptionUtils.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/ExceptionUtils.java index dab21ec3b..99a728386 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/ExceptionUtils.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/ExceptionUtils.java @@ -44,8 +44,8 @@ public final class ExceptionUtils { * * @param exceptionType The class of exception that will be handled * @param fn A function that handles exceptions of type exceptionType - * @param The type of the stage that will be mapped - * @param The type of the exception that will be handled + * @param The type of the stage that will be mapped + * @param The type of the exception that will be handled * @return A function suitable for use with {@link java.util.concurrent.CompletionStage#exceptionally} */ public static Function exceptionallyHandler( @@ -62,4 +62,23 @@ public final class ExceptionUtils { throw wrap(anyException); }; } + + /** + * Create a handler suitable for use with {@link java.util.concurrent.CompletionStage#exceptionally} that converts + * exceptions of a specific type to another type. + * + * @param exceptionType The class of exception that will be handled + * @param fn A function that marshals exceptions of type E to type F + * @param The type of the stage that will be mapped + * @param The type of the exception that will be handled + * @param The type of the exception that will be produced + * @return A function suitable for use with {@link java.util.concurrent.CompletionStage#exceptionally} + */ + public static Function marshal( + final Class exceptionType, + final Function fn) { + return exceptionallyHandler(exceptionType, e -> { + throw wrap(fn.apply(e)); + }); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredBackupsCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredBackupsCommand.java index c58fada19..765aace05 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredBackupsCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/RemoveExpiredBackupsCommand.java @@ -12,6 +12,7 @@ import io.dropwizard.core.setup.Environment; import io.micrometer.core.instrument.Metrics; import java.time.Clock; import java.time.Duration; +import java.util.HexFormat; import java.util.Objects; import java.util.concurrent.atomic.AtomicLong; import net.sourceforge.argparse4j.inf.Namespace; @@ -142,19 +143,19 @@ public class RemoveExpiredBackupsCommand extends EnvironmentCommand - backupManager.deleteBackup(expiredBackup.backupTierToRemove(), expiredBackup.hashedBackupId())); + mono = Mono.fromCompletionStage(() -> backupManager.expireBackup(expiredBackup)); } return mono .doOnSuccess(ignored -> Metrics .counter(EXPIRED_BACKUPS_COUNTER_NAME, - "tier", expiredBackup.backupTierToRemove().name(), + "tier", expiredBackup.expirationType().name(), "dryRun", String.valueOf(dryRun)) .increment()) .onErrorResume(throwable -> { - logger.warn("Failed to remove tier {} for backup {}", expiredBackup.backupTierToRemove(), - expiredBackup.hashedBackupId()); + logger.warn("Failed to remove tier {} for backup {}", + expiredBackup.expirationType(), + HexFormat.of().formatHex(expiredBackup.hashedBackupId())); return Mono.empty(); }); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java index aa5c6affc..7171ff4de 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java @@ -29,6 +29,7 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; @@ -53,6 +54,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.ValueSource; +import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.zkgroup.VerificationFailedException; @@ -80,6 +82,7 @@ public class BackupManagerTest { private final RemoteStorageManager remoteStorageManager = mock(RemoteStorageManager.class); private final byte[] backupKey = TestRandomUtil.nextBytes(32); private final UUID aci = UUID.randomUUID(); + private final SecureRandom secureRandom = new SecureRandom(); private BackupManager backupManager; private BackupsDb backupsDb; @@ -109,14 +112,13 @@ public class BackupManagerTest { testClock.pin(now); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), backupTier); - final String encodedBackupId = Base64.getUrlEncoder().encodeToString(hashedBackupId(backupUser.backupId())); backupManager.createMessageBackupUploadDescriptor(backupUser).join(); verify(tusCredentialGenerator, times(1)) - .generateUpload("%s/%s".formatted(encodedBackupId, BackupManager.MESSAGE_BACKUP_NAME)); + .generateUpload("%s/%s".formatted(backupUser.backupDir(), BackupManager.MESSAGE_BACKUP_NAME)); final BackupManager.BackupInfo info = backupManager.backupInfo(backupUser).join(); - assertThat(info.backupSubdir()).isEqualTo(encodedBackupId); + assertThat(info.backupSubdir()).isEqualTo(backupUser.backupDir()).isNotBlank(); assertThat(info.messageBackupKey()).isEqualTo(BackupManager.MESSAGE_BACKUP_NAME); assertThat(info.mediaUsedSpace()).isEqualTo(Optional.empty()); @@ -344,9 +346,7 @@ public class BackupManagerTest { @Test public void quotaEnforcementRecalculation() { final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupTier.MEDIA); - final String backupMediaPrefix = "%s/%s/".formatted( - BackupManager.encodeBackupIdForCdn(backupUser), - BackupManager.MEDIA_DIRECTORY_NAME); + final String backupMediaPrefix = "%s/%s/".formatted(backupUser.backupDir(), backupUser.mediaDir()); // on recalculation, say there's actually 10 bytes left when(remoteStorageManager.calculateBytesUsed(eq(backupMediaPrefix))) @@ -383,9 +383,7 @@ public class BackupManagerTest { final long mediaToAddSize, boolean shouldAccept) { final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupTier.MEDIA); - final String backupMediaPrefix = "%s/%s/".formatted( - BackupManager.encodeBackupIdForCdn(backupUser), - BackupManager.MEDIA_DIRECTORY_NAME); + final String backupMediaPrefix = "%s/%s/".formatted(backupUser.backupDir(), backupUser.mediaDir()); // set the backupsDb to be out of quota at t=0 testClock.pin(Instant.ofEpochSecond(0)); @@ -410,9 +408,7 @@ public class BackupManagerTest { public void list(final String cursorVal) { final Optional cursor = Optional.of(cursorVal).filter(StringUtils::isNotBlank); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupTier.MEDIA); - final String backupMediaPrefix = "%s/%s/".formatted( - BackupManager.encodeBackupIdForCdn(backupUser), - BackupManager.MEDIA_DIRECTORY_NAME); + final String backupMediaPrefix = "%s/%s/".formatted(backupUser.backupDir(), backupUser.mediaDir()); when(remoteStorageManager.cdnNumber()).thenReturn(13); when(remoteStorageManager.list(eq(backupMediaPrefix), eq(cursor), eq(17L))) @@ -437,9 +433,9 @@ public class BackupManagerTest { final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupTier.MEDIA); final byte[] mediaId = TestRandomUtil.nextBytes(16); final String backupMediaKey = "%s/%s/%s".formatted( - BackupManager.encodeBackupIdForCdn(backupUser), - BackupManager.MEDIA_DIRECTORY_NAME, - BackupManager.encodeForCdn(mediaId)); + backupUser.backupDir(), + backupUser.mediaDir(), + BackupManager.encodeMediaIdForCdn(mediaId)); backupsDb.setMediaUsage(backupUser, new UsageInfo(100, 1000)).join(); @@ -474,9 +470,9 @@ public class BackupManagerTest { TestRandomUtil.nextBytes(15)); descriptors.add(descriptor); final String backupMediaKey = "%s/%s/%s".formatted( - BackupManager.encodeBackupIdForCdn(backupUser), - BackupManager.MEDIA_DIRECTORY_NAME, - BackupManager.encodeForCdn(descriptor.key())); + backupUser.backupDir(), + backupUser.mediaDir(), + BackupManager.encodeMediaIdForCdn(descriptor.key())); initialBytes += i; // fail 2 deletions, otherwise return the corresponding object's size as i @@ -501,9 +497,9 @@ public class BackupManagerTest { final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupTier.MEDIA); final byte[] mediaId = TestRandomUtil.nextBytes(16); final String backupMediaKey = "%s/%s/%s".formatted( - BackupManager.encodeBackupIdForCdn(backupUser), - BackupManager.MEDIA_DIRECTORY_NAME, - BackupManager.encodeForCdn(mediaId)); + backupUser.backupDir(), + backupUser.mediaDir(), + BackupManager.encodeMediaIdForCdn(mediaId)); backupsDb.setMediaUsage(backupUser, new UsageInfo(100, 5)).join(); @@ -545,7 +541,7 @@ public class BackupManagerTest { .map(ExpiredBackup::hashedBackupId) .map(ByteBuffer::wrap) .allMatch(expectedHashes::contains)).isTrue(); - assertThat(expired.stream().allMatch(eb -> eb.backupTierToRemove() == BackupTier.MESSAGES)).isTrue(); + assertThat(expired.stream().allMatch(eb -> eb.expirationType() == ExpiredBackup.ExpirationType.ALL)).isTrue(); // on next iteration, backup i should be expired expectedHashes.add(ByteBuffer.wrap(hashedBackupId(backupUsers.get(i).backupId()))); @@ -572,50 +568,53 @@ public class BackupManagerTest { assertThat(getExpired.apply(Instant.ofEpochSecond(6))) .hasSize(1).first() - .matches(eb -> eb.backupTierToRemove() == BackupTier.MEDIA, "is media tier"); + .matches(eb -> eb.expirationType() == ExpiredBackup.ExpirationType.MEDIA, "is media tier"); assertThat(getExpired.apply(Instant.ofEpochSecond(7))) .hasSize(1).first() - .matches(eb -> eb.backupTierToRemove() == BackupTier.MESSAGES, "is messages tier"); + .matches(eb -> eb.expirationType() == ExpiredBackup.ExpirationType.ALL, "is messages tier"); } @ParameterizedTest - @EnumSource(mode = EnumSource.Mode.INCLUDE, names = {"MESSAGES", "MEDIA"}) - public void deleteBackup(BackupTier backupTier) { + @EnumSource(mode = EnumSource.Mode.INCLUDE, names = {"MEDIA", "ALL"}) + public void expireBackup(ExpiredBackup.ExpirationType expirationType) { final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupTier.MEDIA); backupManager.createMessageBackupUploadDescriptor(backupUser).join(); - final String mediaPrefix = "%s/%s/" - .formatted(BackupManager.encodeBackupIdForCdn(backupUser), BackupManager.MEDIA_DIRECTORY_NAME); - when(remoteStorageManager.list(eq(mediaPrefix), eq(Optional.empty()), anyLong())) + + final String expectedPrefixToDelete = switch (expirationType) { + case ALL -> backupUser.backupDir(); + case MEDIA -> backupUser.backupDir() + "/" + backupUser.mediaDir(); + case GARBAGE_COLLECTION -> throw new IllegalArgumentException(); + } + "/"; + + when(remoteStorageManager.list(eq(expectedPrefixToDelete), eq(Optional.empty()), anyLong())) .thenReturn(CompletableFuture.completedFuture(new RemoteStorageManager.ListResult(List.of( new RemoteStorageManager.ListResult.Entry("abc", 1), new RemoteStorageManager.ListResult.Entry("def", 1), new RemoteStorageManager.ListResult.Entry("ghi", 1)), Optional.empty()))); when(remoteStorageManager.delete(anyString())).thenReturn(CompletableFuture.completedFuture(1L)); - backupManager.deleteBackup(backupTier, hashedBackupId(backupUser.backupId())).join(); + backupManager.expireBackup(expiredBackup(expirationType, backupUser)).join(); verify(remoteStorageManager, times(1)).list(anyString(), any(), anyLong()); - verify(remoteStorageManager, times(1)).delete(mediaPrefix + "abc"); - verify(remoteStorageManager, times(1)).delete(mediaPrefix + "def"); - verify(remoteStorageManager, times(1)).delete(mediaPrefix + "ghi"); - verify(remoteStorageManager, times(backupTier == BackupTier.MESSAGES ? 1 : 0)) - .delete("%s/%s".formatted(BackupManager.encodeBackupIdForCdn(backupUser), BackupManager.MESSAGE_BACKUP_NAME)); + verify(remoteStorageManager, times(1)).delete(expectedPrefixToDelete + "abc"); + verify(remoteStorageManager, times(1)).delete(expectedPrefixToDelete + "def"); + verify(remoteStorageManager, times(1)).delete(expectedPrefixToDelete + "ghi"); verifyNoMoreInteractions(remoteStorageManager); final BackupsDb.TimestampedUsageInfo usage = backupsDb.getMediaUsage(backupUser).join(); assertThat(usage.usageInfo().bytesUsed()).isEqualTo(0L); assertThat(usage.usageInfo().numObjects()).isEqualTo(0L); - if (backupTier == BackupTier.MEDIA) { - // should have deleted all the media, but left the backup descriptor in place - assertThatNoException().isThrownBy(() -> backupsDb.describeBackup(backupUser).join()); - } else { + if (expirationType == ExpiredBackup.ExpirationType.ALL) { // should have deleted the db row for the backup assertThat(CompletableFutureTestUtil.assertFailsWithCause( StatusRuntimeException.class, backupsDb.describeBackup(backupUser)) .getStatus().getCode()) .isEqualTo(Status.NOT_FOUND.getCode()); + } else { + // should have deleted all the media, but left the backup descriptor in place + assertThatNoException().isThrownBy(() -> backupsDb.describeBackup(backupUser).join()); } } @@ -623,8 +622,9 @@ public class BackupManagerTest { public void deleteBackupPaginated() { final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupTier.MEDIA); backupManager.createMessageBackupUploadDescriptor(backupUser).join(); - final String mediaPrefix = "%s/%s/".formatted(BackupManager.encodeBackupIdForCdn(backupUser), - BackupManager.MEDIA_DIRECTORY_NAME); + + final ExpiredBackup expiredBackup = expiredBackup(ExpiredBackup.ExpirationType.MEDIA, backupUser); + final String mediaPrefix = expiredBackup.prefixToDelete() + "/"; // Return 1 item per page. Initially the provided cursor is empty and we'll return the cursor string "1". // When we get the cursor "1", we'll return "2", when "2" we'll return empty indicating listing @@ -647,7 +647,7 @@ public class BackupManagerTest { })); }); when(remoteStorageManager.delete(anyString())).thenReturn(CompletableFuture.completedFuture(1L)); - backupManager.deleteBackup(BackupTier.MEDIA, hashedBackupId(backupUser.backupId())).join(); + backupManager.expireBackup(expiredBackup).join(); verify(remoteStorageManager, times(3)).list(anyString(), any(), anyLong()); verify(remoteStorageManager, times(1)).delete(mediaPrefix + "abc"); verify(remoteStorageManager, times(1)).delete(mediaPrefix + "def"); @@ -655,6 +655,19 @@ public class BackupManagerTest { verifyNoMoreInteractions(remoteStorageManager); } + private static ExpiredBackup expiredBackup(final ExpiredBackup.ExpirationType expirationType, + final AuthenticatedBackupUser backupUser) { + return new ExpiredBackup( + hashedBackupId(backupUser.backupId()), + expirationType, + Instant.now(), + switch (expirationType) { + case ALL -> backupUser.backupDir(); + case MEDIA -> backupUser.backupDir() + "/" + backupUser.mediaDir(); + case GARBAGE_COLLECTION -> null; + }); + } + private Map getBackupItem(final AuthenticatedBackupUser backupUser) { return DYNAMO_DB_EXTENSION.getDynamoDbClient().getItem(GetItemRequest.builder() .tableName(DynamoDbExtensionSchema.Tables.BACKUPS.tableName()) @@ -688,6 +701,14 @@ public class BackupManagerTest { } private AuthenticatedBackupUser backupUser(final byte[] backupId, final BackupTier backupTier) { - return new AuthenticatedBackupUser(backupId, backupTier); + // Won't actually validate the public key, but need to have a public key to perform BackupsDB operations + byte[] privateKey = new byte[32]; + ByteBuffer.wrap(privateKey).put(backupId); + try { + backupsDb.setPublicKey(backupId, backupTier, Curve.decodePrivatePoint(privateKey).publicKey()).join(); + } catch (InvalidKeyException e) { + throw new RuntimeException(e); + } + return new AuthenticatedBackupUser(backupId, backupTier, BackupsDb.generateDirName(secureRandom), BackupsDb.generateDirName(secureRandom)); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupsDbTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupsDbTest.java index e91245bae..a8e2ba5a0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupsDbTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupsDbTest.java @@ -6,21 +6,15 @@ package org.whispersystems.textsecuregcm.backup; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; - -import java.nio.charset.StandardCharsets; -import java.security.MessageDigest; -import java.security.NoSuchAlgorithmException; -import java.time.Instant; -import java.util.Arrays; -import java.util.List; -import java.util.function.Function; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.ValueSource; +import org.signal.libsignal.protocol.ecc.Curve; import org.whispersystems.textsecuregcm.auth.AuthenticatedBackupUser; import org.whispersystems.textsecuregcm.storage.DynamoDbExtension; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema; @@ -28,6 +22,12 @@ import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.TestRandomUtil; import reactor.core.scheduler.Schedulers; +import java.time.Instant; +import java.util.List; +import java.util.Optional; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThat; public class BackupsDbTest { @@ -75,7 +75,7 @@ public class BackupsDbTest { if (mediaAlreadyExists) { this.backupsDb.trackMedia(backupUser, 1, 10).join(); } - backupsDb.setMediaUsage(backupUser, new UsageInfo( 113, 17)).join(); + backupsDb.setMediaUsage(backupUser, new UsageInfo(113, 17)).join(); final BackupsDb.TimestampedUsageInfo info = backupsDb.getMediaUsage(backupUser).join(); assertThat(info.lastRecalculationTime()).isEqualTo(Instant.ofEpochSecond(5)); assertThat(info.usageInfo().bytesUsed()).isEqualTo(113L); @@ -87,6 +87,7 @@ public class BackupsDbTest { final byte[] backupId = TestRandomUtil.nextBytes(16); // Refresh media/messages at t=0 testClock.pin(Instant.ofEpochSecond(0L)); + backupsDb.setPublicKey(backupId, BackupTier.MEDIA, Curve.generateKeyPair().getPublicKey()).join(); this.backupsDb.ttlRefresh(backupUser(backupId, BackupTier.MEDIA)).join(); // refresh only messages at t=2 @@ -100,10 +101,11 @@ public class BackupsDbTest { List expired = expiredBackups.apply(Instant.ofEpochSecond(1)); assertThat(expired).hasSize(1).first() - .matches(eb -> eb.backupTierToRemove() == BackupTier.MEDIA); + .matches(eb -> eb.expirationType() == ExpiredBackup.ExpirationType.MEDIA); // Expire the media - backupsDb.clearMediaUsage(expired.get(0).hashedBackupId()).join(); + backupsDb.startExpiration(expired.get(0)).join(); + backupsDb.finishExpiration(expired.get(0)).join(); // should be nothing to expire at t=1 assertThat(expiredBackups.apply(Instant.ofEpochSecond(1))).isEmpty(); @@ -111,16 +113,100 @@ public class BackupsDbTest { // at t=3, should now expire messages as well expired = expiredBackups.apply(Instant.ofEpochSecond(3)); assertThat(expired).hasSize(1).first() - .matches(eb -> eb.backupTierToRemove() == BackupTier.MESSAGES); + .matches(eb -> eb.expirationType() == ExpiredBackup.ExpirationType.ALL); // Expire the messages - backupsDb.deleteBackup(expired.get(0).hashedBackupId()).join(); + backupsDb.startExpiration(expired.get(0)).join(); + backupsDb.finishExpiration(expired.get(0)).join(); // should be nothing to expire at t=3 assertThat(expiredBackups.apply(Instant.ofEpochSecond(3))).isEmpty(); } + @ParameterizedTest + @EnumSource(names = {"MEDIA", "ALL"}) + public void expirationFailed(ExpiredBackup.ExpirationType expirationType) { + final byte[] backupId = TestRandomUtil.nextBytes(16); + // Refresh media/messages at t=0 + testClock.pin(Instant.ofEpochSecond(0L)); + backupsDb.setPublicKey(backupId, BackupTier.MEDIA, Curve.generateKeyPair().getPublicKey()).join(); + this.backupsDb.ttlRefresh(backupUser(backupId, BackupTier.MEDIA)).join(); + + if (expirationType == ExpiredBackup.ExpirationType.MEDIA) { + // refresh only messages at t=2 so that we only expire media at t=1 + testClock.pin(Instant.ofEpochSecond(2L)); + this.backupsDb.ttlRefresh(backupUser(backupId, BackupTier.MESSAGES)).join(); + } + + final Function> expiredBackups = purgeTime -> { + final List res = backupsDb + .getExpiredBackups(1, Schedulers.immediate(), purgeTime) + .collectList() + .block(); + assertThat(res).hasSizeLessThanOrEqualTo(1); + return res.stream().findFirst(); + }; + + BackupsDb.AuthenticationData info = backupsDb.retrieveAuthenticationData(backupId).join().get(); + final String originalBackupDir = info.backupDir(); + final String originalMediaDir = info.mediaDir(); + + ExpiredBackup expired = expiredBackups.apply(Instant.ofEpochSecond(1)).get(); + assertThat(expired).matches(eb -> eb.expirationType() == expirationType); + + // expire but fail (don't call finishExpiration) + backupsDb.startExpiration(expired).join(); + info = backupsDb.retrieveAuthenticationData(backupId).join().get(); + if (expirationType == ExpiredBackup.ExpirationType.MEDIA) { + // Media expiration should swap the media name and keep the backup name, marking the old media name for expiration + assertThat(expired.prefixToDelete()) + .isEqualTo(originalBackupDir + "/" + originalMediaDir) + .withFailMessage("Should expire media directory, expired %s", expired.prefixToDelete()); + assertThat(info.backupDir()).isEqualTo(originalBackupDir).withFailMessage("should keep backupDir"); + assertThat(info.mediaDir()).isNotEqualTo(originalMediaDir).withFailMessage("should change mediaDir"); + } else { + // Full expiration should swap the media name and the backup name, marking the old backup name for expiration + assertThat(expired.prefixToDelete()) + .isEqualTo(originalBackupDir) + .withFailMessage("Should expire whole backupDir, expired %s", expired.prefixToDelete()); + assertThat(info.backupDir()).isNotEqualTo(originalBackupDir).withFailMessage("should change backupDir"); + assertThat(info.mediaDir()).isNotEqualTo(originalMediaDir).withFailMessage("should change mediaDir"); + } + final String expiredPrefix = expired.prefixToDelete(); + + // We failed, so we should see the same prefix on the next expiration listing + expired = expiredBackups.apply(Instant.ofEpochSecond(1)).get(); + assertThat(expired).matches(eb -> eb.expirationType() == ExpiredBackup.ExpirationType.GARBAGE_COLLECTION, + "Expiration should be garbage collection "); + assertThat(expired.prefixToDelete()).isEqualTo(expiredPrefix); + backupsDb.startExpiration(expired).join(); + + // Successfully finish the expiration + backupsDb.finishExpiration(expired).join(); + + Optional opt = expiredBackups.apply(Instant.ofEpochSecond(1)); + if (expirationType == ExpiredBackup.ExpirationType.MEDIA) { + // should be nothing to expire at t=1 + assertThat(opt).isEmpty(); + // The backup should still exist + backupsDb.describeBackup(backupUser(backupId, BackupTier.MEDIA)).join(); + } else { + // Cleaned up the failed attempt, now should tell us to clean the whole backup + assertThat(opt.get()).matches(eb -> eb.expirationType() == ExpiredBackup.ExpirationType.ALL, + "Expiration should be all "); + backupsDb.startExpiration(opt.get()).join(); + backupsDb.finishExpiration(opt.get()).join(); + + // The backup entry should be gone + assertThat(CompletableFutureTestUtil.assertFailsWithCause(StatusRuntimeException.class, + backupsDb.describeBackup(backupUser(backupId, BackupTier.MEDIA))) + .getStatus().getCode()) + .isEqualTo(Status.Code.NOT_FOUND); + assertThat(expiredBackups.apply(Instant.ofEpochSecond(10))).isEmpty(); + } + } + private AuthenticatedBackupUser backupUser(final byte[] backupId, final BackupTier backupTier) { - return new AuthenticatedBackupUser(backupId, backupTier); + return new AuthenticatedBackupUser(backupId, backupTier, "myBackupDir", "myMediaDir"); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ArchiveControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ArchiveControllerTest.java index 663bfe472..0991724da 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ArchiveControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ArchiveControllerTest.java @@ -36,6 +36,7 @@ import javax.ws.rs.client.Invocation; import javax.ws.rs.client.WebTarget; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; +import org.checkerframework.checker.units.qual.A; import org.glassfish.jersey.server.ServerProperties; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.BeforeEach; @@ -289,17 +290,16 @@ public class ArchiveControllerTest { final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( BackupTier.MEDIA, backupKey, aci); when(backupManager.authenticateBackupUser(any(), any())) - .thenReturn(CompletableFuture.completedFuture( - new AuthenticatedBackupUser(presentation.getBackupId(), BackupTier.MEDIA))); + .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupTier.MEDIA))); when(backupManager.backupInfo(any())).thenReturn(CompletableFuture.completedFuture(new BackupManager.BackupInfo( - 1, "subdir", "filename", Optional.empty()))); + 1, "myBackupDir", "myMediaDir", "filename", Optional.empty()))); final ArchiveController.BackupInfoResponse response = resources.getJerseyTest() .target("v1/archives") .request() .header("X-Signal-ZK-Auth", Base64.getEncoder().encodeToString(presentation.serialize())) .header("X-Signal-ZK-Auth-Signature", "aaa") .get(ArchiveController.BackupInfoResponse.class); - assertThat(response.backupDir()).isEqualTo("subdir"); + assertThat(response.backupDir()).isEqualTo("myBackupDir"); assertThat(response.backupName()).isEqualTo("filename"); assertThat(response.cdn()).isEqualTo(1); assertThat(response.usedSpace()).isNull(); @@ -310,8 +310,7 @@ public class ArchiveControllerTest { final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( BackupTier.MEDIA, backupKey, aci); when(backupManager.authenticateBackupUser(any(), any())) - .thenReturn(CompletableFuture.completedFuture( - new AuthenticatedBackupUser(presentation.getBackupId(), BackupTier.MEDIA))); + .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupTier.MEDIA))); when(backupManager.canStoreMedia(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(true)); when(backupManager.copyToBackup(any(), anyInt(), any(), anyInt(), any(), any())) .thenAnswer(invocation -> { @@ -361,8 +360,7 @@ public class ArchiveControllerTest { final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( BackupTier.MEDIA, backupKey, aci); when(backupManager.authenticateBackupUser(any(), any())) - .thenReturn(CompletableFuture.completedFuture( - new AuthenticatedBackupUser(presentation.getBackupId(), BackupTier.MEDIA))); + .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupTier.MEDIA))); final byte[][] mediaIds = IntStream.range(0, 3).mapToObj(i -> TestRandomUtil.nextBytes(15)).toArray(byte[][]::new); when(backupManager.canStoreMedia(any(), anyLong())).thenReturn(CompletableFuture.completedFuture(true)); @@ -417,8 +415,7 @@ public class ArchiveControllerTest { final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( BackupTier.MEDIA, backupKey, aci); when(backupManager.authenticateBackupUser(any(), any())) - .thenReturn(CompletableFuture.completedFuture( - new AuthenticatedBackupUser(presentation.getBackupId(), BackupTier.MEDIA))); + .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupTier.MEDIA))); when(backupManager.canStoreMedia(any(), eq(1L + 2L + 3L))) .thenReturn(CompletableFuture.completedFuture(false)); @@ -448,8 +445,7 @@ public class ArchiveControllerTest { final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( BackupTier.MEDIA, backupKey, aci); when(backupManager.authenticateBackupUser(any(), any())) - .thenReturn(CompletableFuture.completedFuture( - new AuthenticatedBackupUser(presentation.getBackupId(), BackupTier.MEDIA))); + .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupTier.MEDIA))); final byte[] mediaId = TestRandomUtil.nextBytes(15); final Optional expectedCursor = cursorProvided ? Optional.of("myCursor") : Optional.empty(); @@ -484,8 +480,7 @@ public class ArchiveControllerTest { final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation(BackupTier.MEDIA, backupKey, aci); when(backupManager.authenticateBackupUser(any(), any())) - .thenReturn(CompletableFuture.completedFuture( - new AuthenticatedBackupUser(presentation.getBackupId(), BackupTier.MEDIA))); + .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupTier.MEDIA))); final ArchiveController.DeleteMedia deleteRequest = new ArchiveController.DeleteMedia( IntStream @@ -503,4 +498,8 @@ public class ArchiveControllerTest { .post(Entity.json(deleteRequest)); assertThat(response.getStatus()).isEqualTo(204); } + + private static AuthenticatedBackupUser backupUser(byte[] backupId, BackupTier backupTier) { + return new AuthenticatedBackupUser(backupId, backupTier, "myBackupDir", "myMediaDir"); + } }