From 030d8e8dd486f8b09438ae23f9e0e197b322f6d5 Mon Sep 17 00:00:00 2001 From: ravi-signal <99042880+ravi-signal@users.noreply.github.com> Date: Wed, 28 May 2025 15:25:32 -0500 Subject: [PATCH] Reduce drift between tracked and actual backup usage --- .../textsecuregcm/backup/BackupManager.java | 157 +++++++++++------- .../backup/Cdn3RemoteStorageManager.java | 17 +- .../backup/BackupManagerTest.java | 104 ++++++++++++ 3 files changed, 217 insertions(+), 61 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupManager.java index 6d35ac2f5..c71b1e28b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupManager.java @@ -21,7 +21,6 @@ import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; -import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECPublicKey; @@ -30,6 +29,8 @@ import org.signal.libsignal.zkgroup.VerificationFailedException; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialPresentation; import org.signal.libsignal.zkgroup.backups.BackupCredentialType; import org.signal.libsignal.zkgroup.backups.BackupLevel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.attachments.AttachmentGenerator; import org.whispersystems.textsecuregcm.attachments.TusAttachmentGenerator; import org.whispersystems.textsecuregcm.auth.AuthenticatedBackupUser; @@ -57,6 +58,10 @@ public class BackupManager { // How many cdn object copy requests can be outstanding at a time per batch copy-to-backup operation private static final int COPY_CONCURRENCY = 10; + // How often we should persist the current usage + @VisibleForTesting + static int USAGE_CHECKPOINT_COUNT = 10; + private static final String ZK_AUTHN_COUNTER_NAME = MetricsUtil.name(BackupManager.class, "authentication"); private static final String ZK_AUTHZ_FAILURE_COUNTER_NAME = MetricsUtil.name(BackupManager.class, @@ -71,6 +76,8 @@ public class BackupManager { private static final String SUCCESS_TAG_NAME = "success"; private static final String FAILURE_REASON_TAG_NAME = "reason"; + private static final Logger log = LoggerFactory.getLogger(BackupManager.class); + private final BackupsDb backupsDb; private final GenericServerSecretParams serverSecretParams; private final RateLimiters rateLimiters; @@ -214,29 +221,39 @@ public class BackupManager { checkBackupLevel(backupUser, BackupLevel.PAID); checkBackupCredentialType(backupUser, BackupCredentialType.MEDIA); - return Mono - // Figure out how many objects we're allowed to copy, updating the quota usage for the amount we are allowed - .fromFuture(enforceQuota(backupUser, toCopy)) - - // Copy the ones we have enough quota to hold + return Mono.fromFuture(() -> allowedCopies(backupUser, toCopy)) .flatMapMany(quotaResult -> Flux.concat( - // These fit in our remaining quota, so perform the copy. If the copy fails, our estimated quota usage may not - // be exact since we already updated our usage. We make a best-effort attempt to undo the usage update if we - // know that the copied failed for sure though. - Flux.fromIterable(quotaResult.requestsToCopy()).flatMapSequential( - copyParams -> copyToBackup(backupUser, copyParams) - .flatMap(copyResult -> switch (copyResult.outcome()) { - case SUCCESS -> Mono.just(copyResult); - case SOURCE_WRONG_LENGTH, SOURCE_NOT_FOUND, OUT_OF_QUOTA -> Mono - .fromFuture(this.backupsDb.trackMedia(backupUser, -1, -copyParams.destinationObjectSize())) - .thenReturn(copyResult); - }), - COPY_CONCURRENCY), + // Perform copies for requests that fit in our quota, first updating the usage. If the copy fails, our + // estimated quota usage may not be exact since we update usage first. We make a best-effort attempt + // to undo the usage update if we know that the copied failed for sure. + Flux.fromIterable(quotaResult.requestsToCopy()) + + // Update the usage in reasonable chunk sizes to bound how out of sync our claimed and actual usage gets + .buffer(USAGE_CHECKPOINT_COUNT) + .concatMap(copyParameters -> { + final long quotaToConsume = copyParameters.stream() + .mapToLong(CopyParameters::destinationObjectSize) + .sum(); + return Mono + .fromFuture(backupsDb.trackMedia(backupUser, copyParameters.size(), quotaToConsume)) + .thenMany(Flux.fromIterable(copyParameters)); + }) + + // Actually perform the copies now that we've updated the quota + .flatMapSequential(copyParams -> copyToBackup(backupUser, copyParams) + .flatMap(copyResult -> switch (copyResult.outcome()) { + case SUCCESS -> Mono.just(copyResult); + case SOURCE_WRONG_LENGTH, SOURCE_NOT_FOUND, OUT_OF_QUOTA -> Mono + .fromFuture(this.backupsDb.trackMedia(backupUser, -1, -copyParams.destinationObjectSize())) + .thenReturn(copyResult); + }), + COPY_CONCURRENCY, 1), // There wasn't enough quota remaining to perform these copies Flux.fromIterable(quotaResult.requestsToReject()) - .map(arg -> new CopyResult(CopyResult.Outcome.OUT_OF_QUOTA, arg.destinationMediaId(), null)))); + .map(arg -> new CopyResult(CopyResult.Outcome.OUT_OF_QUOTA, arg.destinationMediaId(), null)) + )); } private Mono copyToBackup(final AuthenticatedBackupUser backupUser, final CopyParameters copyParameters) { @@ -262,15 +279,14 @@ public class BackupManager { private record QuotaResult(List requestsToCopy, List requestsToReject) {} /** - * Determine which copy requests can be performed with the user's remaining quota and update the used quota. If a copy - * request subsequently fails, the caller should attempt to restore the quota for the failed copy. + * Determine which copy requests can be performed with the user's remaining quota. This does not update the quota. * - * @param backupUser The user quota to update + * @param backupUser The user quota to check against * @param toCopy The proposed copy requests - * @return QuotaResult indicating which requests fit into the remaining quota and which requests should be rejected - * with {@link CopyResult.Outcome#OUT_OF_QUOTA} + * @return list of QuotaResult indicating which requests fit into the remaining quota and which requests should be + * rejected with {@link CopyResult.Outcome#OUT_OF_QUOTA} */ - private CompletableFuture enforceQuota( + private CompletableFuture allowedCopies( final AuthenticatedBackupUser backupUser, final List toCopy) { final long totalBytesAdded = toCopy.stream() @@ -305,22 +321,11 @@ public class BackupManager { }) .thenApply(newUsage -> MAX_TOTAL_BACKUP_MEDIA_BYTES - newUsage.bytesUsed()); }) - .thenCompose(remainingQuota -> { + .thenApply(remainingQuota -> { // Figure out how many of the requested objects fit in the remaining quota final int index = indexWhereTotalExceeds(toCopy, CopyParameters::destinationObjectSize, remainingQuota); - final QuotaResult result = new QuotaResult(toCopy.subList(0, index), - toCopy.subList(index, toCopy.size())); - if (index == 0) { - // Skip the usage update if we're not able to write anything - return CompletableFuture.completedFuture(result); - } - - // Update the usage - final long quotaToConsume = result.requestsToCopy.stream() - .mapToLong(CopyParameters::destinationObjectSize) - .sum(); - return backupsDb.trackMedia(backupUser, index, quotaToConsume).thenApply(ignored -> result); + return new QuotaResult(toCopy.subList(0, index), toCopy.subList(index, toCopy.size())); }); } @@ -422,45 +427,79 @@ public class BackupManager { return Flux.usingWhen( - // Gather usage updates into the UsageBatcher to apply during the cleanup operation + // Gather usage updates into the UsageBatcher so we don't have to update our backup record on every delete Mono.just(new UsageBatcher()), // Deletes the objects, returning their former location. Tracks bytes removed so the quota can be updated on // completion batcher -> Flux.fromIterable(storageDescriptors) - .flatMapSequential(sd -> Mono - // Delete the object - .fromCompletionStage(remoteStorageManager.delete(cdnMediaPath(backupUser, sd.key()))) - // Track how much the remote storage manager indicated was deleted as part of the operation - .doOnNext(deletedBytes -> batcher.update(-deletedBytes)) - .thenReturn(sd), DELETION_CONCURRENCY), - // On cleanup, update the quota using whatever updates were accumulated in the batcher - batcher -> - Mono.fromFuture(backupsDb.trackMedia(backupUser, batcher.countDelta.get(), batcher.usageDelta.get()))); + // Delete the objects, allowing DELETION_CONCURRENCY operations out at a time + .flatMapSequential( + sd -> Mono.fromCompletionStage(remoteStorageManager.delete(cdnMediaPath(backupUser, sd.key()))), + DELETION_CONCURRENCY) + .zipWithIterable(storageDescriptors) + + // Track how much the remote storage manager indicated was deleted as part of the operation + .concatMap(deletedBytesAndStorageDescriptor -> { + final long deletedBytes = deletedBytesAndStorageDescriptor.getT1(); + final StorageDescriptor sd = deletedBytesAndStorageDescriptor.getT2(); + + // If it has been a while, perform a checkpoint to make sure our usage doesn't drift too much + if (batcher.update(-deletedBytes)) { + final UsageBatcher.UsageUpdate usageUpdate = batcher.getAndReset(); + return Mono + .fromFuture(backupsDb.trackMedia(backupUser, usageUpdate.countDelta, usageUpdate.bytesDelta)) + .doOnError(throwable -> + log.warn("Failed to update delta {} after successful delete operation", usageUpdate, throwable)) + .thenReturn(sd); + } else { + return Mono.just(sd); + } + }), + + // On cleanup, update the quota using whatever remaining updates were accumulated in the batcher + batcher -> { + final UsageBatcher.UsageUpdate update = batcher.getAndReset(); + return Mono + .fromFuture(backupsDb.trackMedia(backupUser, update.countDelta, update.bytesDelta)) + .doOnError(throwable -> + log.warn("Failed to update delta {} after successful delete operation", update, throwable)); + }); } /** - * Track pending media usage updates + * Track pending media usage updates. Not thread safe! */ private static class UsageBatcher { - AtomicLong countDelta = new AtomicLong(); - AtomicLong usageDelta = new AtomicLong(); + private long runningCountDelta = 0; + private long runningBytesDelta = 0; + + record UsageUpdate(long countDelta, long bytesDelta) {} /** - * Stage a usage update that will be applied later + * Stage a usage update. Returns true when it is time to make a checkpoint * * @param bytesDelta The amount of bytes that should be tracked as used (or if negative, freed). If the delta is * non-zero, the count will also be updated. + * @return true if we should persist the usage */ - void update(long bytesDelta) { - if (bytesDelta < 0) { - countDelta.decrementAndGet(); - } else if (bytesDelta > 0) { - countDelta.incrementAndGet(); - } - usageDelta.addAndGet(bytesDelta); + boolean update(long bytesDelta) { + this.runningCountDelta += Long.signum(bytesDelta); + this.runningBytesDelta += bytesDelta; + return Math.abs(runningCountDelta) >= USAGE_CHECKPOINT_COUNT; + } + + /** + * Get the current usage delta, and set the delta to 0 + * @return A {@link UsageUpdate} to apply + */ + UsageUpdate getAndReset() { + final UsageUpdate update = new UsageUpdate(runningCountDelta, runningBytesDelta); + runningCountDelta = 0; + runningBytesDelta = 0; + return update; } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/backup/Cdn3RemoteStorageManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/backup/Cdn3RemoteStorageManager.java index 1df40636d..725dc383a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/backup/Cdn3RemoteStorageManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/backup/Cdn3RemoteStorageManager.java @@ -54,6 +54,8 @@ public class Cdn3RemoteStorageManager implements RemoteStorageManager { private static final String OPERATION_TAG_NAME = "op"; private static final String STATUS_TAG_NAME = "status"; + private static final String OBJECT_REMOVED_ON_DELETE_COUNTER_NAME = MetricsUtil.name(Cdn3RemoteStorageManager.class, "objectRemovedOnDelete"); + public Cdn3RemoteStorageManager( final ExecutorService httpExecutor, final ScheduledExecutorService retryExecutor, @@ -111,6 +113,10 @@ public class Cdn3RemoteStorageManager implements RemoteStorageManager { .build(); return this.storageManagerHttpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()) .thenAccept(response -> { + Metrics.counter(STORAGE_MANAGER_STATUS_COUNTER_NAME, + OPERATION_TAG_NAME, "copy", + STATUS_TAG_NAME, Integer.toString(response.statusCode())) + .increment(); if (response.statusCode() == Response.Status.NOT_FOUND.getStatusCode()) { throw ExceptionUtils.wrap(new SourceObjectNotFoundException()); } else if (response.statusCode() == Response.Status.CONFLICT.getStatusCode()) { @@ -259,6 +265,7 @@ public class Cdn3RemoteStorageManager implements RemoteStorageManager { record DeleteResponse(@NotNull long bytesDeleted) {} public CompletionStage delete(final String key) { + final Timer.Sample sample = Timer.start(); final HttpRequest request = HttpRequest.newBuilder().DELETE() .uri(URI.create(deleteUrl(key))) .header(CLIENT_ID_HEADER, clientId) @@ -271,11 +278,17 @@ public class Cdn3RemoteStorageManager implements RemoteStorageManager { STATUS_TAG_NAME, Integer.toString(response.statusCode())) .increment(); try { - return parseDeleteResponse(response); + long bytesDeleted = parseDeleteResponse(response); + Metrics.counter(OBJECT_REMOVED_ON_DELETE_COUNTER_NAME, + "removed", Boolean.toString(bytesDeleted > 0)) + .increment(); + return bytesDeleted; } catch (IOException e) { throw ExceptionUtils.wrap(e); } - }); + }) + .whenComplete((ignored, ignoredException) -> + sample.stop(Metrics.timer(STORAGE_MANAGER_TIMER_NAME, OPERATION_TAG_NAME, "delete"))); } private long parseDeleteResponse(final HttpResponse httpDeleteResponse) throws IOException { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java index f5a510a50..5454a5cef 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java @@ -40,10 +40,16 @@ import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.UUID; +import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.function.Function; +import java.util.stream.Collectors; import java.util.stream.IntStream; import javax.annotation.Nullable; +import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.StringUtils; import org.assertj.core.api.ThrowableAssert; import org.junit.jupiter.api.BeforeEach; @@ -73,6 +79,7 @@ import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.TestRandomUtil; +import reactor.core.publisher.Flux; import reactor.core.scheduler.Schedulers; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; @@ -451,6 +458,54 @@ public class BackupManagerTest { assertThat(mediaCount).isEqualTo(1); } + @Test + public void copyUsageCheckpoints() throws InterruptedException { + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID); + backupsDb.setMediaUsage(backupUser, new UsageInfo(0, 0)).join(); + + final List sourceKeys = IntStream.range(0, 50) + .mapToObj(ignore -> RandomStringUtils.insecure().nextAlphanumeric(10)) + .toList(); + final List toCopy = sourceKeys.stream() + .map(source -> new CopyParameters(3, source, 100, COPY_ENCRYPTION_PARAM, TestRandomUtil.nextBytes(15))) + .toList(); + + final int slowIndex = BackupManager.USAGE_CHECKPOINT_COUNT - 1; + final CompletableFuture slow = new CompletableFuture<>(); + when(remoteStorageManager.copy(eq(3), anyString(), eq(100), any(), anyString())) + .thenReturn(CompletableFuture.completedFuture(null)); + when(remoteStorageManager.copy(eq(3), eq(sourceKeys.get(slowIndex)), eq(100), any(), anyString())) + .thenReturn(slow); + final ArrayBlockingQueue copyResults = new ArrayBlockingQueue<>(100); + final CompletableFuture future = backupManager + .copyToBackup(backupUser, toCopy) + .doOnNext(copyResults::add).then().toFuture(); + + for (int i = 0; i < slowIndex; i++) { + assertThat(copyResults.poll(1, TimeUnit.SECONDS)).isNotNull(); + } + + // Copying can start on the next batch of USAGE_CHECKPOINT_COUNT before the current one is done, so we should see + // at least one usage update, and at most 2 + final UsageInfo usage = backupsDb.getMediaUsage(backupUser).join().usageInfo(); + final long bytesPerObject = COPY_ENCRYPTION_PARAM.outputSize(100); + assertThat(backupsDb.getMediaUsage(backupUser).join().usageInfo()).isIn( + new UsageInfo( + bytesPerObject * BackupManager.USAGE_CHECKPOINT_COUNT, + BackupManager.USAGE_CHECKPOINT_COUNT), + new UsageInfo( + 2 * bytesPerObject * BackupManager.USAGE_CHECKPOINT_COUNT, + 2 * BackupManager.USAGE_CHECKPOINT_COUNT)); + + // We should still be waiting since we have a slow delete + assertThat(future).isNotDone(); + + slow.complete(null); + future.join(); + assertThat(backupsDb.getMediaUsage(backupUser).join().usageInfo()) + .isEqualTo(new UsageInfo(bytesPerObject * 50, 50)); + } + @Test public void copyFailure() { final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID); @@ -689,6 +744,55 @@ public class BackupManagerTest { .matches(e -> ((StatusRuntimeException) e).getStatus().getCode() == Status.INVALID_ARGUMENT.getCode()); } + @Test + public void deleteUsageCheckpoints() throws InterruptedException { + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, + BackupLevel.PAID); + + // 100 objects, each 2 bytes large + final List mediaIds = IntStream.range(0, 100).mapToObj(ig -> TestRandomUtil.nextBytes(16)).toList(); + backupsDb.setMediaUsage(backupUser, new UsageInfo(200, 100)).join(); + + // One object is slow to delete + final CompletableFuture slowFuture = new CompletableFuture<>(); + final String slowMediaKey = "%s/%s/%s".formatted( + backupUser.backupDir(), + backupUser.mediaDir(), + BackupManager.encodeMediaIdForCdn(mediaIds.get(BackupManager.USAGE_CHECKPOINT_COUNT + 3))); + + when(remoteStorageManager.delete(anyString())).thenReturn(CompletableFuture.completedFuture(2L)); + when(remoteStorageManager.delete(slowMediaKey)).thenReturn(slowFuture); + when(remoteStorageManager.cdnNumber()).thenReturn(5); + + + final Flux flux = backupManager.deleteMedia(backupUser, + mediaIds.stream() + .map(i -> new BackupManager.StorageDescriptor(5, i)) + .toList()); + final ArrayBlockingQueue sds = new ArrayBlockingQueue<>(100); + final CompletableFuture future = flux.doOnNext(sds::add).then().toFuture(); + for (int i = 0; i < BackupManager.USAGE_CHECKPOINT_COUNT; i++) { + sds.poll(1, TimeUnit.SECONDS); + } + + assertThat(backupsDb.getMediaUsage(backupUser).join().usageInfo()) + .isEqualTo(new UsageInfo( + 200 - (2 * BackupManager.USAGE_CHECKPOINT_COUNT), + 100 - BackupManager.USAGE_CHECKPOINT_COUNT)); + // We should still be waiting since we have a slow delete + assertThat(future).isNotDone(); + // But we should checkpoint the usage periodically + assertThat(backupsDb.getMediaUsage(backupUser).join().usageInfo()) + .isEqualTo(new UsageInfo( + 200 - (2 * BackupManager.USAGE_CHECKPOINT_COUNT), + 100 - BackupManager.USAGE_CHECKPOINT_COUNT)); + + slowFuture.complete(2L); + future.join(); + assertThat(backupsDb.getMediaUsage(backupUser).join().usageInfo()) + .isEqualTo(new UsageInfo(0L, 0L)); + } + @Test public void deletePartialFailure() { final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID);