Add a crawler to recalculate quota usage

This commit is contained in:
Ravi Khadiwala 2025-05-28 12:51:56 -05:00 committed by ravi-signal
parent 4dc3b19d2a
commit a7ea42adc3
6 changed files with 191 additions and 6 deletions

View File

@ -262,6 +262,7 @@ import org.whispersystems.textsecuregcm.websocket.AuthenticatedConnectListener;
import org.whispersystems.textsecuregcm.websocket.ProvisioningConnectListener;
import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator;
import org.whispersystems.textsecuregcm.workers.BackupMetricsCommand;
import org.whispersystems.textsecuregcm.workers.BackupUsageRecalculationCommand;
import org.whispersystems.textsecuregcm.workers.CertificateCommand;
import org.whispersystems.textsecuregcm.workers.CheckDynamicConfigurationCommand;
import org.whispersystems.textsecuregcm.workers.DeleteUserCommand;
@ -330,6 +331,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
bootstrap.addCommand(new RemoveExpiredUsernameHoldsCommand(Clock.systemUTC()));
bootstrap.addCommand(new RemoveExpiredBackupsCommand(Clock.systemUTC()));
bootstrap.addCommand(new BackupMetricsCommand(Clock.systemUTC()));
bootstrap.addCommand(new BackupUsageRecalculationCommand());
bootstrap.addCommand(new RemoveExpiredLinkedDevicesCommand());
bootstrap.addCommand(new NotifyIdleDevicesCommand());

View File

@ -24,6 +24,7 @@ import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.Function;
import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.signal.libsignal.zkgroup.GenericServerSecretParams;
@ -337,6 +338,19 @@ public class BackupManager {
});
}
public record RecalculationResult(UsageInfo oldUsage, UsageInfo newUsage) {}
public CompletionStage<Optional<RecalculationResult>> recalculateQuota(final StoredBackupAttributes storedBackupAttributes) {
if (StringUtils.isBlank(storedBackupAttributes.backupDir()) || StringUtils.isBlank(storedBackupAttributes.mediaDir())) {
return CompletableFuture.completedFuture(Optional.empty());
}
final String cdnPath = cdnMediaDirectory(storedBackupAttributes.backupDir(), storedBackupAttributes.mediaDir());
return this.remoteStorageManager.calculateBytesUsed(cdnPath).thenCompose(usage ->
backupsDb.setMediaUsage(storedBackupAttributes, usage).thenApply(ignored ->
Optional.of(new RecalculationResult(
new UsageInfo(storedBackupAttributes.bytesUsed(), storedBackupAttributes.numObjects()),
usage))));
}
/**
* @return the largest index i such that sum(ts[0],...ts[i - 1]) <= max
*/
@ -735,8 +749,12 @@ public class BackupManager {
return "%s/%s".formatted(backupUser.backupDir(), MESSAGE_BACKUP_NAME);
}
private static String cdnMediaDirectory(final String backupDir, final String mediaDir) {
return "%s/%s/".formatted(backupDir, mediaDir);
}
private static String cdnMediaDirectory(final AuthenticatedBackupUser backupUser) {
return "%s/%s/".formatted(backupUser.backupDir(), backupUser.mediaDir());
return cdnMediaDirectory(backupUser.backupDir(), backupUser.mediaDir());
}
private static String cdnMediaPath(final AuthenticatedBackupUser backupUser, final byte[] mediaId) {

View File

@ -402,8 +402,16 @@ public class BackupsDb {
}
CompletableFuture<Void> setMediaUsage(final AuthenticatedBackupUser backupUser, UsageInfo usageInfo) {
return setMediaUsage(UpdateBuilder.forUser(backupTableName, backupUser), usageInfo);
}
CompletableFuture<Void> setMediaUsage(final StoredBackupAttributes backupAttributes, UsageInfo usageInfo) {
return setMediaUsage(new UpdateBuilder(backupTableName, BackupLevel.PAID, backupAttributes.hashedBackupId()), usageInfo);
}
private CompletableFuture<Void> setMediaUsage(final UpdateBuilder updateBuilder, UsageInfo usageInfo) {
return dynamoClient.updateItem(
UpdateBuilder.forUser(backupTableName, backupUser)
updateBuilder
.addSetExpression("#mediaBytesUsed = :mediaBytesUsed",
Map.entry("#mediaBytesUsed", ATTR_MEDIA_BYTES_USED),
Map.entry(":mediaBytesUsed", AttributeValues.n(usageInfo.bytesUsed())))
@ -496,13 +504,18 @@ public class BackupsDb {
"#refresh", ATTR_LAST_REFRESH,
"#mediaRefresh", ATTR_LAST_MEDIA_REFRESH,
"#bytesUsed", ATTR_MEDIA_BYTES_USED,
"#numObjects", ATTR_MEDIA_COUNT))
.projectionExpression("#backupIdHash, #refresh, #mediaRefresh, #bytesUsed, #numObjects")
"#numObjects", ATTR_MEDIA_COUNT,
"#backupDir", ATTR_BACKUP_DIR,
"#mediaDir", ATTR_MEDIA_DIR))
.projectionExpression("#backupIdHash, #refresh, #mediaRefresh, #bytesUsed, #numObjects, #backupDir, #mediaDir")
.build())
.items())
.sequential()
.filter(item -> item.containsKey(KEY_BACKUP_ID_HASH))
.map(item -> new StoredBackupAttributes(
AttributeValues.getByteArray(item, KEY_BACKUP_ID_HASH, null),
AttributeValues.getString(item, ATTR_BACKUP_DIR, null),
AttributeValues.getString(item, ATTR_MEDIA_DIR, null),
Instant.ofEpochSecond(AttributeValues.getLong(item, ATTR_LAST_REFRESH, 0L)),
Instant.ofEpochSecond(AttributeValues.getLong(item, ATTR_LAST_MEDIA_REFRESH, 0L)),
AttributeValues.getLong(item, ATTR_MEDIA_BYTES_USED, 0L),

View File

@ -9,11 +9,19 @@ import java.time.Instant;
/**
* Attributes stored in the backups table for a single backup id
*
* @param hashedBackupId The hashed backup-id of this entry
* @param backupDir The cdn backupDir of this entry
* @param mediaDir The cdn mediaDir (within the backupDir) of this entry
* @param lastRefresh The last time the record was updated with a messages or media tier credential
* @param lastMediaRefresh The last time the record was updated with a media tier credential
* @param bytesUsed The number of media bytes used by the backup
* @param numObjects The number of media objects used byt the backup
*/
public record StoredBackupAttributes(
Instant lastRefresh, Instant lastMediaRefresh,
long bytesUsed, long numObjects) {}
byte[] hashedBackupId,
String backupDir,
String mediaDir,
Instant lastRefresh,
Instant lastMediaRefresh,
long bytesUsed,
long numObjects) {}

View File

@ -0,0 +1,120 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.workers;
import io.dropwizard.core.Application;
import io.dropwizard.core.setup.Environment;
import io.micrometer.core.instrument.Metrics;
import net.sourceforge.argparse4j.inf.Namespace;
import net.sourceforge.argparse4j.inf.Subparser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.backup.BackupManager;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import java.util.Objects;
public class BackupUsageRecalculationCommand extends AbstractCommandWithDependencies {
private final Logger logger = LoggerFactory.getLogger(getClass());
private static final String SEGMENT_COUNT_ARGUMENT = "segments";
private static final int DEFAULT_SEGMENT_COUNT = 1;
private static final String MAX_CONCURRENCY_ARGUMENT = "max-concurrency";
private static final int DEFAULT_MAX_CONCURRENCY = 4;
private static final String RECALCULATION_COUNT_COUNTER_NAME =
MetricsUtil.name(BackupUsageRecalculationCommand.class, "countRecalculations");
private static final String RECALCULATION_BYTE_COUNTER_NAME =
MetricsUtil.name(BackupUsageRecalculationCommand.class, "byteRecalculations");
public BackupUsageRecalculationCommand() {
super(new Application<>() {
@Override
public void run(final WhisperServerConfiguration configuration, final Environment environment) {
}
}, "backup-usage-recalculation", "Recalculate the usage of backups");
}
@Override
public void configure(final Subparser subparser) {
super.configure(subparser);
subparser.addArgument("--segments")
.type(Integer.class)
.dest(SEGMENT_COUNT_ARGUMENT)
.required(false)
.setDefault(DEFAULT_SEGMENT_COUNT)
.help("The total number of segments for a DynamoDB scan");
subparser.addArgument("--max-concurrency")
.type(Integer.class)
.dest(MAX_CONCURRENCY_ARGUMENT)
.setDefault(DEFAULT_MAX_CONCURRENCY)
.help("Max concurrency for DynamoDB operations");
}
@Override
protected void run(final Environment environment, final Namespace namespace,
final WhisperServerConfiguration configuration, final CommandDependencies commandDependencies) throws Exception {
final int segments = Objects.requireNonNull(namespace.getInt(SEGMENT_COUNT_ARGUMENT));
final int recalculationConcurrency = Objects.requireNonNull(namespace.getInt(MAX_CONCURRENCY_ARGUMENT));
logger.info("Crawling to recalculate usage with {} segments and {} processors",
segments,
Runtime.getRuntime().availableProcessors());
final BackupManager backupManager = commandDependencies.backupManager();
final Long backupsConsidered = backupManager
.listBackupAttributes(segments, Schedulers.parallel())
.flatMap(attrs -> Mono.fromCompletionStage(() -> backupManager.recalculateQuota(attrs)).doOnNext(maybeRecalculationResult -> maybeRecalculationResult.ifPresent(recalculationResult -> {
if (!recalculationResult.newUsage().equals(recalculationResult.oldUsage())) {
logger.info("Recalculated usage. oldUsage={}, newUsage={}, lastRefresh={}, lastMediaRefresh={}",
recalculationResult.oldUsage(),
recalculationResult.newUsage(),
attrs.lastRefresh(),
attrs.lastMediaRefresh());
}
Metrics.counter(RECALCULATION_COUNT_COUNTER_NAME,
"delta", DeltaType.deltaType(
recalculationResult.oldUsage().numObjects(),
recalculationResult.newUsage().numObjects()).name())
.increment();
Metrics.counter(RECALCULATION_BYTE_COUNTER_NAME,
"delta", DeltaType.deltaType(
recalculationResult.oldUsage().bytesUsed(),
recalculationResult.newUsage().bytesUsed()).name())
.increment();
}
)), recalculationConcurrency)
.count()
.block();
logger.info("Crawled {} backups", backupsConsidered);
}
private enum DeltaType {
REDUCED,
SAME,
INCREASED;
static DeltaType deltaType(long oldv, long newv) {
return switch (Long.signum(newv - oldv)) {
case -1 -> REDUCED;
case 0 -> SAME;
case 1 -> INCREASED;
default -> throw new IllegalStateException("Unexpected value: " + (newv - oldv));
};
}
}
}

View File

@ -642,6 +642,30 @@ public class BackupManagerTest {
}
}
@Test
public void requestRecalculation() {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID);
final String backupMediaPrefix = "%s/%s/".formatted(backupUser.backupDir(), backupUser.mediaDir());
final UsageInfo oldUsage = new UsageInfo(1000, 100);
final UsageInfo newUsage = new UsageInfo(2000, 200);
testClock.pin(Instant.ofEpochSecond(123));
backupsDb.setMediaUsage(backupUser, oldUsage).join();
when(remoteStorageManager.calculateBytesUsed(eq(backupMediaPrefix)))
.thenReturn(CompletableFuture.completedFuture(newUsage));
final StoredBackupAttributes attrs = backupManager.listBackupAttributes(1, Schedulers.immediate()).single().block();
testClock.pin(Instant.ofEpochSecond(456));
assertThat(backupManager.recalculateQuota(attrs).toCompletableFuture().join())
.get()
.isEqualTo(new BackupManager.RecalculationResult(oldUsage, newUsage));
// backupsDb should have the new value
final BackupsDb.TimestampedUsageInfo info = backupsDb.getMediaUsage(backupUser).join();
assertThat(info.lastRecalculationTime()).isEqualTo(Instant.ofEpochSecond(456));
assertThat(info.usageInfo()).isEqualTo(newUsage);
}
@ParameterizedTest
@ValueSource(strings = {"", "cursor"})
public void list(final String cursorVal) {