Add support for distinct media backup credentials

Co-authored-by: Ravi Khadiwala <ravi@signal.org>
This commit is contained in:
Jon Chambers 2024-10-29 16:03:10 -04:00 committed by GitHub
parent d335b7a033
commit b21b50873f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 566 additions and 258 deletions

View File

@ -18,6 +18,7 @@ import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.IdentityKeyPair; import org.signal.libsignal.protocol.IdentityKeyPair;
import org.signal.libsignal.protocol.InvalidKeyException;
import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.signal.libsignal.protocol.state.SignedPreKeyRecord; import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
import org.signal.libsignal.protocol.util.KeyHelper; import org.signal.libsignal.protocol.util.KeyHelper;
@ -161,15 +162,19 @@ public class TestUser {
: aciIdentityKey; : aciIdentityKey;
final TestDevice device = requireNonNull(devices.get(deviceId)); final TestDevice device = requireNonNull(devices.get(deviceId));
final SignedPreKeyRecord signedPreKeyRecord = device.latestSignedPreKey(identity); final SignedPreKeyRecord signedPreKeyRecord = device.latestSignedPreKey(identity);
return new PreKeySetPublicView( try {
Collections.emptyList(), return new PreKeySetPublicView(
identity.getPublicKey(), Collections.emptyList(),
new SignedPreKeyPublicView( identity.getPublicKey(),
signedPreKeyRecord.getId(), new SignedPreKeyPublicView(
signedPreKeyRecord.getKeyPair().getPublicKey(), signedPreKeyRecord.getId(),
signedPreKeyRecord.getSignature() signedPreKeyRecord.getKeyPair().getPublicKey(),
) signedPreKeyRecord.getSignature()
); )
);
} catch (InvalidKeyException e) {
throw new RuntimeException(e);
}
} }
public record SignedPreKeyPublicView( public record SignedPreKeyPublicView(

View File

@ -272,7 +272,7 @@
<dependency> <dependency>
<groupId>org.signal</groupId> <groupId>org.signal</groupId>
<artifactId>libsignal-server</artifactId> <artifactId>libsignal-server</artifactId>
<version>0.54.2</version> <version>0.60.0</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.signal.forks</groupId> <groupId>org.signal.forks</groupId>

View File

@ -5,6 +5,12 @@
package org.whispersystems.textsecuregcm.auth; package org.whispersystems.textsecuregcm.auth;
import org.signal.libsignal.zkgroup.backups.BackupCredentialType;
import org.signal.libsignal.zkgroup.backups.BackupLevel; import org.signal.libsignal.zkgroup.backups.BackupLevel;
public record AuthenticatedBackupUser(byte[] backupId, BackupLevel backupLevel, String backupDir, String mediaDir) {} public record AuthenticatedBackupUser(byte[] backupId,
BackupCredentialType credentialType,
BackupLevel backupLevel,
String backupDir,
String mediaDir) {
}

View File

@ -21,6 +21,7 @@ import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.VerificationFailedException; import org.signal.libsignal.zkgroup.VerificationFailedException;
import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequest; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequest;
import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialResponse; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialResponse;
import org.signal.libsignal.zkgroup.backups.BackupCredentialType;
import org.signal.libsignal.zkgroup.backups.BackupLevel; import org.signal.libsignal.zkgroup.backups.BackupLevel;
import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialPresentation; import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialPresentation;
import org.signal.libsignal.zkgroup.receipts.ReceiptSerial; import org.signal.libsignal.zkgroup.receipts.ReceiptSerial;
@ -29,7 +30,6 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
@ -81,22 +81,34 @@ public class BackupAuthManager {
} }
/** /**
* Store a credential request containing a blinded backup-id for future use. * Store credential requests containing blinded backup-ids for future use.
* *
* @param account The account using the backup-id * @param account The account using the backup-id
* @param backupAuthCredentialRequest A request containing the blinded backup-id * @param messagesBackupCredentialRequest A request containing the blinded backup-id the client will use to upload
* message backups
* @param mediaBackupCredentialRequest A request containing the blinded backup-id the client will use to upload
* media backups
* @return A future that completes when the credentialRequest has been stored * @return A future that completes when the credentialRequest has been stored
* @throws RateLimitExceededException If too many backup-ids have been committed * @throws RateLimitExceededException If too many backup-ids have been committed
*/ */
public CompletableFuture<Void> commitBackupId(final Account account, public CompletableFuture<Void> commitBackupId(final Account account,
final BackupAuthCredentialRequest backupAuthCredentialRequest) { final BackupAuthCredentialRequest messagesBackupCredentialRequest,
final BackupAuthCredentialRequest mediaBackupCredentialRequest) {
if (configuredBackupLevel(account).isEmpty()) { if (configuredBackupLevel(account).isEmpty()) {
throw Status.PERMISSION_DENIED.withDescription("Backups not allowed on account").asRuntimeException(); throw Status.PERMISSION_DENIED.withDescription("Backups not allowed on account").asRuntimeException();
} }
final byte[] serializedMessageCredentialRequest = messagesBackupCredentialRequest.serialize();
final byte[] serializedMediaCredentialRequest = mediaBackupCredentialRequest.serialize();
byte[] serializedRequest = backupAuthCredentialRequest.serialize(); final boolean messageCredentialRequestMatches = account.getBackupCredentialRequest(BackupCredentialType.MESSAGES)
byte[] existingRequest = account.getBackupCredentialRequest(); .map(storedCredentialRequest -> MessageDigest.isEqual(storedCredentialRequest, serializedMessageCredentialRequest))
if (existingRequest != null && MessageDigest.isEqual(serializedRequest, existingRequest)) { .orElse(false);
final boolean mediaCredentialRequestMatches = account.getBackupCredentialRequest(BackupCredentialType.MEDIA)
.map(storedCredentialRequest -> MessageDigest.isEqual(storedCredentialRequest, serializedMediaCredentialRequest))
.orElse(false);
if (messageCredentialRequestMatches && mediaCredentialRequestMatches) {
// No need to update or enforce rate limits, this is the credential that the user has already // No need to update or enforce rate limits, this is the credential that the user has already
// committed to. // committed to.
return CompletableFuture.completedFuture(null); return CompletableFuture.completedFuture(null);
@ -105,7 +117,7 @@ public class BackupAuthManager {
return rateLimiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID) return rateLimiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID)
.validateAsync(account.getUuid()) .validateAsync(account.getUuid())
.thenCompose(ignored -> this.accountsManager .thenCompose(ignored -> this.accountsManager
.updateAsync(account, acc -> acc.setBackupCredentialRequest(serializedRequest)) .updateAsync(account, a -> a.setBackupCredentialRequests(serializedMessageCredentialRequest, serializedMediaCredentialRequest))
.thenRun(Util.NOOP)) .thenRun(Util.NOOP))
.toCompletableFuture(); .toCompletableFuture();
} }
@ -123,12 +135,14 @@ public class BackupAuthManager {
* method will also remove the expired voucher from the account. * method will also remove the expired voucher from the account.
* *
* @param account The account to create the credentials for * @param account The account to create the credentials for
* @param credentialType The type of backup credentials to create
* @param redemptionStart The day (must be truncated to a day boundary) the first credential should be valid * @param redemptionStart The day (must be truncated to a day boundary) the first credential should be valid
* @param redemptionEnd The day (must be truncated to a day boundary) the last credential should be valid * @param redemptionEnd The day (must be truncated to a day boundary) the last credential should be valid
* @return Credentials and the day on which they may be redeemed * @return Credentials and the day on which they may be redeemed
*/ */
public CompletableFuture<List<Credential>> getBackupAuthCredentials( public CompletableFuture<List<Credential>> getBackupAuthCredentials(
final Account account, final Account account,
final BackupCredentialType credentialType,
final Instant redemptionStart, final Instant redemptionStart,
final Instant redemptionEnd) { final Instant redemptionEnd) {
@ -139,7 +153,7 @@ public class BackupAuthManager {
if (hasExpiredVoucher(a)) { if (hasExpiredVoucher(a)) {
a.setBackupVoucher(null); a.setBackupVoucher(null);
} }
}).thenCompose(updated -> getBackupAuthCredentials(updated, redemptionStart, redemptionEnd)); }).thenCompose(updated -> getBackupAuthCredentials(updated, credentialType, redemptionStart, redemptionEnd));
} }
// If this account isn't allowed some level of backup access via configuration, don't continue // If this account isn't allowed some level of backup access via configuration, don't continue
@ -157,23 +171,20 @@ public class BackupAuthManager {
} }
// fetch the blinded backup-id the account should have previously committed to // fetch the blinded backup-id the account should have previously committed to
final byte[] committedBytes = account.getBackupCredentialRequest(); final byte[] committedBytes = account.getBackupCredentialRequest(credentialType)
if (committedBytes == null) { .orElseThrow(() -> Status.NOT_FOUND.withDescription("No blinded backup-id has been added to the account").asRuntimeException());
throw Status.NOT_FOUND.withDescription("No blinded backup-id has been added to the account").asRuntimeException();
}
try { try {
// create a credential for every day in the requested period // create a credential for every day in the requested period
final BackupAuthCredentialRequest credentialReq = new BackupAuthCredentialRequest(committedBytes); final BackupAuthCredentialRequest credentialReq = new BackupAuthCredentialRequest(committedBytes);
return CompletableFuture.completedFuture(Stream return CompletableFuture.completedFuture(Stream
.iterate(redemptionStart, curr -> curr.plus(Duration.ofDays(1))) .iterate(redemptionStart, redemptionTime -> !redemptionTime.isAfter(redemptionEnd), curr -> curr.plus(Duration.ofDays(1)))
.takeWhile(redemptionTime -> !redemptionTime.isAfter(redemptionEnd))
.map(redemptionTime -> { .map(redemptionTime -> {
// Check if the account has a voucher that's good for a certain receiptLevel at redemption time, otherwise // Check if the account has a voucher that's good for a certain receiptLevel at redemption time, otherwise
// use the default receipt level // use the default receipt level
final BackupLevel backupLevel = storedBackupLevel(account, redemptionTime).orElse(configuredBackupLevel); final BackupLevel backupLevel = storedBackupLevel(account, redemptionTime).orElse(configuredBackupLevel);
return new Credential( return new Credential(
credentialReq.issueCredential(redemptionTime, backupLevel, serverSecretParams), credentialReq.issueCredential(redemptionTime, backupLevel, credentialType, serverSecretParams),
redemptionTime); redemptionTime);
}) })
.toList()); .toList());
@ -210,7 +221,7 @@ public class BackupAuthManager {
final long receiptLevel = receiptCredentialPresentation.getReceiptLevel(); final long receiptLevel = receiptCredentialPresentation.getReceiptLevel();
if (BackupLevelUtil.fromReceiptLevel(receiptLevel) != BackupLevel.MEDIA) { if (BackupLevelUtil.fromReceiptLevel(receiptLevel) != BackupLevel.PAID) {
throw Status.INVALID_ARGUMENT throw Status.INVALID_ARGUMENT
.withDescription("server does not recognize the requested receipt level") .withDescription("server does not recognize the requested receipt level")
.asRuntimeException(); .asRuntimeException();
@ -281,10 +292,10 @@ public class BackupAuthManager {
*/ */
private Optional<BackupLevel> configuredBackupLevel(final Account account) { private Optional<BackupLevel> configuredBackupLevel(final Account account) {
if (inExperiment(BACKUP_MEDIA_EXPERIMENT_NAME, account)) { if (inExperiment(BACKUP_MEDIA_EXPERIMENT_NAME, account)) {
return Optional.of(BackupLevel.MEDIA); return Optional.of(BackupLevel.PAID);
} }
if (inExperiment(BACKUP_EXPERIMENT_NAME, account)) { if (inExperiment(BACKUP_EXPERIMENT_NAME, account)) {
return Optional.of(BackupLevel.MESSAGES); return Optional.of(BackupLevel.FREE);
} }
return Optional.empty(); return Optional.empty();
} }

View File

@ -28,25 +28,22 @@ import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.signal.libsignal.zkgroup.GenericServerSecretParams; import org.signal.libsignal.zkgroup.GenericServerSecretParams;
import org.signal.libsignal.zkgroup.VerificationFailedException; import org.signal.libsignal.zkgroup.VerificationFailedException;
import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialPresentation; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialPresentation;
import org.signal.libsignal.zkgroup.backups.BackupCredentialType;
import org.signal.libsignal.zkgroup.backups.BackupLevel; import org.signal.libsignal.zkgroup.backups.BackupLevel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.attachments.AttachmentGenerator; import org.whispersystems.textsecuregcm.attachments.AttachmentGenerator;
import org.whispersystems.textsecuregcm.attachments.TusAttachmentGenerator; import org.whispersystems.textsecuregcm.attachments.TusAttachmentGenerator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedBackupUser; import org.whispersystems.textsecuregcm.auth.AuthenticatedBackupUser;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.util.AsyncTimerUtil; import org.whispersystems.textsecuregcm.util.AsyncTimerUtil;
import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
public class BackupManager { public class BackupManager {
private static final Logger logger = LoggerFactory.getLogger(BackupManager.class);
static final String MESSAGE_BACKUP_NAME = "messageBackup"; static final String MESSAGE_BACKUP_NAME = "messageBackup";
public static final long MAX_TOTAL_BACKUP_MEDIA_BYTES = DataSize.gibibytes(100).toBytes(); public static final long MAX_TOTAL_BACKUP_MEDIA_BYTES = DataSize.gibibytes(100).toBytes();
static final long MAX_MEDIA_OBJECT_SIZE = DataSize.mebibytes(101).toBytes(); static final long MAX_MEDIA_OBJECT_SIZE = DataSize.mebibytes(101).toBytes();
@ -120,8 +117,10 @@ public class BackupManager {
// Note: this is a special case where we can't validate the presentation signature against the stored public key // Note: this is a special case where we can't validate the presentation signature against the stored public key
// because we are currently setting it. We check against the provided public key, but we must also verify that // because we are currently setting it. We check against the provided public key, but we must also verify that
// there isn't an existing, different stored public key for the backup-id (verified with a condition expression) // there isn't an existing, different stored public key for the backup-id (verified with a condition expression)
final BackupLevel backupLevel = verifyPresentation(presentation).verifySignature(signature, publicKey); final Pair<BackupCredentialType, BackupLevel> credentialTypeAndBackupLevel =
return backupsDb.setPublicKey(presentation.getBackupId(), backupLevel, publicKey) verifyPresentation(presentation).verifySignature(signature, publicKey);
return backupsDb.setPublicKey(presentation.getBackupId(), credentialTypeAndBackupLevel.second(), publicKey)
.exceptionally(ExceptionUtils.exceptionallyHandler(PublicKeyConflictException.class, ex -> { .exceptionally(ExceptionUtils.exceptionallyHandler(PublicKeyConflictException.class, ex -> {
Metrics.counter(ZK_AUTHN_COUNTER_NAME, Metrics.counter(ZK_AUTHN_COUNTER_NAME,
SUCCESS_TAG_NAME, String.valueOf(false), SUCCESS_TAG_NAME, String.valueOf(false),
@ -144,7 +143,8 @@ public class BackupManager {
*/ */
public CompletableFuture<BackupUploadDescriptor> createMessageBackupUploadDescriptor( public CompletableFuture<BackupUploadDescriptor> createMessageBackupUploadDescriptor(
final AuthenticatedBackupUser backupUser) { final AuthenticatedBackupUser backupUser) {
checkBackupLevel(backupUser, BackupLevel.MESSAGES); checkBackupLevel(backupUser, BackupLevel.FREE);
checkBackupCredentialType(backupUser, BackupCredentialType.MESSAGES);
// this could race with concurrent updates, but the only effect would be last-writer-wins on the timestamp // this could race with concurrent updates, but the only effect would be last-writer-wins on the timestamp
return backupsDb return backupsDb
@ -154,7 +154,8 @@ public class BackupManager {
public CompletableFuture<BackupUploadDescriptor> createTemporaryAttachmentUploadDescriptor( public CompletableFuture<BackupUploadDescriptor> createTemporaryAttachmentUploadDescriptor(
final AuthenticatedBackupUser backupUser) { final AuthenticatedBackupUser backupUser) {
checkBackupLevel(backupUser, BackupLevel.MEDIA); checkBackupLevel(backupUser, BackupLevel.PAID);
checkBackupCredentialType(backupUser, BackupCredentialType.MEDIA);
return rateLimiters.forDescriptor(RateLimiters.For.BACKUP_ATTACHMENT) return rateLimiters.forDescriptor(RateLimiters.For.BACKUP_ATTACHMENT)
.validateAsync(rateLimitKey(backupUser)).thenApply(ignored -> { .validateAsync(rateLimitKey(backupUser)).thenApply(ignored -> {
@ -172,7 +173,7 @@ public class BackupManager {
* @param backupUser an already ZK authenticated backup user * @param backupUser an already ZK authenticated backup user
*/ */
public CompletableFuture<Void> ttlRefresh(final AuthenticatedBackupUser backupUser) { public CompletableFuture<Void> ttlRefresh(final AuthenticatedBackupUser backupUser) {
checkBackupLevel(backupUser, BackupLevel.MESSAGES); checkBackupLevel(backupUser, BackupLevel.FREE);
// update message backup TTL // update message backup TTL
return backupsDb.ttlRefresh(backupUser); return backupsDb.ttlRefresh(backupUser);
} }
@ -187,7 +188,7 @@ public class BackupManager {
* @return Information about the existing backup * @return Information about the existing backup
*/ */
public CompletableFuture<BackupInfo> backupInfo(final AuthenticatedBackupUser backupUser) { public CompletableFuture<BackupInfo> backupInfo(final AuthenticatedBackupUser backupUser) {
checkBackupLevel(backupUser, BackupLevel.MESSAGES); checkBackupLevel(backupUser, BackupLevel.FREE);
return backupsDb.describeBackup(backupUser) return backupsDb.describeBackup(backupUser)
.thenApply(backupDescription -> new BackupInfo( .thenApply(backupDescription -> new BackupInfo(
backupDescription.cdn(), backupDescription.cdn(),
@ -210,7 +211,8 @@ public class BackupManager {
* detailing why the object could not be copied. * detailing why the object could not be copied.
*/ */
public Flux<CopyResult> copyToBackup(final AuthenticatedBackupUser backupUser, List<CopyParameters> toCopy) { public Flux<CopyResult> copyToBackup(final AuthenticatedBackupUser backupUser, List<CopyParameters> toCopy) {
checkBackupLevel(backupUser, BackupLevel.MEDIA); checkBackupLevel(backupUser, BackupLevel.PAID);
checkBackupCredentialType(backupUser, BackupCredentialType.MEDIA);
return Mono return Mono
// Figure out how many objects we're allowed to copy, updating the quota usage for the amount we are allowed // Figure out how many objects we're allowed to copy, updating the quota usage for the amount we are allowed
@ -349,7 +351,7 @@ public class BackupManager {
* @return A map of headers to include with CDN requests * @return A map of headers to include with CDN requests
*/ */
public Map<String, String> generateReadAuth(final AuthenticatedBackupUser backupUser, final int cdnNumber) { public Map<String, String> generateReadAuth(final AuthenticatedBackupUser backupUser, final int cdnNumber) {
checkBackupLevel(backupUser, BackupLevel.MESSAGES); checkBackupLevel(backupUser, BackupLevel.FREE);
if (cdnNumber != 3) { if (cdnNumber != 3) {
throw Status.INVALID_ARGUMENT.withDescription("unknown cdn").asRuntimeException(); throw Status.INVALID_ARGUMENT.withDescription("unknown cdn").asRuntimeException();
} }
@ -377,7 +379,7 @@ public class BackupManager {
final AuthenticatedBackupUser backupUser, final AuthenticatedBackupUser backupUser,
final Optional<String> cursor, final Optional<String> cursor,
final int limit) { final int limit) {
checkBackupLevel(backupUser, BackupLevel.MESSAGES); checkBackupLevel(backupUser, BackupLevel.FREE);
return remoteStorageManager.list(cdnMediaDirectory(backupUser), cursor, limit) return remoteStorageManager.list(cdnMediaDirectory(backupUser), cursor, limit)
.thenApply(result -> .thenApply(result ->
new ListMediaResult( new ListMediaResult(
@ -395,7 +397,7 @@ public class BackupManager {
} }
public CompletableFuture<Void> deleteEntireBackup(final AuthenticatedBackupUser backupUser) { public CompletableFuture<Void> deleteEntireBackup(final AuthenticatedBackupUser backupUser) {
checkBackupLevel(backupUser, BackupLevel.MESSAGES); checkBackupLevel(backupUser, BackupLevel.FREE);
return backupsDb return backupsDb
// Try to swap out the backupDir for the user // Try to swap out the backupDir for the user
.scheduleBackupDeletion(backupUser) .scheduleBackupDeletion(backupUser)
@ -408,7 +410,8 @@ public class BackupManager {
public Flux<StorageDescriptor> deleteMedia(final AuthenticatedBackupUser backupUser, public Flux<StorageDescriptor> deleteMedia(final AuthenticatedBackupUser backupUser,
final List<StorageDescriptor> storageDescriptors) { final List<StorageDescriptor> storageDescriptors) {
checkBackupLevel(backupUser, BackupLevel.MESSAGES); checkBackupLevel(backupUser, BackupLevel.FREE);
checkBackupCredentialType(backupUser, BackupCredentialType.MEDIA);
// Check for a cdn we don't know how to process // Check for a cdn we don't know how to process
if (storageDescriptors.stream().anyMatch(sd -> sd.cdn() != remoteStorageManager.cdnNumber())) { if (storageDescriptors.stream().anyMatch(sd -> sd.cdn() != remoteStorageManager.cdnNumber())) {
@ -492,10 +495,16 @@ public class BackupManager {
// There was no stored public key, use a bunk public key so that validation will fail // There was no stored public key, use a bunk public key so that validation will fail
return new BackupsDb.AuthenticationData(INVALID_PUBLIC_KEY, null, null); return new BackupsDb.AuthenticationData(INVALID_PUBLIC_KEY, null, null);
}); });
final Pair<BackupCredentialType, BackupLevel> credentialTypeAndBackupLevel =
signatureVerifier.verifySignature(signature, authenticationData.publicKey());
return new AuthenticatedBackupUser( return new AuthenticatedBackupUser(
presentation.getBackupId(), presentation.getBackupId(),
signatureVerifier.verifySignature(signature, authenticationData.publicKey()), credentialTypeAndBackupLevel.first(),
authenticationData.backupDir(), authenticationData.mediaDir()); credentialTypeAndBackupLevel.second(),
authenticationData.backupDir(),
authenticationData.mediaDir());
}) })
.thenApply(result -> { .thenApply(result -> {
Metrics.counter(ZK_AUTHN_COUNTER_NAME, SUCCESS_TAG_NAME, String.valueOf(true)).increment(); Metrics.counter(ZK_AUTHN_COUNTER_NAME, SUCCESS_TAG_NAME, String.valueOf(true)).increment();
@ -579,7 +588,7 @@ public class BackupManager {
interface PresentationSignatureVerifier { interface PresentationSignatureVerifier {
BackupLevel verifySignature(byte[] signature, ECPublicKey publicKey); Pair<BackupCredentialType, BackupLevel> verifySignature(byte[] signature, ECPublicKey publicKey);
} }
/** /**
@ -611,7 +620,7 @@ public class BackupManager {
.withDescription("backup auth credential presentation signature verification failed") .withDescription("backup auth credential presentation signature verification failed")
.asRuntimeException(); .asRuntimeException();
} }
return presentation.getBackupLevel(); return new Pair<>(presentation.getType(), presentation.getBackupLevel());
}; };
} }
@ -622,9 +631,34 @@ public class BackupManager {
* @param backupLevel The authorization level to verify the backupUser has access to * @param backupLevel The authorization level to verify the backupUser has access to
* @throws {@link Status#PERMISSION_DENIED} error if the backup user is not authorized to access {@code backupLevel} * @throws {@link Status#PERMISSION_DENIED} error if the backup user is not authorized to access {@code backupLevel}
*/ */
private static void checkBackupLevel(final AuthenticatedBackupUser backupUser, final BackupLevel backupLevel) { @VisibleForTesting
static void checkBackupLevel(final AuthenticatedBackupUser backupUser, final BackupLevel backupLevel) {
if (backupUser.backupLevel().compareTo(backupLevel) < 0) { if (backupUser.backupLevel().compareTo(backupLevel) < 0) {
Metrics.counter(ZK_AUTHZ_FAILURE_COUNTER_NAME).increment(); Metrics.counter(ZK_AUTHZ_FAILURE_COUNTER_NAME,
FAILURE_REASON_TAG_NAME, "level")
.increment();
throw Status.PERMISSION_DENIED
.withDescription("credential does not support the requested operation")
.asRuntimeException();
}
}
/**
* Check that the authenticated backup user is authenticated with the given credential type
*
* @param backupUser The backup user to check
* @param credentialType The credential type to require
* @throws {@link Status#PERMISSION_DENIED} error if the backup user is not authenticated with the given
* {@code credentialType}
*/
@VisibleForTesting
static void checkBackupCredentialType(final AuthenticatedBackupUser backupUser, final BackupCredentialType credentialType) {
if (backupUser.credentialType() != credentialType) {
Metrics.counter(ZK_AUTHZ_FAILURE_COUNTER_NAME,
FAILURE_REASON_TAG_NAME, "credential_type")
.increment();
throw Status.PERMISSION_DENIED throw Status.PERMISSION_DENIED
.withDescription("credential does not support the requested operation") .withDescription("credential does not support the requested operation")
.asRuntimeException(); .asRuntimeException();

View File

@ -87,7 +87,7 @@ public class BackupsDb {
// garbage collection of archive objects. // garbage collection of archive objects.
public static final String ATTR_LAST_REFRESH = "R"; public static final String ATTR_LAST_REFRESH = "R";
// N: Time in seconds since epoch of the last backup media refresh. This timestamp can only be updated if the client // N: Time in seconds since epoch of the last backup media refresh. This timestamp can only be updated if the client
// has BackupLevel.MEDIA, and must be periodically updated to avoid garbage collection of media objects. // has BackupLevel.PAID, and must be periodically updated to avoid garbage collection of media objects.
public static final String ATTR_LAST_MEDIA_REFRESH = "MR"; public static final String ATTR_LAST_MEDIA_REFRESH = "MR";
// B: A 32 byte public key that should be used to sign the presentation used to authenticate requests against the // B: A 32 byte public key that should be used to sign the presentation used to authenticate requests against the
// backup-id // backup-id
@ -265,7 +265,7 @@ public class BackupsDb {
* Indicates that we couldn't schedule a deletion because one was already scheduled. The caller may want to delete the * Indicates that we couldn't schedule a deletion because one was already scheduled. The caller may want to delete the
* objects directly. * objects directly.
*/ */
class PendingDeletionException extends IOException {} static class PendingDeletionException extends IOException {}
/** /**
* Attempt to mark a backup as expired and swap in a new empty backupDir for the user. * Attempt to mark a backup as expired and swap in a new empty backupDir for the user.
@ -285,7 +285,7 @@ public class BackupsDb {
final byte[] hashedBackupId = hashedBackupId(backupUser); final byte[] hashedBackupId = hashedBackupId(backupUser);
// Clear usage metadata, swap names of things we intend to delete, and record our intent to delete in attr_expired_prefix // Clear usage metadata, swap names of things we intend to delete, and record our intent to delete in attr_expired_prefix
return dynamoClient.updateItem(new UpdateBuilder(backupTableName, BackupLevel.MEDIA, hashedBackupId) return dynamoClient.updateItem(new UpdateBuilder(backupTableName, BackupLevel.PAID, hashedBackupId)
.clearMediaUsage(clock) .clearMediaUsage(clock)
.expireDirectoryNames(secureRandom, ExpiredBackup.ExpirationType.ALL) .expireDirectoryNames(secureRandom, ExpiredBackup.ExpirationType.ALL)
.setRefreshTimes(Instant.ofEpochSecond(0)) .setRefreshTimes(Instant.ofEpochSecond(0))
@ -300,7 +300,7 @@ public class BackupsDb {
// is toggling backups on and off. In this case, it should be pretty cheap to directly delete the backup. // is toggling backups on and off. In this case, it should be pretty cheap to directly delete the backup.
// Instead of changing the backupDir, just make sure the row has expired/ timestamps and tell the caller we // Instead of changing the backupDir, just make sure the row has expired/ timestamps and tell the caller we
// couldn't schedule the deletion. // couldn't schedule the deletion.
dynamoClient.updateItem(new UpdateBuilder(backupTableName, BackupLevel.MEDIA, hashedBackupId) dynamoClient.updateItem(new UpdateBuilder(backupTableName, BackupLevel.PAID, hashedBackupId)
.setRefreshTimes(Instant.ofEpochSecond(0)) .setRefreshTimes(Instant.ofEpochSecond(0))
.updateItemBuilder() .updateItemBuilder()
.build()) .build())
@ -399,7 +399,7 @@ public class BackupsDb {
} }
// Clear usage metadata, swap names of things we intend to delete, and record our intent to delete in attr_expired_prefix // Clear usage metadata, swap names of things we intend to delete, and record our intent to delete in attr_expired_prefix
return dynamoClient.updateItem(new UpdateBuilder(backupTableName, BackupLevel.MEDIA, expiredBackup.hashedBackupId()) return dynamoClient.updateItem(new UpdateBuilder(backupTableName, BackupLevel.PAID, expiredBackup.hashedBackupId())
.clearMediaUsage(clock) .clearMediaUsage(clock)
.expireDirectoryNames(secureRandom, expiredBackup.expirationType()) .expireDirectoryNames(secureRandom, expiredBackup.expirationType())
.addRemoveExpression(Map.entry("#mediaRefresh", ATTR_LAST_MEDIA_REFRESH)) .addRemoveExpression(Map.entry("#mediaRefresh", ATTR_LAST_MEDIA_REFRESH))
@ -433,7 +433,7 @@ public class BackupsDb {
.build()) .build())
.thenRun(Util.NOOP); .thenRun(Util.NOOP);
} else { } else {
return dynamoClient.updateItem(new UpdateBuilder(backupTableName, BackupLevel.MEDIA, hashedBackupId) return dynamoClient.updateItem(new UpdateBuilder(backupTableName, BackupLevel.PAID, hashedBackupId)
.addRemoveExpression(Map.entry("#expiredPrefixes", ATTR_EXPIRED_PREFIX)) .addRemoveExpression(Map.entry("#expiredPrefixes", ATTR_EXPIRED_PREFIX))
.updateItemBuilder() .updateItemBuilder()
.build()) .build())
@ -722,7 +722,7 @@ public class BackupsDb {
Map.entry("#lastRefreshTime", ATTR_LAST_REFRESH), Map.entry("#lastRefreshTime", ATTR_LAST_REFRESH),
Map.entry(":lastRefreshTime", AttributeValues.n(refreshTime.getEpochSecond()))); Map.entry(":lastRefreshTime", AttributeValues.n(refreshTime.getEpochSecond())));
if (backupLevel.compareTo(BackupLevel.MEDIA) >= 0) { if (backupLevel.compareTo(BackupLevel.PAID) >= 0) {
// update the media time if we have the appropriate level // update the media time if we have the appropriate level
addSetExpression("#lastMediaRefreshTime = :lastMediaRefreshTime", addSetExpression("#lastMediaRefreshTime = :lastMediaRefreshTime",
Map.entry("#lastMediaRefreshTime", ATTR_LAST_MEDIA_REFRESH), Map.entry("#lastMediaRefreshTime", ATTR_LAST_MEDIA_REFRESH),

View File

@ -23,11 +23,14 @@ import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy; import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target; import java.lang.annotation.Target;
import java.time.Instant; import java.time.Instant;
import java.util.Arrays;
import java.util.Base64; import java.util.Base64;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage; import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.Max; import javax.validation.constraints.Max;
@ -52,6 +55,7 @@ import org.signal.libsignal.protocol.ecc.ECPublicKey;
import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialPresentation; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialPresentation;
import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequest; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequest;
import org.signal.libsignal.zkgroup.backups.BackupCredentialType;
import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialPresentation; import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialPresentation;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.backup.BackupAuthManager; import org.whispersystems.textsecuregcm.backup.BackupAuthManager;
@ -89,11 +93,21 @@ public class ArchiveController {
public record SetBackupIdRequest( public record SetBackupIdRequest(
@Schema(description = """ @Schema(description = """
A BackupAuthCredentialRequest containing a blinded encrypted backup-id, encoded in standard padded base64 A BackupAuthCredentialRequest containing a blinded encrypted backup-id, encoded in standard padded base64.
This backup-id should be used for message backups only, and must have the message backup type set on the
credential.
""", implementation = String.class) """, implementation = String.class)
@JsonDeserialize(using = BackupAuthCredentialAdapter.CredentialRequestDeserializer.class) @JsonDeserialize(using = BackupAuthCredentialAdapter.CredentialRequestDeserializer.class)
@JsonSerialize(using = BackupAuthCredentialAdapter.CredentialRequestSerializer.class) @JsonSerialize(using = BackupAuthCredentialAdapter.CredentialRequestSerializer.class)
@NotNull BackupAuthCredentialRequest backupAuthCredentialRequest) {} @NotNull BackupAuthCredentialRequest messagesBackupAuthCredentialRequest,
@Schema(description = """
A BackupAuthCredentialRequest containing a blinded encrypted backup-id, encoded in standard padded base64.
This backup-id should be used for media only, and must have the media type set on the credential.
""", implementation = String.class)
@JsonDeserialize(using = BackupAuthCredentialAdapter.CredentialRequestDeserializer.class)
@JsonSerialize(using = BackupAuthCredentialAdapter.CredentialRequestSerializer.class)
@NotNull BackupAuthCredentialRequest mediaBackupAuthCredentialRequest) {}
@PUT @PUT
@ -115,8 +129,9 @@ public class ArchiveController {
public CompletionStage<Response> setBackupId( public CompletionStage<Response> setBackupId(
@Mutable @Auth final AuthenticatedDevice account, @Mutable @Auth final AuthenticatedDevice account,
@Valid @NotNull final SetBackupIdRequest setBackupIdRequest) throws RateLimitExceededException { @Valid @NotNull final SetBackupIdRequest setBackupIdRequest) throws RateLimitExceededException {
return this.backupAuthManager return this.backupAuthManager
.commitBackupId(account.getAccount(), setBackupIdRequest.backupAuthCredentialRequest) .commitBackupId(account.getAccount(), setBackupIdRequest.messagesBackupAuthCredentialRequest, setBackupIdRequest.mediaBackupAuthCredentialRequest)
.thenApply(Util.ASYNC_EMPTY_RESPONSE); .thenApply(Util.ASYNC_EMPTY_RESPONSE);
} }
@ -166,8 +181,8 @@ public class ArchiveController {
} }
public record BackupAuthCredentialsResponse( public record BackupAuthCredentialsResponse(
@Schema(description = "A list of BackupAuthCredentials and their validity periods") @Schema(description = "A map of credential types to lists of BackupAuthCredentials and their validity periods")
List<BackupAuthCredential> credentials) { Map<BackupCredentialType, List<BackupAuthCredential>> credentials) {
public record BackupAuthCredential( public record BackupAuthCredential(
@Schema(description = "A BackupAuthCredential, encoded in standard padded base64") @Schema(description = "A BackupAuthCredential, encoded in standard padded base64")
@ -202,14 +217,21 @@ public class ArchiveController {
@NotNull @QueryParam("redemptionStartSeconds") Long startSeconds, @NotNull @QueryParam("redemptionStartSeconds") Long startSeconds,
@NotNull @QueryParam("redemptionEndSeconds") Long endSeconds) { @NotNull @QueryParam("redemptionEndSeconds") Long endSeconds) {
return this.backupAuthManager.getBackupAuthCredentials( final Map<BackupCredentialType, List<BackupAuthCredentialsResponse.BackupAuthCredential>> credentialsByType =
auth.getAccount(), new ConcurrentHashMap<>();
Instant.ofEpochSecond(startSeconds), Instant.ofEpochSecond(endSeconds))
.thenApply(creds -> new BackupAuthCredentialsResponse(creds.stream() return CompletableFuture.allOf(Arrays.stream(BackupCredentialType.values())
.map(cred -> new BackupAuthCredentialsResponse.BackupAuthCredential( .map(credentialType -> this.backupAuthManager.getBackupAuthCredentials(
cred.credential().serialize(), auth.getAccount(),
cred.redemptionTime().getEpochSecond())) credentialType,
.toList())); Instant.ofEpochSecond(startSeconds), Instant.ofEpochSecond(endSeconds))
.thenAccept(credentials -> credentialsByType.put(credentialType, credentials.stream()
.map(credential -> new BackupAuthCredentialsResponse.BackupAuthCredential(
credential.credential().serialize(),
credential.redemptionTime().getEpochSecond()))
.toList())))
.toArray(CompletableFuture[]::new))
.thenApply(ignored -> new BackupAuthCredentialsResponse(credentialsByType));
} }
@ -227,7 +249,8 @@ public class ArchiveController {
@ApiResponse(responseCode = "401", description = """ @ApiResponse(responseCode = "401", description = """
The provided backup auth credential presentation could not be verified or The provided backup auth credential presentation could not be verified or
The public key signature was invalid or The public key signature was invalid or
There is no backup associated with the backup-id in the presentation""") There is no backup associated with the backup-id in the presentation or
The credential was of the wrong type (messages/media)""")
@ApiResponse(responseCode = "400", description = "Bad arguments. The request may have been made on an authenticated channel") @ApiResponse(responseCode = "400", description = "Bad arguments. The request may have been made on an authenticated channel")
@interface ApiResponseZkAuth {} @interface ApiResponseZkAuth {}
@ -453,7 +476,7 @@ public class ArchiveController {
throw new BadRequestException("must not use authenticated connection for anonymous operations"); throw new BadRequestException("must not use authenticated connection for anonymous operations");
} }
return backupManager.authenticateBackupUser(presentation.presentation, signature.signature) return backupManager.authenticateBackupUser(presentation.presentation, signature.signature)
.thenCompose(backupUser -> backupManager.createTemporaryAttachmentUploadDescriptor(backupUser)) .thenCompose(backupManager::createTemporaryAttachmentUploadDescriptor)
.thenApply(result -> new UploadDescriptorResponse( .thenApply(result -> new UploadDescriptorResponse(
result.cdn(), result.cdn(),
result.key(), result.key(),

View File

@ -22,6 +22,7 @@ import java.util.function.Predicate;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.zkgroup.backups.BackupCredentialType;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
@ -116,7 +117,11 @@ public class Account {
@JsonProperty("bcr") @JsonProperty("bcr")
@Nullable @Nullable
private byte[] backupCredentialRequest; private byte[] messagesBackupCredentialRequest;
@JsonProperty("mbcr")
@Nullable
private byte[] mediaBackupCredentialRequest;
@JsonProperty("bv") @JsonProperty("bv")
@Nullable @Nullable
@ -284,7 +289,7 @@ public class Account {
requireNotStale(); requireNotStale();
return Optional.ofNullable(getPrimaryDevice().getCapabilities()) return Optional.ofNullable(getPrimaryDevice().getCapabilities())
.map(Device.DeviceCapabilities::transfer) .map(DeviceCapabilities::transfer)
.orElse(false); .orElse(false);
} }
@ -509,12 +514,22 @@ public class Account {
this.svr3ShareSet = svr3ShareSet; this.svr3ShareSet = svr3ShareSet;
} }
public byte[] getBackupCredentialRequest() { public void setBackupCredentialRequests(final byte[] messagesBackupCredentialRequest,
return backupCredentialRequest; final byte[] mediaBackupCredentialRequest) {
requireNotStale();
this.messagesBackupCredentialRequest = messagesBackupCredentialRequest;
this.mediaBackupCredentialRequest = mediaBackupCredentialRequest;
} }
public void setBackupCredentialRequest(final byte[] backupCredentialRequest) { public Optional<byte[]> getBackupCredentialRequest(final BackupCredentialType credentialType) {
this.backupCredentialRequest = backupCredentialRequest; requireNotStale();
return Optional.ofNullable(switch (credentialType) {
case MESSAGES -> messagesBackupCredentialRequest;
case MEDIA -> mediaBackupCredentialRequest;
});
} }
public @Nullable BackupVoucher getBackupVoucher() { public @Nullable BackupVoucher getBackupVoucher() {

View File

@ -36,6 +36,7 @@ import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.signal.libsignal.zkgroup.backups.BackupCredentialType;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.AsyncTimerUtil; import org.whispersystems.textsecuregcm.util.AsyncTimerUtil;
@ -321,7 +322,9 @@ public class Accounts extends AbstractDynamoDbStore {
// Carry over the old backup id commitment. If the new account claimer cannot does not have the secret used to // Carry over the old backup id commitment. If the new account claimer cannot does not have the secret used to
// generate their backup-id, this credential is useless, however if they can produce the same credential they // generate their backup-id, this credential is useless, however if they can produce the same credential they
// won't be rate-limited for setting their backup-id. // won't be rate-limited for setting their backup-id.
accountToCreate.setBackupCredentialRequest(existingAccount.getBackupCredentialRequest()); accountToCreate.setBackupCredentialRequests(
existingAccount.getBackupCredentialRequest(BackupCredentialType.MESSAGES).orElse(null),
existingAccount.getBackupCredentialRequest(BackupCredentialType.MEDIA).orElse(null));
// Carry over the old SVR3 share-set. This is required for an account to restore information from SVR. The share- // Carry over the old SVR3 share-set. This is required for an account to restore information from SVR. The share-
// set is not a secret, if the new account claimer does not have the SVR3 pin, it is useless. // set is not a secret, if the new account claimer does not have the SVR3 pin, it is useless.

View File

@ -6,8 +6,6 @@
package org.whispersystems.textsecuregcm.workers; package org.whispersystems.textsecuregcm.workers;
import io.dropwizard.core.Application; import io.dropwizard.core.Application;
import io.dropwizard.core.cli.Cli;
import io.dropwizard.core.cli.EnvironmentCommand;
import io.dropwizard.core.setup.Environment; import io.dropwizard.core.setup.Environment;
import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
@ -18,8 +16,6 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.WhisperServerConfiguration; import org.whispersystems.textsecuregcm.WhisperServerConfiguration;
import org.whispersystems.textsecuregcm.backup.BackupManager; import org.whispersystems.textsecuregcm.backup.BackupManager;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.util.logging.UncaughtExceptionHandler;
import reactor.core.scheduler.Schedulers; import reactor.core.scheduler.Schedulers;
import java.time.Clock; import java.time.Clock;
@ -69,13 +65,13 @@ public class BackupMetricsCommand extends AbstractCommandWithDependencies {
Runtime.getRuntime().availableProcessors()); Runtime.getRuntime().availableProcessors());
final DistributionSummary numObjectsMediaTier = Metrics.summary(name(getClass(), "numObjects"), final DistributionSummary numObjectsMediaTier = Metrics.summary(name(getClass(), "numObjects"),
"tier", BackupLevel.MEDIA.name()); "tier", BackupLevel.PAID.name());
final DistributionSummary bytesUsedMediaTier = Metrics.summary(name(getClass(), "bytesUsed"), final DistributionSummary bytesUsedMediaTier = Metrics.summary(name(getClass(), "bytesUsed"),
"tier", BackupLevel.MEDIA.name()); "tier", BackupLevel.PAID.name());
final DistributionSummary numObjectsMessagesTier = Metrics.summary(name(getClass(), "numObjects"), final DistributionSummary numObjectsMessagesTier = Metrics.summary(name(getClass(), "numObjects"),
"tier", BackupLevel.MESSAGES.name()); "tier", BackupLevel.FREE.name());
final DistributionSummary bytesUsedMessagesTier = Metrics.summary(name(getClass(), "bytesUsed"), final DistributionSummary bytesUsedMessagesTier = Metrics.summary(name(getClass(), "bytesUsed"),
"tier", BackupLevel.MESSAGES.name()); "tier", BackupLevel.FREE.name());
final DistributionSummary timeSinceLastRefresh = Metrics.summary(name(getClass(), final DistributionSummary timeSinceLastRefresh = Metrics.summary(name(getClass(),
"timeSinceLastRefresh")); "timeSinceLastRefresh"));

View File

@ -24,6 +24,7 @@ import java.time.Duration;
import java.time.Instant; import java.time.Instant;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.List; import java.util.List;
import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer; import java.util.function.Consumer;
@ -39,12 +40,14 @@ import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.NullSource; import org.junit.jupiter.params.provider.NullSource;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.ServerSecretParams; import org.signal.libsignal.zkgroup.ServerSecretParams;
import org.signal.libsignal.zkgroup.VerificationFailedException; import org.signal.libsignal.zkgroup.VerificationFailedException;
import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequest; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequest;
import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequestContext; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequestContext;
import org.signal.libsignal.zkgroup.backups.BackupCredentialType;
import org.signal.libsignal.zkgroup.backups.BackupLevel; import org.signal.libsignal.zkgroup.backups.BackupLevel;
import org.signal.libsignal.zkgroup.receipts.ClientZkReceiptOperations; import org.signal.libsignal.zkgroup.receipts.ClientZkReceiptOperations;
import org.signal.libsignal.zkgroup.receipts.ReceiptCredential; import org.signal.libsignal.zkgroup.receipts.ReceiptCredential;
@ -67,7 +70,8 @@ import org.whispersystems.textsecuregcm.util.TestRandomUtil;
public class BackupAuthManagerTest { public class BackupAuthManagerTest {
private final UUID aci = UUID.randomUUID(); private final UUID aci = UUID.randomUUID();
private final byte[] backupKey = TestRandomUtil.nextBytes(32); private final byte[] messagesBackupKey = TestRandomUtil.nextBytes(32);
private final byte[] mediaBackupKey = TestRandomUtil.nextBytes(32);
private final ServerSecretParams receiptParams = ServerSecretParams.generate(); private final ServerSecretParams receiptParams = ServerSecretParams.generate();
private final TestClock clock = TestClock.now(); private final TestClock clock = TestClock.now();
private final BackupAuthTestUtil backupAuthTestUtil = new BackupAuthTestUtil(clock); private final BackupAuthTestUtil backupAuthTestUtil = new BackupAuthTestUtil(clock);
@ -92,6 +96,30 @@ public class BackupAuthManagerTest {
clock); clock);
} }
@Test
void commitBackupId() {
final BackupAuthManager authManager = create(BackupLevel.FREE, false);
final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(aci);
when(accountsManager.updateAsync(any(), any()))
.thenAnswer(invocation -> {
final Account a = invocation.getArgument(0);
final Consumer<Account> updater = invocation.getArgument(1);
updater.accept(a);
return CompletableFuture.completedFuture(a);
});
final BackupAuthCredentialRequest messagesCredentialRequest = backupAuthTestUtil.getRequest(messagesBackupKey, aci);
final BackupAuthCredentialRequest mediaCredentialRequest = backupAuthTestUtil.getRequest(mediaBackupKey, aci);
authManager.commitBackupId(account, messagesCredentialRequest, mediaCredentialRequest).join();
verify(account).setBackupCredentialRequests(messagesCredentialRequest.serialize(), mediaCredentialRequest.serialize());
}
@ParameterizedTest @ParameterizedTest
@EnumSource @EnumSource
@NullSource @NullSource
@ -102,9 +130,11 @@ public class BackupAuthManagerTest {
when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account)); when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account));
final ThrowableAssert.ThrowingCallable commit = () -> final ThrowableAssert.ThrowingCallable commit = () ->
authManager.commitBackupId(account, backupAuthTestUtil.getRequest(backupKey, aci)).join(); authManager.commitBackupId(account,
backupAuthTestUtil.getRequest(messagesBackupKey, aci),
backupAuthTestUtil.getRequest(mediaBackupKey, aci)).join();
if (backupLevel == null) { if (backupLevel == null) {
Assertions.assertThatExceptionOfType(StatusRuntimeException.class) assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(commit) .isThrownBy(commit)
.extracting(ex -> ex.getStatus().getCode()) .extracting(ex -> ex.getStatus().getCode())
.isEqualTo(Status.Code.PERMISSION_DENIED); .isEqualTo(Status.Code.PERMISSION_DENIED);
@ -113,46 +143,70 @@ public class BackupAuthManagerTest {
} }
} }
@CartesianTest
void getBackupAuthCredentials(@CartesianTest.Enum final BackupLevel backupLevel,
@CartesianTest.Enum final BackupCredentialType credentialType) {
@ParameterizedTest
@EnumSource
@NullSource
void credentialsRequiresBackupLevel(final BackupLevel backupLevel) {
final BackupAuthManager authManager = create(backupLevel, false); final BackupAuthManager authManager = create(backupLevel, false);
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(aci); when(account.getUuid()).thenReturn(aci);
when(account.getBackupCredentialRequest()).thenReturn(backupAuthTestUtil.getRequest(backupKey, aci).serialize()); when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES))
.thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize()));
when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA))
.thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize()));
final ThrowableAssert.ThrowingCallable getCreds = () -> assertThat(authManager.getBackupAuthCredentials(account,
assertThat(authManager.getBackupAuthCredentials(account, credentialType,
clock.instant().truncatedTo(ChronoUnit.DAYS), clock.instant().truncatedTo(ChronoUnit.DAYS),
clock.instant().plus(Duration.ofDays(1)).truncatedTo(ChronoUnit.DAYS)).join()) clock.instant().plus(Duration.ofDays(1)).truncatedTo(ChronoUnit.DAYS)).join())
.hasSize(2); .hasSize(2);
if (backupLevel == null) {
Assertions.assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(getCreds)
.extracting(ex -> ex.getStatus().getCode())
.isEqualTo(Status.Code.PERMISSION_DENIED);
} else {
Assertions.assertThatNoException().isThrownBy(getCreds);
}
} }
@ParameterizedTest @ParameterizedTest
@EnumSource @EnumSource
void getReceiptCredentials(final BackupLevel backupLevel) throws VerificationFailedException { void getBackupAuthCredentialsNoBackupLevel(final BackupCredentialType credentialType) {
final BackupAuthManager authManager = create(backupLevel, false); final BackupAuthManager authManager = create(null, false);
final BackupAuthCredentialRequestContext requestContext = BackupAuthCredentialRequestContext.create(backupKey, aci);
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(aci); when(account.getUuid()).thenReturn(aci);
when(account.getBackupCredentialRequest()).thenReturn(requestContext.getRequest().serialize()); when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES))
.thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize()));
when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA))
.thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize()));
assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(() -> authManager.getBackupAuthCredentials(account,
credentialType,
clock.instant().truncatedTo(ChronoUnit.DAYS),
clock.instant().plus(Duration.ofDays(1)).truncatedTo(ChronoUnit.DAYS)).join())
.extracting(ex -> ex.getStatus().getCode())
.isEqualTo(Status.Code.PERMISSION_DENIED);
}
@CartesianTest
void getReceiptCredentials(@CartesianTest.Enum final BackupLevel backupLevel,
@CartesianTest.Enum final BackupCredentialType credentialType) throws VerificationFailedException {
final BackupAuthManager authManager = create(backupLevel, false);
final byte[] backupKey = switch (credentialType) {
case MESSAGES -> messagesBackupKey;
case MEDIA -> mediaBackupKey;
};
final BackupAuthCredentialRequestContext requestContext =
BackupAuthCredentialRequestContext.create(backupKey, aci);
final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(aci);
when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES))
.thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize()));
when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA))
.thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize()));
final Instant start = clock.instant().truncatedTo(ChronoUnit.DAYS); final Instant start = clock.instant().truncatedTo(ChronoUnit.DAYS);
final List<BackupAuthManager.Credential> creds = authManager.getBackupAuthCredentials(account, final List<BackupAuthManager.Credential> creds = authManager.getBackupAuthCredentials(account,
start, start.plus(Duration.ofDays(7))).join(); credentialType, start, start.plus(Duration.ofDays(7))).join();
assertThat(creds).hasSize(8); assertThat(creds).hasSize(8);
Instant redemptionTime = start; Instant redemptionTime = start;
@ -190,16 +244,19 @@ public class BackupAuthManagerTest {
@MethodSource @MethodSource
void invalidCredentialTimeWindows(final Instant requestRedemptionStart, final Instant requestRedemptionEnd, void invalidCredentialTimeWindows(final Instant requestRedemptionStart, final Instant requestRedemptionEnd,
final Instant now) { final Instant now) {
final BackupAuthManager authManager = create(BackupLevel.MESSAGES, false); final BackupAuthManager authManager = create(BackupLevel.FREE, false);
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(aci); when(account.getUuid()).thenReturn(aci);
when(account.getBackupCredentialRequest()).thenReturn(backupAuthTestUtil.getRequest(backupKey, aci).serialize()); when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES))
.thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize()));
when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA))
.thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize()));
clock.pin(now); clock.pin(now);
assertThatExceptionOfType(StatusRuntimeException.class) assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy( .isThrownBy(
() -> authManager.getBackupAuthCredentials(account, requestRedemptionStart, requestRedemptionEnd).join()) () -> authManager.getBackupAuthCredentials(account, BackupCredentialType.MESSAGES, requestRedemptionStart, requestRedemptionEnd).join())
.extracting(ex -> ex.getStatus().getCode()) .extracting(ex -> ex.getStatus().getCode())
.isEqualTo(Status.Code.INVALID_ARGUMENT); .isEqualTo(Status.Code.INVALID_ARGUMENT);
} }
@ -211,19 +268,23 @@ public class BackupAuthManagerTest {
final Instant day4 = Instant.EPOCH.plus(Duration.ofDays(4)); final Instant day4 = Instant.EPOCH.plus(Duration.ofDays(4));
final Instant dayMax = day0.plus(BackupAuthManager.MAX_REDEMPTION_DURATION); final Instant dayMax = day0.plus(BackupAuthManager.MAX_REDEMPTION_DURATION);
final BackupAuthManager authManager = create(BackupLevel.MESSAGES, false); final BackupAuthManager authManager = create(BackupLevel.FREE, false);
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(aci); when(account.getUuid()).thenReturn(aci);
when(account.getBackupCredentialRequest()).thenReturn(backupAuthTestUtil.getRequest(backupKey, aci).serialize()); when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES))
.thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize()));
when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA))
.thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize()));
when(account.getBackupVoucher()).thenReturn(new Account.BackupVoucher(201, day4)); when(account.getBackupVoucher()).thenReturn(new Account.BackupVoucher(201, day4));
final List<BackupAuthManager.Credential> creds = authManager.getBackupAuthCredentials(account, day0, dayMax).join(); final List<BackupAuthManager.Credential> creds = authManager.getBackupAuthCredentials(account, BackupCredentialType.MESSAGES, day0, dayMax).join();
Instant redemptionTime = day0; Instant redemptionTime = day0;
final BackupAuthCredentialRequestContext requestContext = BackupAuthCredentialRequestContext.create(backupKey, aci); final BackupAuthCredentialRequestContext requestContext = BackupAuthCredentialRequestContext.create(
messagesBackupKey, aci);
for (int i = 0; i < creds.size(); i++) { for (int i = 0; i < creds.size(); i++) {
// Before the expiration, credentials should have a media receipt, otherwise messages only // Before the expiration, credentials should have a media receipt, otherwise messages only
final BackupLevel level = i < 5 ? BackupLevel.MEDIA : BackupLevel.MESSAGES; final BackupLevel level = i < 5 ? BackupLevel.PAID : BackupLevel.FREE;
final BackupAuthManager.Credential cred = creds.get(i); final BackupAuthManager.Credential cred = creds.get(i);
assertThat(requestContext assertThat(requestContext
.receiveResponse(cred.credential(), redemptionTime, backupAuthTestUtil.params.getPublicParams()) .receiveResponse(cred.credential(), redemptionTime, backupAuthTestUtil.params.getPublicParams())
@ -240,19 +301,23 @@ public class BackupAuthManagerTest {
final Instant day2 = Instant.EPOCH.plus(Duration.ofDays(2)); final Instant day2 = Instant.EPOCH.plus(Duration.ofDays(2));
final Instant day3 = Instant.EPOCH.plus(Duration.ofDays(3)); final Instant day3 = Instant.EPOCH.plus(Duration.ofDays(3));
final BackupAuthManager authManager = create(BackupLevel.MESSAGES, false); final BackupAuthManager authManager = create(BackupLevel.FREE, false);
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(aci); when(account.getUuid()).thenReturn(aci);
when(account.getBackupVoucher()).thenReturn(new Account.BackupVoucher(3, day1)); when(account.getBackupVoucher()).thenReturn(new Account.BackupVoucher(3, day1));
final Account updated = mock(Account.class); final Account updated = mock(Account.class);
when(updated.getUuid()).thenReturn(aci); when(updated.getUuid()).thenReturn(aci);
when(updated.getBackupCredentialRequest()).thenReturn(backupAuthTestUtil.getRequest(backupKey, aci).serialize()); when(updated.getBackupCredentialRequest(BackupCredentialType.MESSAGES))
.thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize()));
when(updated.getBackupCredentialRequest(BackupCredentialType.MEDIA))
.thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize()));
when(updated.getBackupVoucher()).thenReturn(null); when(updated.getBackupVoucher()).thenReturn(null);
when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(updated)); when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(updated));
clock.pin(day2.plus(Duration.ofSeconds(1))); clock.pin(day2.plus(Duration.ofSeconds(1)));
assertThat(authManager.getBackupAuthCredentials(account, day2, day2.plus(Duration.ofDays(7))).join()) assertThat(authManager.getBackupAuthCredentials(account, BackupCredentialType.MESSAGES, day2, day2.plus(Duration.ofDays(7))).join())
.hasSize(8); .hasSize(8);
@SuppressWarnings("unchecked") final ArgumentCaptor<Consumer<Account>> accountUpdater = ArgumentCaptor.forClass( @SuppressWarnings("unchecked") final ArgumentCaptor<Consumer<Account>> accountUpdater = ArgumentCaptor.forClass(
@ -276,7 +341,7 @@ public class BackupAuthManagerTest {
@Test @Test
void redeemReceipt() throws InvalidInputException, VerificationFailedException { void redeemReceipt() throws InvalidInputException, VerificationFailedException {
final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1));
final BackupAuthManager authManager = create(BackupLevel.MESSAGES, false); final BackupAuthManager authManager = create(BackupLevel.FREE, false);
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(aci); when(account.getUuid()).thenReturn(aci);
@ -293,7 +358,7 @@ public class BackupAuthManagerTest {
final Instant newExpirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); final Instant newExpirationTime = Instant.EPOCH.plus(Duration.ofDays(1));
final Instant existingExpirationTime = Instant.EPOCH.plus(Duration.ofDays(1)).plus(Duration.ofSeconds(1)); final Instant existingExpirationTime = Instant.EPOCH.plus(Duration.ofDays(1)).plus(Duration.ofSeconds(1));
final BackupAuthManager authManager = create(BackupLevel.MESSAGES, false); final BackupAuthManager authManager = create(BackupLevel.FREE, false);
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(aci); when(account.getUuid()).thenReturn(aci);
@ -318,8 +383,8 @@ public class BackupAuthManagerTest {
void redeemExpiredReceipt() { void redeemExpiredReceipt() {
final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1));
clock.pin(expirationTime.plus(Duration.ofSeconds(1))); clock.pin(expirationTime.plus(Duration.ofSeconds(1)));
final BackupAuthManager authManager = create(BackupLevel.MESSAGES, false); final BackupAuthManager authManager = create(BackupLevel.FREE, false);
Assertions.assertThatExceptionOfType(StatusRuntimeException.class) assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(() -> authManager.redeemReceipt(mock(Account.class), receiptPresentation(3, expirationTime)).join()) .isThrownBy(() -> authManager.redeemReceipt(mock(Account.class), receiptPresentation(3, expirationTime)).join())
.extracting(ex -> ex.getStatus().getCode()) .extracting(ex -> ex.getStatus().getCode())
.isEqualTo(Status.Code.INVALID_ARGUMENT); .isEqualTo(Status.Code.INVALID_ARGUMENT);
@ -332,8 +397,8 @@ public class BackupAuthManagerTest {
void redeemInvalidLevel(long level) { void redeemInvalidLevel(long level) {
final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1));
clock.pin(expirationTime.plus(Duration.ofSeconds(1))); clock.pin(expirationTime.plus(Duration.ofSeconds(1)));
final BackupAuthManager authManager = create(BackupLevel.MESSAGES, false); final BackupAuthManager authManager = create(BackupLevel.FREE, false);
Assertions.assertThatExceptionOfType(StatusRuntimeException.class) assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(() -> .isThrownBy(() ->
authManager.redeemReceipt(mock(Account.class), receiptPresentation(level, expirationTime)).join()) authManager.redeemReceipt(mock(Account.class), receiptPresentation(level, expirationTime)).join())
.extracting(ex -> ex.getStatus().getCode()) .extracting(ex -> ex.getStatus().getCode())
@ -344,9 +409,9 @@ public class BackupAuthManagerTest {
@Test @Test
void redeemInvalidPresentation() throws InvalidInputException, VerificationFailedException { void redeemInvalidPresentation() throws InvalidInputException, VerificationFailedException {
final BackupAuthManager authManager = create(BackupLevel.MESSAGES, false); final BackupAuthManager authManager = create(BackupLevel.FREE, false);
final ReceiptCredentialPresentation invalid = receiptPresentation(ServerSecretParams.generate(), 3L, Instant.EPOCH); final ReceiptCredentialPresentation invalid = receiptPresentation(ServerSecretParams.generate(), 3L, Instant.EPOCH);
Assertions.assertThatExceptionOfType(StatusRuntimeException.class) assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(() -> authManager.redeemReceipt(mock(Account.class), invalid).join()) .isThrownBy(() -> authManager.redeemReceipt(mock(Account.class), invalid).join())
.extracting(ex -> ex.getStatus().getCode()) .extracting(ex -> ex.getStatus().getCode())
.isEqualTo(Status.Code.INVALID_ARGUMENT); .isEqualTo(Status.Code.INVALID_ARGUMENT);
@ -357,7 +422,7 @@ public class BackupAuthManagerTest {
@Test @Test
void receiptAlreadyRedeemed() throws InvalidInputException, VerificationFailedException { void receiptAlreadyRedeemed() throws InvalidInputException, VerificationFailedException {
final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1));
final BackupAuthManager authManager = create(BackupLevel.MESSAGES, false); final BackupAuthManager authManager = create(BackupLevel.FREE, false);
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(aci); when(account.getUuid()).thenReturn(aci);
@ -397,28 +462,31 @@ public class BackupAuthManagerTest {
@Test @Test
void testRateLimits() { void testRateLimits() {
final AccountsManager accountsManager = mock(AccountsManager.class); final AccountsManager accountsManager = mock(AccountsManager.class);
final BackupAuthManager authManager = create(BackupLevel.MESSAGES, true); final BackupAuthManager authManager = create(BackupLevel.FREE, true);
final BackupAuthCredentialRequest credentialRequest = backupAuthTestUtil.getRequest(backupKey, aci); final BackupAuthCredentialRequest messagesCredential = backupAuthTestUtil.getRequest(messagesBackupKey, aci);
final BackupAuthCredentialRequest mediaCredential = backupAuthTestUtil.getRequest(mediaBackupKey, aci);
final Account account = mock(Account.class); final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(aci); when(account.getUuid()).thenReturn(aci);
when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account)); when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account));
// Should be rate limited // Should be rate limited
final RateLimitExceededException ex = CompletableFutureTestUtil.assertFailsWithCause( CompletableFutureTestUtil.assertFailsWithCause(RateLimitExceededException.class,
RateLimitExceededException.class, authManager.commitBackupId(account, messagesCredential, mediaCredential));
authManager.commitBackupId(account, credentialRequest));
// If we don't change the request, shouldn't be rate limited // If we don't change the request, shouldn't be rate limited
when(account.getBackupCredentialRequest()).thenReturn(credentialRequest.serialize()); when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES))
assertDoesNotThrow(() -> authManager.commitBackupId(account, credentialRequest).join()); .thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize()));
when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA))
.thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize()));
assertDoesNotThrow(() -> authManager.commitBackupId(account, messagesCredential, mediaCredential).join());
} }
private static String experimentName(@Nullable BackupLevel backupLevel) { private static String experimentName(@Nullable BackupLevel backupLevel) {
return switch (backupLevel) { return switch (backupLevel) {
case MESSAGES -> BackupAuthManager.BACKUP_EXPERIMENT_NAME; case FREE -> BackupAuthManager.BACKUP_EXPERIMENT_NAME;
case MEDIA -> BackupAuthManager.BACKUP_MEDIA_EXPERIMENT_NAME; case PAID -> BackupAuthManager.BACKUP_MEDIA_EXPERIMENT_NAME;
case null -> "fake_experiment"; case null -> "fake_experiment";
}; };
} }

View File

@ -12,12 +12,14 @@ import java.time.Clock;
import java.time.Instant; import java.time.Instant;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.List; import java.util.List;
import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import org.signal.libsignal.zkgroup.GenericServerSecretParams; import org.signal.libsignal.zkgroup.GenericServerSecretParams;
import org.signal.libsignal.zkgroup.VerificationFailedException; import org.signal.libsignal.zkgroup.VerificationFailedException;
import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialPresentation; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialPresentation;
import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequest; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequest;
import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequestContext; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequestContext;
import org.signal.libsignal.zkgroup.backups.BackupCredentialType;
import org.signal.libsignal.zkgroup.backups.BackupLevel; import org.signal.libsignal.zkgroup.backups.BackupLevel;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.ExperimentHelper; import org.whispersystems.textsecuregcm.tests.util.ExperimentHelper;
@ -48,7 +50,7 @@ public class BackupAuthTestUtil {
final BackupAuthCredentialRequestContext ctx = BackupAuthCredentialRequestContext.create(backupKey, aci); final BackupAuthCredentialRequestContext ctx = BackupAuthCredentialRequestContext.create(backupKey, aci);
return ctx.receiveResponse( return ctx.receiveResponse(
ctx.getRequest() ctx.getRequest()
.issueCredential(clock.instant().truncatedTo(ChronoUnit.DAYS), backupLevel, params), .issueCredential(clock.instant().truncatedTo(ChronoUnit.DAYS), backupLevel, BackupCredentialType.MESSAGES, params),
redemptionTime, redemptionTime,
params.getPublicParams()) params.getPublicParams())
.present(params.getPublicParams()); .present(params.getPublicParams());
@ -57,19 +59,20 @@ public class BackupAuthTestUtil {
public List<BackupAuthManager.Credential> getCredentials( public List<BackupAuthManager.Credential> getCredentials(
final BackupLevel backupLevel, final BackupLevel backupLevel,
final BackupAuthCredentialRequest request, final BackupAuthCredentialRequest request,
final BackupCredentialType credentialType,
final Instant redemptionStart, final Instant redemptionStart,
final Instant redemptionEnd) { final Instant redemptionEnd) {
final UUID aci = UUID.randomUUID(); final UUID aci = UUID.randomUUID();
final String experimentName = switch (backupLevel) { final String experimentName = switch (backupLevel) {
case MESSAGES -> BackupAuthManager.BACKUP_EXPERIMENT_NAME; case FREE -> BackupAuthManager.BACKUP_EXPERIMENT_NAME;
case MEDIA -> BackupAuthManager.BACKUP_MEDIA_EXPERIMENT_NAME; case PAID -> BackupAuthManager.BACKUP_MEDIA_EXPERIMENT_NAME;
}; };
final BackupAuthManager issuer = new BackupAuthManager( final BackupAuthManager issuer = new BackupAuthManager(
ExperimentHelper.withEnrollment(experimentName, aci), null, null, null, null, params, clock); ExperimentHelper.withEnrollment(experimentName, aci), null, null, null, null, params, clock);
Account account = mock(Account.class); Account account = mock(Account.class);
when(account.getUuid()).thenReturn(aci); when(account.getUuid()).thenReturn(aci);
when(account.getBackupCredentialRequest()).thenReturn(request.serialize()); when(account.getBackupCredentialRequest(credentialType)).thenReturn(Optional.of(request.serialize()));
return issuer.getBackupAuthCredentials(account, redemptionStart, redemptionEnd).join(); return issuer.getBackupAuthCredentials(account, credentialType, redemptionStart, redemptionEnd).join();
} }
} }

View File

@ -45,10 +45,12 @@ import java.util.function.Function;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.assertj.core.api.ThrowableAssert;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.CartesianTest; import org.junitpioneer.jupiter.cartesian.CartesianTest;
@ -58,6 +60,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.zkgroup.GenericServerSecretParams; import org.signal.libsignal.zkgroup.GenericServerSecretParams;
import org.signal.libsignal.zkgroup.VerificationFailedException; import org.signal.libsignal.zkgroup.VerificationFailedException;
import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialPresentation; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialPresentation;
import org.signal.libsignal.zkgroup.backups.BackupCredentialType;
import org.signal.libsignal.zkgroup.backups.BackupLevel; import org.signal.libsignal.zkgroup.backups.BackupLevel;
import org.whispersystems.textsecuregcm.attachments.TusAttachmentGenerator; import org.whispersystems.textsecuregcm.attachments.TusAttachmentGenerator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedBackupUser; import org.whispersystems.textsecuregcm.auth.AuthenticatedBackupUser;
@ -87,7 +90,6 @@ public class BackupManagerTest {
private static final CopyParameters COPY_PARAM = new CopyParameters( private static final CopyParameters COPY_PARAM = new CopyParameters(
3, "abc", 100, 3, "abc", 100,
COPY_ENCRYPTION_PARAM, TestRandomUtil.nextBytes(15)); COPY_ENCRYPTION_PARAM, TestRandomUtil.nextBytes(15));
private static final String COPY_DEST_STRING = Base64.getEncoder().encodeToString(COPY_PARAM.destinationMediaId());
private final TestClock testClock = TestClock.now(); private final TestClock testClock = TestClock.now();
private final BackupAuthTestUtil backupAuthTestUtil = new BackupAuthTestUtil(testClock); private final BackupAuthTestUtil backupAuthTestUtil = new BackupAuthTestUtil(testClock);
@ -125,6 +127,62 @@ public class BackupManagerTest {
testClock); testClock);
} }
@ParameterizedTest
@CsvSource({
"FREE, FREE, false",
"FREE, PAID, true",
"PAID, FREE, false",
"PAID, PAID, false"
})
void checkBackupLevel(final BackupLevel authenticateBackupLevel,
final BackupLevel requiredLevel,
final boolean expectException) {
final AuthenticatedBackupUser backupUser =
backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, authenticateBackupLevel);
final ThrowableAssert.ThrowingCallable checkBackupLevel =
() -> BackupManager.checkBackupLevel(backupUser, requiredLevel);
if (expectException) {
assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(checkBackupLevel)
.extracting(StatusRuntimeException::getStatus)
.extracting(Status::getCode)
.isEqualTo(Status.Code.PERMISSION_DENIED);
} else {
assertThatNoException().isThrownBy(checkBackupLevel);
}
}
@ParameterizedTest
@CsvSource({
"MESSAGES, MESSAGES, false",
"MESSAGES, MEDIA, true",
"MEDIA, MESSAGES, true",
"MEDIA, MEDIA, false"
})
void checkBackupCredentialType(final BackupCredentialType authenticateCredentialType,
final BackupCredentialType requiredCredentialType,
final boolean expectException) {
final AuthenticatedBackupUser backupUser =
backupUser(TestRandomUtil.nextBytes(16), authenticateCredentialType, BackupLevel.FREE);
final ThrowableAssert.ThrowingCallable checkCredentialType =
() -> BackupManager.checkBackupCredentialType(backupUser, requiredCredentialType);
if (expectException) {
assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(checkCredentialType)
.extracting(StatusRuntimeException::getStatus)
.extracting(Status::getCode)
.isEqualTo(Status.Code.PERMISSION_DENIED);
} else {
assertThatNoException().isThrownBy(checkCredentialType);
}
}
@ParameterizedTest @ParameterizedTest
@EnumSource @EnumSource
public void createBackup(final BackupLevel backupLevel) { public void createBackup(final BackupLevel backupLevel) {
@ -132,7 +190,7 @@ public class BackupManagerTest {
final Instant now = Instant.ofEpochSecond(Duration.ofDays(1).getSeconds()); final Instant now = Instant.ofEpochSecond(Duration.ofDays(1).getSeconds());
testClock.pin(now); testClock.pin(now);
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), backupLevel); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, backupLevel);
backupManager.createMessageBackupUploadDescriptor(backupUser).join(); backupManager.createMessageBackupUploadDescriptor(backupUser).join();
verify(tusCredentialGenerator, times(1)) verify(tusCredentialGenerator, times(1))
@ -144,22 +202,46 @@ public class BackupManagerTest {
assertThat(info.mediaUsedSpace()).isEqualTo(Optional.empty()); assertThat(info.mediaUsedSpace()).isEqualTo(Optional.empty());
// Check that the initial expiration times are the initial write times // Check that the initial expiration times are the initial write times
checkExpectedExpirations(now, backupLevel == BackupLevel.MEDIA ? now : null, backupUser); checkExpectedExpirations(now, backupLevel == BackupLevel.PAID ? now : null, backupUser);
}
@ParameterizedTest
@EnumSource
public void createBackupWrongCredentialType(final BackupLevel backupLevel) {
final Instant now = Instant.ofEpochSecond(Duration.ofDays(1).getSeconds());
testClock.pin(now);
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, backupLevel);
assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(() -> backupManager.createMessageBackupUploadDescriptor(backupUser).join())
.matches(exception -> exception.getStatus().getCode() == Status.PERMISSION_DENIED.getCode());
} }
@Test @Test
public void createTemporaryMediaAttachmentRateLimited() { public void createTemporaryMediaAttachmentRateLimited() {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID);
when(mediaUploadLimiter.validateAsync(eq(BackupManager.rateLimitKey(backupUser)))) when(mediaUploadLimiter.validateAsync(eq(BackupManager.rateLimitKey(backupUser))))
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null))); .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null)));
final RateLimitExceededException e = CompletableFutureTestUtil.assertFailsWithCause( CompletableFutureTestUtil.assertFailsWithCause(
RateLimitExceededException.class, RateLimitExceededException.class,
backupManager.createTemporaryAttachmentUploadDescriptor(backupUser).toCompletableFuture()); backupManager.createTemporaryAttachmentUploadDescriptor(backupUser).toCompletableFuture());
} }
@Test @Test
public void createTemporaryMediaAttachmentWrongTier() { public void createTemporaryMediaAttachmentWrongTier() {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MESSAGES); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.FREE);
assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(() -> backupManager.createTemporaryAttachmentUploadDescriptor(backupUser))
.extracting(StatusRuntimeException::getStatus)
.extracting(Status::getCode)
.isEqualTo(Status.Code.PERMISSION_DENIED);
}
@Test
public void createTemporaryMediaAttachmentWrongCredentialType() {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, BackupLevel.PAID);
assertThatExceptionOfType(StatusRuntimeException.class) assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(() -> backupManager.createTemporaryAttachmentUploadDescriptor(backupUser)) .isThrownBy(() -> backupManager.createTemporaryAttachmentUploadDescriptor(backupUser))
.extracting(StatusRuntimeException::getStatus) .extracting(StatusRuntimeException::getStatus)
@ -170,7 +252,7 @@ public class BackupManagerTest {
@ParameterizedTest @ParameterizedTest
@EnumSource @EnumSource
public void ttlRefresh(final BackupLevel backupLevel) { public void ttlRefresh(final BackupLevel backupLevel) {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), backupLevel); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, backupLevel);
final Instant tstart = Instant.ofEpochSecond(1).plus(Duration.ofDays(1)); final Instant tstart = Instant.ofEpochSecond(1).plus(Duration.ofDays(1));
final Instant tnext = tstart.plus(Duration.ofSeconds(1)); final Instant tnext = tstart.plus(Duration.ofSeconds(1));
@ -185,7 +267,7 @@ public class BackupManagerTest {
checkExpectedExpirations( checkExpectedExpirations(
tnext, tnext,
backupLevel == BackupLevel.MEDIA ? tnext : null, backupLevel == BackupLevel.PAID ? tnext : null,
backupUser); backupUser);
} }
@ -195,7 +277,7 @@ public class BackupManagerTest {
final Instant tstart = Instant.ofEpochSecond(1).plus(Duration.ofDays(1)); final Instant tstart = Instant.ofEpochSecond(1).plus(Duration.ofDays(1));
final Instant tnext = tstart.plus(Duration.ofSeconds(1)); final Instant tnext = tstart.plus(Duration.ofSeconds(1));
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), backupLevel); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, backupLevel);
// create backup at t=tstart // create backup at t=tstart
testClock.pin(tstart); testClock.pin(tstart);
@ -207,7 +289,7 @@ public class BackupManagerTest {
checkExpectedExpirations( checkExpectedExpirations(
tnext, tnext,
backupLevel == BackupLevel.MEDIA ? tnext : null, backupLevel == BackupLevel.PAID ? tnext : null,
backupUser); backupUser);
} }
@ -215,7 +297,7 @@ public class BackupManagerTest {
public void invalidPresentationNoPublicKey() throws VerificationFailedException { public void invalidPresentationNoPublicKey() throws VerificationFailedException {
final BackupAuthCredentialPresentation invalidPresentation = backupAuthTestUtil.getPresentation( final BackupAuthCredentialPresentation invalidPresentation = backupAuthTestUtil.getPresentation(
GenericServerSecretParams.generate(), GenericServerSecretParams.generate(),
BackupLevel.MESSAGES, backupKey, aci); BackupLevel.FREE, backupKey, aci);
final ECKeyPair keyPair = Curve.generateKeyPair(); final ECKeyPair keyPair = Curve.generateKeyPair();
@ -233,10 +315,10 @@ public class BackupManagerTest {
@Test @Test
public void invalidPresentationCorrectSignature() throws VerificationFailedException { public void invalidPresentationCorrectSignature() throws VerificationFailedException {
final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation(
BackupLevel.MESSAGES, backupKey, aci); BackupLevel.FREE, backupKey, aci);
final BackupAuthCredentialPresentation invalidPresentation = backupAuthTestUtil.getPresentation( final BackupAuthCredentialPresentation invalidPresentation = backupAuthTestUtil.getPresentation(
GenericServerSecretParams.generate(), GenericServerSecretParams.generate(),
BackupLevel.MESSAGES, backupKey, aci); BackupLevel.FREE, backupKey, aci);
final ECKeyPair keyPair = Curve.generateKeyPair(); final ECKeyPair keyPair = Curve.generateKeyPair();
backupManager.setPublicKey( backupManager.setPublicKey(
@ -256,7 +338,7 @@ public class BackupManagerTest {
@Test @Test
public void unknownPublicKey() throws VerificationFailedException { public void unknownPublicKey() throws VerificationFailedException {
final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation(
BackupLevel.MESSAGES, backupKey, aci); BackupLevel.FREE, backupKey, aci);
final ECKeyPair keyPair = Curve.generateKeyPair(); final ECKeyPair keyPair = Curve.generateKeyPair();
final byte[] signature = keyPair.getPrivateKey().calculateSignature(presentation.serialize()); final byte[] signature = keyPair.getPrivateKey().calculateSignature(presentation.serialize());
@ -272,7 +354,7 @@ public class BackupManagerTest {
@Test @Test
public void mismatchedPublicKey() throws VerificationFailedException { public void mismatchedPublicKey() throws VerificationFailedException {
final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation(
BackupLevel.MESSAGES, backupKey, aci); BackupLevel.FREE, backupKey, aci);
final ECKeyPair keyPair1 = Curve.generateKeyPair(); final ECKeyPair keyPair1 = Curve.generateKeyPair();
final ECKeyPair keyPair2 = Curve.generateKeyPair(); final ECKeyPair keyPair2 = Curve.generateKeyPair();
@ -295,7 +377,7 @@ public class BackupManagerTest {
@Test @Test
public void signatureValidation() throws VerificationFailedException { public void signatureValidation() throws VerificationFailedException {
final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation(
BackupLevel.MESSAGES, backupKey, aci); BackupLevel.FREE, backupKey, aci);
final ECKeyPair keyPair = Curve.generateKeyPair(); final ECKeyPair keyPair = Curve.generateKeyPair();
final byte[] signature = keyPair.getPrivateKey().calculateSignature(presentation.serialize()); final byte[] signature = keyPair.getPrivateKey().calculateSignature(presentation.serialize());
@ -322,7 +404,7 @@ public class BackupManagerTest {
// correct signature // correct signature
final AuthenticatedBackupUser user = backupManager.authenticateBackupUser(presentation, signature).join(); final AuthenticatedBackupUser user = backupManager.authenticateBackupUser(presentation, signature).join();
assertThat(user.backupId()).isEqualTo(presentation.getBackupId()); assertThat(user.backupId()).isEqualTo(presentation.getBackupId());
assertThat(user.backupLevel()).isEqualTo(BackupLevel.MESSAGES); assertThat(user.backupLevel()).isEqualTo(BackupLevel.FREE);
} }
@Test @Test
@ -330,7 +412,7 @@ public class BackupManagerTest {
// credential for 1 day after epoch // credential for 1 day after epoch
testClock.pin(Instant.ofEpochSecond(1).plus(Duration.ofDays(1))); testClock.pin(Instant.ofEpochSecond(1).plus(Duration.ofDays(1)));
final BackupAuthCredentialPresentation oldCredential = backupAuthTestUtil.getPresentation(BackupLevel.MESSAGES, final BackupAuthCredentialPresentation oldCredential = backupAuthTestUtil.getPresentation(BackupLevel.FREE,
backupKey, aci); backupKey, aci);
final ECKeyPair keyPair = Curve.generateKeyPair(); final ECKeyPair keyPair = Curve.generateKeyPair();
final byte[] signature = keyPair.getPrivateKey().calculateSignature(oldCredential.serialize()); final byte[] signature = keyPair.getPrivateKey().calculateSignature(oldCredential.serialize());
@ -355,7 +437,7 @@ public class BackupManagerTest {
@Test @Test
public void copySuccess() { public void copySuccess() {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID);
final CopyResult copied = copy(backupUser); final CopyResult copied = copy(backupUser);
assertThat(copied.cdn()).isEqualTo(3); assertThat(copied.cdn()).isEqualTo(3);
@ -372,7 +454,7 @@ public class BackupManagerTest {
@Test @Test
public void copyFailure() { public void copyFailure() {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID);
assertThat(copyError(backupUser, new SourceObjectNotFoundException()).outcome()) assertThat(copyError(backupUser, new SourceObjectNotFoundException()).outcome())
.isEqualTo(CopyResult.Outcome.SOURCE_NOT_FOUND); .isEqualTo(CopyResult.Outcome.SOURCE_NOT_FOUND);
@ -384,7 +466,7 @@ public class BackupManagerTest {
@Test @Test
public void copyPartialSuccess() { public void copyPartialSuccess() {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID);
final List<CopyParameters> toCopy = List.of( final List<CopyParameters> toCopy = List.of(
new CopyParameters(3, "success", 100, COPY_ENCRYPTION_PARAM, TestRandomUtil.nextBytes(15)), new CopyParameters(3, "success", 100, COPY_ENCRYPTION_PARAM, TestRandomUtil.nextBytes(15)),
new CopyParameters(3, "missing", 200, COPY_ENCRYPTION_PARAM, TestRandomUtil.nextBytes(15)), new CopyParameters(3, "missing", 200, COPY_ENCRYPTION_PARAM, TestRandomUtil.nextBytes(15)),
@ -413,9 +495,20 @@ public class BackupManagerTest {
assertThat(AttributeValues.getLong(backup, BackupsDb.ATTR_MEDIA_COUNT, -1L)).isEqualTo(1L); assertThat(AttributeValues.getLong(backup, BackupsDb.ATTR_MEDIA_COUNT, -1L)).isEqualTo(1L);
} }
@Test
public void copyWrongCredentialType() {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, BackupLevel.PAID);
assertThatExceptionOfType(StatusRuntimeException.class)
.isThrownBy(() -> copy(backupUser))
.extracting(StatusRuntimeException::getStatus)
.extracting(Status::getCode)
.isEqualTo(Status.Code.PERMISSION_DENIED);
}
@Test @Test
public void quotaEnforcementNoRecalculation() { public void quotaEnforcementNoRecalculation() {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID);
verifyNoInteractions(remoteStorageManager); verifyNoInteractions(remoteStorageManager);
// set the backupsDb to be out of quota at t=0 // set the backupsDb to be out of quota at t=0
@ -432,7 +525,7 @@ public class BackupManagerTest {
@Test @Test
public void quotaEnforcementRecalculation() { public void quotaEnforcementRecalculation() {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID);
final String backupMediaPrefix = "%s/%s/".formatted(backupUser.backupDir(), backupUser.mediaDir()); final String backupMediaPrefix = "%s/%s/".formatted(backupUser.backupDir(), backupUser.mediaDir());
final long remainingAfterRecalc = BackupManager.MAX_TOTAL_BACKUP_MEDIA_BYTES - COPY_PARAM.destinationObjectSize(); final long remainingAfterRecalc = BackupManager.MAX_TOTAL_BACKUP_MEDIA_BYTES - COPY_PARAM.destinationObjectSize();
@ -462,7 +555,7 @@ public class BackupManagerTest {
@CartesianTest.Values(booleans = {true, false}) boolean hasSpaceBeforeRecalc, @CartesianTest.Values(booleans = {true, false}) boolean hasSpaceBeforeRecalc,
@CartesianTest.Values(booleans = {true, false}) boolean hasSpaceAfterRecalc, @CartesianTest.Values(booleans = {true, false}) boolean hasSpaceAfterRecalc,
@CartesianTest.Values(booleans = {true, false}) boolean doesReaclc) { @CartesianTest.Values(booleans = {true, false}) boolean doesReaclc) {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID);
final String backupMediaPrefix = "%s/%s/".formatted(backupUser.backupDir(), backupUser.mediaDir()); final String backupMediaPrefix = "%s/%s/".formatted(backupUser.backupDir(), backupUser.mediaDir());
final long destSize = COPY_PARAM.destinationObjectSize(); final long destSize = COPY_PARAM.destinationObjectSize();
@ -496,7 +589,7 @@ public class BackupManagerTest {
@ValueSource(strings = {"", "cursor"}) @ValueSource(strings = {"", "cursor"})
public void list(final String cursorVal) { public void list(final String cursorVal) {
final Optional<String> cursor = Optional.of(cursorVal).filter(StringUtils::isNotBlank); final Optional<String> cursor = Optional.of(cursorVal).filter(StringUtils::isNotBlank);
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, BackupLevel.PAID);
final String backupMediaPrefix = "%s/%s/".formatted(backupUser.backupDir(), backupUser.mediaDir()); final String backupMediaPrefix = "%s/%s/".formatted(backupUser.backupDir(), backupUser.mediaDir());
when(remoteStorageManager.cdnNumber()).thenReturn(13); when(remoteStorageManager.cdnNumber()).thenReturn(13);
@ -519,14 +612,14 @@ public class BackupManagerTest {
@Test @Test
public void deleteEntireBackup() { public void deleteEntireBackup() {
final AuthenticatedBackupUser original = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser original = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, BackupLevel.PAID);
testClock.pin(Instant.ofEpochSecond(10)); testClock.pin(Instant.ofEpochSecond(10));
// Deleting should swap the backupDir for the user // Deleting should swap the backupDir for the user
backupManager.deleteEntireBackup(original).join(); backupManager.deleteEntireBackup(original).join();
verifyNoInteractions(remoteStorageManager); verifyNoInteractions(remoteStorageManager);
final AuthenticatedBackupUser after = retrieveBackupUser(original.backupId(), BackupLevel.MEDIA); final AuthenticatedBackupUser after = retrieveBackupUser(original.backupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID);
assertThat(original.backupDir()).isNotEqualTo(after.backupDir()); assertThat(original.backupDir()).isNotEqualTo(after.backupDir());
assertThat(original.mediaDir()).isNotEqualTo(after.mediaDir()); assertThat(original.mediaDir()).isNotEqualTo(after.mediaDir());
@ -552,7 +645,7 @@ public class BackupManagerTest {
@Test @Test
public void delete() { public void delete() {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID);
final byte[] mediaId = TestRandomUtil.nextBytes(16); final byte[] mediaId = TestRandomUtil.nextBytes(16);
final String backupMediaKey = "%s/%s/%s".formatted( final String backupMediaKey = "%s/%s/%s".formatted(
backupUser.backupDir(), backupUser.backupDir(),
@ -571,9 +664,24 @@ public class BackupManagerTest {
.isEqualTo(new UsageInfo(93, 999)); .isEqualTo(new UsageInfo(93, 999));
} }
@Test
public void deleteWrongCredentialType() {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, BackupLevel.PAID);
final byte[] mediaId = TestRandomUtil.nextBytes(16);
final String backupMediaKey = "%s/%s/%s".formatted(
backupUser.backupDir(),
backupUser.mediaDir(),
BackupManager.encodeMediaIdForCdn(mediaId));
assertThatThrownBy(() ->
backupManager.deleteMedia(backupUser, List.of(new BackupManager.StorageDescriptor(5, mediaId))).then().block())
.isInstanceOf(StatusRuntimeException.class)
.matches(e -> ((StatusRuntimeException) e).getStatus().getCode() == Status.PERMISSION_DENIED.getCode());
}
@Test @Test
public void deleteUnknownCdn() { public void deleteUnknownCdn() {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID);
final BackupManager.StorageDescriptor sd = new BackupManager.StorageDescriptor(4, TestRandomUtil.nextBytes(15)); final BackupManager.StorageDescriptor sd = new BackupManager.StorageDescriptor(4, TestRandomUtil.nextBytes(15));
when(remoteStorageManager.cdnNumber()).thenReturn(5); when(remoteStorageManager.cdnNumber()).thenReturn(5);
assertThatThrownBy(() -> assertThatThrownBy(() ->
@ -584,7 +692,7 @@ public class BackupManagerTest {
@Test @Test
public void deletePartialFailure() { public void deletePartialFailure() {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID);
final List<BackupManager.StorageDescriptor> descriptors = new ArrayList<>(); final List<BackupManager.StorageDescriptor> descriptors = new ArrayList<>();
long initialBytes = 0; long initialBytes = 0;
@ -621,7 +729,7 @@ public class BackupManagerTest {
@Test @Test
public void alreadyDeleted() { public void alreadyDeleted() {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID);
final byte[] mediaId = TestRandomUtil.nextBytes(16); final byte[] mediaId = TestRandomUtil.nextBytes(16);
final String backupMediaKey = "%s/%s/%s".formatted( final String backupMediaKey = "%s/%s/%s".formatted(
backupUser.backupDir(), backupUser.backupDir(),
@ -642,7 +750,7 @@ public class BackupManagerTest {
@Test @Test
public void listExpiredBackups() { public void listExpiredBackups() {
final List<AuthenticatedBackupUser> backupUsers = IntStream.range(0, 10) final List<AuthenticatedBackupUser> backupUsers = IntStream.range(0, 10)
.mapToObj(i -> backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA)) .mapToObj(i -> backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, BackupLevel.PAID))
.toList(); .toList();
for (int i = 0; i < backupUsers.size(); i++) { for (int i = 0; i < backupUsers.size(); i++) {
testClock.pin(Instant.ofEpochSecond(i)); testClock.pin(Instant.ofEpochSecond(i));
@ -680,11 +788,11 @@ public class BackupManagerTest {
// refreshed media timestamp at t=5 // refreshed media timestamp at t=5
testClock.pin(Instant.ofEpochSecond(5)); testClock.pin(Instant.ofEpochSecond(5));
backupManager.createMessageBackupUploadDescriptor(backupUser(backupId, BackupLevel.MEDIA)).join(); backupManager.createMessageBackupUploadDescriptor(backupUser(backupId, BackupCredentialType.MESSAGES, BackupLevel.PAID)).join();
// refreshed messages timestamp at t=6 // refreshed messages timestamp at t=6
testClock.pin(Instant.ofEpochSecond(6)); testClock.pin(Instant.ofEpochSecond(6));
backupManager.createMessageBackupUploadDescriptor(backupUser(backupId, BackupLevel.MESSAGES)).join(); backupManager.createMessageBackupUploadDescriptor(backupUser(backupId, BackupCredentialType.MESSAGES, BackupLevel.FREE)).join();
Function<Instant, List<ExpiredBackup>> getExpired = time -> backupManager Function<Instant, List<ExpiredBackup>> getExpired = time -> backupManager
.getExpiredBackups(1, Schedulers.immediate(), time) .getExpiredBackups(1, Schedulers.immediate(), time)
@ -704,7 +812,7 @@ public class BackupManagerTest {
@ParameterizedTest @ParameterizedTest
@EnumSource(mode = EnumSource.Mode.INCLUDE, names = {"MEDIA", "ALL"}) @EnumSource(mode = EnumSource.Mode.INCLUDE, names = {"MEDIA", "ALL"})
public void expireBackup(ExpiredBackup.ExpirationType expirationType) { public void expireBackup(ExpiredBackup.ExpirationType expirationType) {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, BackupLevel.PAID);
backupManager.createMessageBackupUploadDescriptor(backupUser).join(); backupManager.createMessageBackupUploadDescriptor(backupUser).join();
final String expectedPrefixToDelete = switch (expirationType) { final String expectedPrefixToDelete = switch (expirationType) {
@ -746,7 +854,7 @@ public class BackupManagerTest {
@Test @Test
public void deleteBackupPaginated() { public void deleteBackupPaginated() {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, BackupLevel.PAID);
backupManager.createMessageBackupUploadDescriptor(backupUser).join(); backupManager.createMessageBackupUploadDescriptor(backupUser).join();
final ExpiredBackup expiredBackup = expiredBackup(ExpiredBackup.ExpirationType.MEDIA, backupUser); final ExpiredBackup expiredBackup = expiredBackup(ExpiredBackup.ExpirationType.MEDIA, backupUser);
@ -845,9 +953,9 @@ public class BackupManagerTest {
} }
/** /**
* Create BackupUser with the provided backupId and tier * Create BackupUser with the provided backupId, credential type, and tier
*/ */
private AuthenticatedBackupUser backupUser(final byte[] backupId, final BackupLevel backupLevel) { private AuthenticatedBackupUser backupUser(final byte[] backupId, final BackupCredentialType credentialType, final BackupLevel backupLevel) {
// Won't actually validate the public key, but need to have a public key to perform BackupsDB operations // Won't actually validate the public key, but need to have a public key to perform BackupsDB operations
byte[] privateKey = new byte[32]; byte[] privateKey = new byte[32];
ByteBuffer.wrap(privateKey).put(backupId); ByteBuffer.wrap(privateKey).put(backupId);
@ -856,14 +964,14 @@ public class BackupManagerTest {
} catch (InvalidKeyException e) { } catch (InvalidKeyException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
return retrieveBackupUser(backupId, backupLevel); return retrieveBackupUser(backupId, credentialType, backupLevel);
} }
/** /**
* Retrieve an existing BackupUser from the database * Retrieve an existing BackupUser from the database
*/ */
private AuthenticatedBackupUser retrieveBackupUser(final byte[] backupId, final BackupLevel backupLevel) { private AuthenticatedBackupUser retrieveBackupUser(final byte[] backupId, final BackupCredentialType credentialType, final BackupLevel backupLevel) {
final BackupsDb.AuthenticationData authData = backupsDb.retrieveAuthenticationData(backupId).join().get(); final BackupsDb.AuthenticationData authData = backupsDb.retrieveAuthenticationData(backupId).join().get();
return new AuthenticatedBackupUser(backupId, backupLevel, authData.backupDir(), authData.mediaDir()); return new AuthenticatedBackupUser(backupId, credentialType, backupLevel, authData.backupDir(), authData.mediaDir());
} }
} }

View File

@ -23,6 +23,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.zkgroup.backups.BackupCredentialType;
import org.signal.libsignal.zkgroup.backups.BackupLevel; import org.signal.libsignal.zkgroup.backups.BackupLevel;
import org.whispersystems.textsecuregcm.auth.AuthenticatedBackupUser; import org.whispersystems.textsecuregcm.auth.AuthenticatedBackupUser;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtension; import org.whispersystems.textsecuregcm.storage.DynamoDbExtension;
@ -51,7 +52,7 @@ public class BackupsDbTest {
@Test @Test
public void trackMediaStats() { public void trackMediaStats() {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID);
// add at least one message backup so we can describe it // add at least one message backup so we can describe it
backupsDb.addMessageBackup(backupUser).join(); backupsDb.addMessageBackup(backupUser).join();
int total = 0; int total = 0;
@ -74,7 +75,7 @@ public class BackupsDbTest {
@ValueSource(booleans = {false, true}) @ValueSource(booleans = {false, true})
public void setUsage(boolean mediaAlreadyExists) { public void setUsage(boolean mediaAlreadyExists) {
testClock.pin(Instant.ofEpochSecond(5)); testClock.pin(Instant.ofEpochSecond(5));
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID);
if (mediaAlreadyExists) { if (mediaAlreadyExists) {
this.backupsDb.trackMedia(backupUser, 1, 10).join(); this.backupsDb.trackMedia(backupUser, 1, 10).join();
} }
@ -90,12 +91,12 @@ public class BackupsDbTest {
final byte[] backupId = TestRandomUtil.nextBytes(16); final byte[] backupId = TestRandomUtil.nextBytes(16);
// Refresh media/messages at t=0 // Refresh media/messages at t=0
testClock.pin(Instant.ofEpochSecond(0L)); testClock.pin(Instant.ofEpochSecond(0L));
backupsDb.setPublicKey(backupId, BackupLevel.MEDIA, Curve.generateKeyPair().getPublicKey()).join(); backupsDb.setPublicKey(backupId, BackupLevel.PAID, Curve.generateKeyPair().getPublicKey()).join();
this.backupsDb.ttlRefresh(backupUser(backupId, BackupLevel.MEDIA)).join(); this.backupsDb.ttlRefresh(backupUser(backupId, BackupCredentialType.MEDIA, BackupLevel.PAID)).join();
// refresh only messages at t=2 // refresh only messages at t=2
testClock.pin(Instant.ofEpochSecond(2L)); testClock.pin(Instant.ofEpochSecond(2L));
this.backupsDb.ttlRefresh(backupUser(backupId, BackupLevel.MESSAGES)).join(); this.backupsDb.ttlRefresh(backupUser(backupId, BackupCredentialType.MEDIA, BackupLevel.FREE)).join();
final Function<Instant, List<ExpiredBackup>> expiredBackups = purgeTime -> backupsDb final Function<Instant, List<ExpiredBackup>> expiredBackups = purgeTime -> backupsDb
.getExpiredBackups(1, Schedulers.immediate(), purgeTime) .getExpiredBackups(1, Schedulers.immediate(), purgeTime)
@ -132,13 +133,13 @@ public class BackupsDbTest {
final byte[] backupId = TestRandomUtil.nextBytes(16); final byte[] backupId = TestRandomUtil.nextBytes(16);
// Refresh media/messages at t=0 // Refresh media/messages at t=0
testClock.pin(Instant.ofEpochSecond(0L)); testClock.pin(Instant.ofEpochSecond(0L));
backupsDb.setPublicKey(backupId, BackupLevel.MEDIA, Curve.generateKeyPair().getPublicKey()).join(); backupsDb.setPublicKey(backupId, BackupLevel.PAID, Curve.generateKeyPair().getPublicKey()).join();
this.backupsDb.ttlRefresh(backupUser(backupId, BackupLevel.MEDIA)).join(); this.backupsDb.ttlRefresh(backupUser(backupId, BackupCredentialType.MEDIA, BackupLevel.PAID)).join();
if (expirationType == ExpiredBackup.ExpirationType.MEDIA) { if (expirationType == ExpiredBackup.ExpirationType.MEDIA) {
// refresh only messages at t=2 so that we only expire media at t=1 // refresh only messages at t=2 so that we only expire media at t=1
testClock.pin(Instant.ofEpochSecond(2L)); testClock.pin(Instant.ofEpochSecond(2L));
this.backupsDb.ttlRefresh(backupUser(backupId, BackupLevel.MESSAGES)).join(); this.backupsDb.ttlRefresh(backupUser(backupId, BackupCredentialType.MEDIA, BackupLevel.FREE)).join();
} }
final Function<Instant, Optional<ExpiredBackup>> expiredBackups = purgeTime -> { final Function<Instant, Optional<ExpiredBackup>> expiredBackups = purgeTime -> {
@ -192,7 +193,7 @@ public class BackupsDbTest {
// should be nothing to expire at t=1 // should be nothing to expire at t=1
assertThat(opt).isEmpty(); assertThat(opt).isEmpty();
// The backup should still exist // The backup should still exist
backupsDb.describeBackup(backupUser(backupId, BackupLevel.MEDIA)).join(); backupsDb.describeBackup(backupUser(backupId, BackupCredentialType.MEDIA, BackupLevel.PAID)).join();
} else { } else {
// Cleaned up the failed attempt, now should tell us to clean the whole backup // Cleaned up the failed attempt, now should tell us to clean the whole backup
assertThat(opt.get()).matches(eb -> eb.expirationType() == ExpiredBackup.ExpirationType.ALL, assertThat(opt.get()).matches(eb -> eb.expirationType() == ExpiredBackup.ExpirationType.ALL,
@ -202,7 +203,7 @@ public class BackupsDbTest {
// The backup entry should be gone // The backup entry should be gone
assertThat(CompletableFutureTestUtil.assertFailsWithCause(StatusRuntimeException.class, assertThat(CompletableFutureTestUtil.assertFailsWithCause(StatusRuntimeException.class,
backupsDb.describeBackup(backupUser(backupId, BackupLevel.MEDIA))) backupsDb.describeBackup(backupUser(backupId, BackupCredentialType.MEDIA, BackupLevel.PAID)))
.getStatus().getCode()) .getStatus().getCode())
.isEqualTo(Status.Code.NOT_FOUND); .isEqualTo(Status.Code.NOT_FOUND);
assertThat(expiredBackups.apply(Instant.ofEpochSecond(10))).isEmpty(); assertThat(expiredBackups.apply(Instant.ofEpochSecond(10))).isEmpty();
@ -211,9 +212,9 @@ public class BackupsDbTest {
@Test @Test
public void list() { public void list() {
final AuthenticatedBackupUser u1 = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MESSAGES); final AuthenticatedBackupUser u1 = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.FREE);
final AuthenticatedBackupUser u2 = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser u2 = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID);
final AuthenticatedBackupUser u3 = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser u3 = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID);
// add at least one message backup, so we can describe it // add at least one message backup, so we can describe it
testClock.pin(Instant.ofEpochSecond(10)); testClock.pin(Instant.ofEpochSecond(10));
@ -248,7 +249,7 @@ public class BackupsDbTest {
assertThat(sbm3.lastRefresh()).isEqualTo(sbm3.lastMediaRefresh()).isEqualTo(Instant.ofEpochSecond(30)); assertThat(sbm3.lastRefresh()).isEqualTo(sbm3.lastMediaRefresh()).isEqualTo(Instant.ofEpochSecond(30));
} }
private AuthenticatedBackupUser backupUser(final byte[] backupId, final BackupLevel backupLevel) { private AuthenticatedBackupUser backupUser(final byte[] backupId, final BackupCredentialType credentialType, final BackupLevel backupLevel) {
return new AuthenticatedBackupUser(backupId, backupLevel, "myBackupDir", "myMediaDir"); return new AuthenticatedBackupUser(backupId, credentialType, backupLevel, "myBackupDir", "myMediaDir");
} }
} }

View File

@ -10,6 +10,7 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import io.dropwizard.auth.AuthValueFactoryProvider; import io.dropwizard.auth.AuthValueFactoryProvider;
@ -51,6 +52,7 @@ import org.signal.libsignal.zkgroup.InvalidInputException;
import org.signal.libsignal.zkgroup.ServerSecretParams; import org.signal.libsignal.zkgroup.ServerSecretParams;
import org.signal.libsignal.zkgroup.VerificationFailedException; import org.signal.libsignal.zkgroup.VerificationFailedException;
import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialPresentation; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialPresentation;
import org.signal.libsignal.zkgroup.backups.BackupCredentialType;
import org.signal.libsignal.zkgroup.backups.BackupLevel; import org.signal.libsignal.zkgroup.backups.BackupLevel;
import org.signal.libsignal.zkgroup.receipts.ClientZkReceiptOperations; import org.signal.libsignal.zkgroup.receipts.ClientZkReceiptOperations;
import org.signal.libsignal.zkgroup.receipts.ReceiptCredential; import org.signal.libsignal.zkgroup.receipts.ReceiptCredential;
@ -59,8 +61,8 @@ import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialRequestContext;
import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialResponse; import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialResponse;
import org.signal.libsignal.zkgroup.receipts.ReceiptSerial; import org.signal.libsignal.zkgroup.receipts.ReceiptSerial;
import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations; import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.AuthenticatedBackupUser; import org.whispersystems.textsecuregcm.auth.AuthenticatedBackupUser;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.backup.BackupAuthManager; import org.whispersystems.textsecuregcm.backup.BackupAuthManager;
import org.whispersystems.textsecuregcm.backup.BackupAuthTestUtil; import org.whispersystems.textsecuregcm.backup.BackupAuthTestUtil;
import org.whispersystems.textsecuregcm.backup.BackupManager; import org.whispersystems.textsecuregcm.backup.BackupManager;
@ -71,6 +73,7 @@ import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.GrpcStatusRuntimeExceptionMapper; import org.whispersystems.textsecuregcm.mappers.GrpcStatusRuntimeExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.EnumMapUtil;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
@ -95,7 +98,8 @@ public class ArchiveControllerTest {
.build(); .build();
private final UUID aci = UUID.randomUUID(); private final UUID aci = UUID.randomUUID();
private final byte[] backupKey = TestRandomUtil.nextBytes(32); private final byte[] messagesBackupKey = TestRandomUtil.nextBytes(32);
private final byte[] mediaBackupKey = TestRandomUtil.nextBytes(32);
@BeforeEach @BeforeEach
public void setUp() { public void setUp() {
@ -132,7 +136,7 @@ public class ArchiveControllerTest {
public void anonymousAuthOnly(final String method, final String path, final String body) public void anonymousAuthOnly(final String method, final String path, final String body)
throws VerificationFailedException { throws VerificationFailedException {
final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation(
BackupLevel.MEDIA, backupKey, aci); BackupLevel.PAID, messagesBackupKey, aci);
final Invocation.Builder request = resources.getJerseyTest() final Invocation.Builder request = resources.getJerseyTest()
.target(path) .target(path)
.request() .request()
@ -152,15 +156,22 @@ public class ArchiveControllerTest {
@Test @Test
public void setBackupId() throws RateLimitExceededException { public void setBackupId() throws RateLimitExceededException {
when(backupAuthManager.commitBackupId(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(backupAuthManager.commitBackupId(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
.target("v1/archives/backupid") .target("v1/archives/backupid")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ArchiveController.SetBackupIdRequest(backupAuthTestUtil.getRequest(backupKey, aci)), .put(Entity.entity(new ArchiveController.SetBackupIdRequest(
backupAuthTestUtil.getRequest(messagesBackupKey, aci),
backupAuthTestUtil.getRequest(mediaBackupKey, aci)),
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(204); assertThat(response.getStatus()).isEqualTo(204);
verify(backupAuthManager).commitBackupId(AuthHelper.VALID_ACCOUNT,
backupAuthTestUtil.getRequest(messagesBackupKey, aci),
backupAuthTestUtil.getRequest(mediaBackupKey, aci));
} }
@Test @Test
@ -191,7 +202,7 @@ public class ArchiveControllerTest {
when(backupManager.setPublicKey(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(backupManager.setPublicKey(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation(
BackupLevel.MEDIA, backupKey, aci); BackupLevel.PAID, messagesBackupKey, aci);
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
.target("v1/archives/keys") .target("v1/archives/keys")
.request() .request()
@ -208,7 +219,7 @@ public class ArchiveControllerTest {
when(backupManager.setPublicKey(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(backupManager.setPublicKey(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation(
BackupLevel.MEDIA, backupKey, aci); BackupLevel.PAID, messagesBackupKey, aci);
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
.target("v1/archives/keys") .target("v1/archives/keys")
.request() .request()
@ -223,7 +234,7 @@ public class ArchiveControllerTest {
when(backupManager.setPublicKey(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null)); when(backupManager.setPublicKey(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));
final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation(
BackupLevel.MEDIA, backupKey, aci); BackupLevel.PAID, messagesBackupKey, aci);
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
.target("v1/archives/keys") .target("v1/archives/keys")
.request() .request()
@ -239,8 +250,8 @@ public class ArchiveControllerTest {
@ParameterizedTest @ParameterizedTest
@CsvSource(textBlock = """ @CsvSource(textBlock = """
{}, 422 {}, 422
'{"backupAuthCredentialRequest": "aaa"}', 400 '{"messagesBackupAuthCredentialRequest": "aaa", "mediaBackupAuthCredentialRequest": "aaa"}', 400
'{"backupAuthCredentialRequest": ""}', 400 '{"messagesBackupAuthCredentialRequest": "", "mediaBackupAuthCredentialRequest": ""}', 400
""") """)
public void setBackupIdInvalid(final String requestBody, final int expectedStatus) { public void setBackupIdInvalid(final String requestBody, final int expectedStatus) {
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
@ -264,15 +275,17 @@ public class ArchiveControllerTest {
public void setBackupIdException(final Exception ex, final boolean sync, final int expectedStatus) public void setBackupIdException(final Exception ex, final boolean sync, final int expectedStatus)
throws RateLimitExceededException { throws RateLimitExceededException {
if (sync) { if (sync) {
when(backupAuthManager.commitBackupId(any(), any())).thenThrow(ex); when(backupAuthManager.commitBackupId(any(), any(), any())).thenThrow(ex);
} else { } else {
when(backupAuthManager.commitBackupId(any(), any())).thenReturn(CompletableFuture.failedFuture(ex)); when(backupAuthManager.commitBackupId(any(), any(), any())).thenReturn(CompletableFuture.failedFuture(ex));
} }
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
.target("v1/archives/backupid") .target("v1/archives/backupid")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new ArchiveController.SetBackupIdRequest(backupAuthTestUtil.getRequest(backupKey, aci)), .put(Entity.entity(new ArchiveController.SetBackupIdRequest(
backupAuthTestUtil.getRequest(messagesBackupKey, aci),
backupAuthTestUtil.getRequest(mediaBackupKey, aci)),
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(expectedStatus); assertThat(response.getStatus()).isEqualTo(expectedStatus);
} }
@ -281,18 +294,36 @@ public class ArchiveControllerTest {
public void getCredentials() { public void getCredentials() {
final Instant start = Instant.now().truncatedTo(ChronoUnit.DAYS); final Instant start = Instant.now().truncatedTo(ChronoUnit.DAYS);
final Instant end = start.plus(Duration.ofDays(1)); final Instant end = start.plus(Duration.ofDays(1));
final List<BackupAuthManager.Credential> expectedResponse = backupAuthTestUtil.getCredentials(
BackupLevel.MEDIA, backupAuthTestUtil.getRequest(backupKey, aci), start, end); final Map<BackupCredentialType, List<BackupAuthManager.Credential>> expectedCredentialsByType =
when(backupAuthManager.getBackupAuthCredentials(any(), eq(start), eq(end))).thenReturn( EnumMapUtil.toEnumMap(BackupCredentialType.class, credentialType -> backupAuthTestUtil.getCredentials(
CompletableFuture.completedFuture(expectedResponse)); BackupLevel.PAID, backupAuthTestUtil.getRequest(messagesBackupKey, aci), credentialType, start, end));
final ArchiveController.BackupAuthCredentialsResponse creds = resources.getJerseyTest()
expectedCredentialsByType.forEach((credentialType, expectedCredentials) ->
when(backupAuthManager.getBackupAuthCredentials(any(), eq(credentialType), eq(start), eq(end))).thenReturn(
CompletableFuture.completedFuture(expectedCredentials)));
final ArchiveController.BackupAuthCredentialsResponse credentialResponse = resources.getJerseyTest()
.target("v1/archives/auth") .target("v1/archives/auth")
.queryParam("redemptionStartSeconds", start.getEpochSecond()) .queryParam("redemptionStartSeconds", start.getEpochSecond())
.queryParam("redemptionEndSeconds", end.getEpochSecond()) .queryParam("redemptionEndSeconds", end.getEpochSecond())
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(ArchiveController.BackupAuthCredentialsResponse.class); .get(ArchiveController.BackupAuthCredentialsResponse.class);
assertThat(creds.credentials().getFirst().redemptionTime()).isEqualTo(start.getEpochSecond());
expectedCredentialsByType.forEach((credentialType, expectedCredentials) -> {
assertThat(credentialResponse.credentials().get(credentialType)).size().isEqualTo(expectedCredentials.size());
assertThat(credentialResponse.credentials().get(credentialType).getFirst().redemptionTime())
.isEqualTo(start.getEpochSecond());
for (int i = 0; i < expectedCredentials.size(); i++) {
assertThat(credentialResponse.credentials().get(credentialType).get(i).redemptionTime())
.isEqualTo(expectedCredentials.get(i).redemptionTime().getEpochSecond());
assertThat(credentialResponse.credentials().get(credentialType).get(i).credential())
.isEqualTo(expectedCredentials.get(i).credential().serialize());
}
});
} }
public enum BadCredentialsType {MISSING_START, MISSING_END, MISSING_BOTH} public enum BadCredentialsType {MISSING_START, MISSING_END, MISSING_BOTH}
@ -322,9 +353,9 @@ public class ArchiveControllerTest {
@Test @Test
public void getBackupInfo() throws VerificationFailedException { public void getBackupInfo() throws VerificationFailedException {
final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation(
BackupLevel.MEDIA, backupKey, aci); BackupLevel.PAID, messagesBackupKey, aci);
when(backupManager.authenticateBackupUser(any(), any())) when(backupManager.authenticateBackupUser(any(), any()))
.thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID)));
when(backupManager.backupInfo(any())).thenReturn(CompletableFuture.completedFuture(new BackupManager.BackupInfo( when(backupManager.backupInfo(any())).thenReturn(CompletableFuture.completedFuture(new BackupManager.BackupInfo(
1, "myBackupDir", "myMediaDir", "filename", Optional.empty()))); 1, "myBackupDir", "myMediaDir", "filename", Optional.empty())));
final ArchiveController.BackupInfoResponse response = resources.getJerseyTest() final ArchiveController.BackupInfoResponse response = resources.getJerseyTest()
@ -342,9 +373,9 @@ public class ArchiveControllerTest {
@Test @Test
public void putMediaBatchSuccess() throws VerificationFailedException { public void putMediaBatchSuccess() throws VerificationFailedException {
final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation(
BackupLevel.MEDIA, backupKey, aci); BackupLevel.PAID, messagesBackupKey, aci);
when(backupManager.authenticateBackupUser(any(), any())) when(backupManager.authenticateBackupUser(any(), any()))
.thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID)));
final byte[][] mediaIds = new byte[][]{TestRandomUtil.nextBytes(15), TestRandomUtil.nextBytes(15)}; final byte[][] mediaIds = new byte[][]{TestRandomUtil.nextBytes(15), TestRandomUtil.nextBytes(15)};
when(backupManager.copyToBackup(any(), any())) when(backupManager.copyToBackup(any(), any()))
.thenReturn(Flux.just( .thenReturn(Flux.just(
@ -389,9 +420,9 @@ public class ArchiveControllerTest {
public void putMediaBatchPartialFailure() throws VerificationFailedException { public void putMediaBatchPartialFailure() throws VerificationFailedException {
final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation(
BackupLevel.MEDIA, backupKey, aci); BackupLevel.PAID, messagesBackupKey, aci);
when(backupManager.authenticateBackupUser(any(), any())) when(backupManager.authenticateBackupUser(any(), any()))
.thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID)));
final byte[][] mediaIds = IntStream.range(0, 4).mapToObj(i -> TestRandomUtil.nextBytes(15)).toArray(byte[][]::new); final byte[][] mediaIds = IntStream.range(0, 4).mapToObj(i -> TestRandomUtil.nextBytes(15)).toArray(byte[][]::new);
when(backupManager.copyToBackup(any(), any())) when(backupManager.copyToBackup(any(), any()))
@ -448,9 +479,9 @@ public class ArchiveControllerTest {
@Test @Test
public void copyMediaWithNegativeLength() throws VerificationFailedException { public void copyMediaWithNegativeLength() throws VerificationFailedException {
final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation(
BackupLevel.MEDIA, backupKey, aci); BackupLevel.PAID, messagesBackupKey, aci);
when(backupManager.authenticateBackupUser(any(), any())) when(backupManager.authenticateBackupUser(any(), any()))
.thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID)));
final byte[][] mediaIds = new byte[][]{TestRandomUtil.nextBytes(15), TestRandomUtil.nextBytes(15)}; final byte[][] mediaIds = new byte[][]{TestRandomUtil.nextBytes(15), TestRandomUtil.nextBytes(15)};
final Response r = resources.getJerseyTest() final Response r = resources.getJerseyTest()
.target("v1/archives/media/batch") .target("v1/archives/media/batch")
@ -483,9 +514,9 @@ public class ArchiveControllerTest {
@CartesianTest.Values(booleans = {true, false}) final boolean cursorReturned) @CartesianTest.Values(booleans = {true, false}) final boolean cursorReturned)
throws VerificationFailedException { throws VerificationFailedException {
final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation(
BackupLevel.MEDIA, backupKey, aci); BackupLevel.PAID, messagesBackupKey, aci);
when(backupManager.authenticateBackupUser(any(), any())) when(backupManager.authenticateBackupUser(any(), any()))
.thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID)));
final byte[] mediaId = TestRandomUtil.nextBytes(15); final byte[] mediaId = TestRandomUtil.nextBytes(15);
final Optional<String> expectedCursor = cursorProvided ? Optional.of("myCursor") : Optional.empty(); final Optional<String> expectedCursor = cursorProvided ? Optional.of("myCursor") : Optional.empty();
@ -517,10 +548,10 @@ public class ArchiveControllerTest {
@Test @Test
public void delete() throws VerificationFailedException { public void delete() throws VerificationFailedException {
final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation(BackupLevel.MEDIA, final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation(BackupLevel.PAID,
backupKey, aci); messagesBackupKey, aci);
when(backupManager.authenticateBackupUser(any(), any())) when(backupManager.authenticateBackupUser(any(), any()))
.thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID)));
final ArchiveController.DeleteMedia deleteRequest = new ArchiveController.DeleteMedia( final ArchiveController.DeleteMedia deleteRequest = new ArchiveController.DeleteMedia(
IntStream IntStream
@ -544,9 +575,9 @@ public class ArchiveControllerTest {
@Test @Test
public void mediaUploadForm() throws RateLimitExceededException, VerificationFailedException { public void mediaUploadForm() throws RateLimitExceededException, VerificationFailedException {
final BackupAuthCredentialPresentation presentation = final BackupAuthCredentialPresentation presentation =
backupAuthTestUtil.getPresentation(BackupLevel.MEDIA, backupKey, aci); backupAuthTestUtil.getPresentation(BackupLevel.PAID, messagesBackupKey, aci);
when(backupManager.authenticateBackupUser(any(), any())) when(backupManager.authenticateBackupUser(any(), any()))
.thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID)));
when(backupManager.createTemporaryAttachmentUploadDescriptor(any())) when(backupManager.createTemporaryAttachmentUploadDescriptor(any()))
.thenReturn(CompletableFuture.completedFuture( .thenReturn(CompletableFuture.completedFuture(
new BackupUploadDescriptor(3, "abc", Map.of("k", "v"), "example.org"))); new BackupUploadDescriptor(3, "abc", Map.of("k", "v"), "example.org")));
@ -576,9 +607,9 @@ public class ArchiveControllerTest {
@Test @Test
public void readAuth() throws VerificationFailedException { public void readAuth() throws VerificationFailedException {
final BackupAuthCredentialPresentation presentation = final BackupAuthCredentialPresentation presentation =
backupAuthTestUtil.getPresentation(BackupLevel.MEDIA, backupKey, aci); backupAuthTestUtil.getPresentation(BackupLevel.PAID, messagesBackupKey, aci);
when(backupManager.authenticateBackupUser(any(), any())) when(backupManager.authenticateBackupUser(any(), any()))
.thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID)));
when(backupManager.generateReadAuth(any(), eq(3))).thenReturn(Map.of("key", "value")); when(backupManager.generateReadAuth(any(), eq(3))).thenReturn(Map.of("key", "value"));
final ArchiveController.ReadAuthResponse response = resources.getJerseyTest() final ArchiveController.ReadAuthResponse response = resources.getJerseyTest()
.target("v1/archives/auth/read") .target("v1/archives/auth/read")
@ -593,7 +624,7 @@ public class ArchiveControllerTest {
@Test @Test
public void readAuthInvalidParam() throws VerificationFailedException { public void readAuthInvalidParam() throws VerificationFailedException {
final BackupAuthCredentialPresentation presentation = final BackupAuthCredentialPresentation presentation =
backupAuthTestUtil.getPresentation(BackupLevel.MEDIA, backupKey, aci); backupAuthTestUtil.getPresentation(BackupLevel.PAID, messagesBackupKey, aci);
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()
.target("v1/archives/auth/read") .target("v1/archives/auth/read")
.request() .request()
@ -615,9 +646,9 @@ public class ArchiveControllerTest {
@Test @Test
public void deleteEntireBackup() throws VerificationFailedException { public void deleteEntireBackup() throws VerificationFailedException {
final BackupAuthCredentialPresentation presentation = final BackupAuthCredentialPresentation presentation =
backupAuthTestUtil.getPresentation(BackupLevel.MEDIA, backupKey, aci); backupAuthTestUtil.getPresentation(BackupLevel.PAID, messagesBackupKey, aci);
when(backupManager.authenticateBackupUser(any(), any())) when(backupManager.authenticateBackupUser(any(), any()))
.thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID)));
when(backupManager.deleteEntireBackup(any())).thenReturn(CompletableFuture.completedFuture(null)); when(backupManager.deleteEntireBackup(any())).thenReturn(CompletableFuture.completedFuture(null));
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()
.target("v1/archives/") .target("v1/archives/")
@ -631,25 +662,25 @@ public class ArchiveControllerTest {
@Test @Test
public void invalidSourceAttachmentKey() throws VerificationFailedException { public void invalidSourceAttachmentKey() throws VerificationFailedException {
final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation(
BackupLevel.MEDIA, backupKey, aci); BackupLevel.PAID, messagesBackupKey, aci);
when(backupManager.authenticateBackupUser(any(), any())) when(backupManager.authenticateBackupUser(any(), any()))
.thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID)));
final Response r = resources.getJerseyTest() final Response r = resources.getJerseyTest()
.target("v1/archives/media") .target("v1/archives/media")
.request() .request()
.header("X-Signal-ZK-Auth", Base64.getEncoder().encodeToString(presentation.serialize())) .header("X-Signal-ZK-Auth", Base64.getEncoder().encodeToString(presentation.serialize()))
.header("X-Signal-ZK-Auth-Signature", "aaa") .header("X-Signal-ZK-Auth-Signature", "aaa")
.put(Entity.json(new ArchiveController.CopyMediaRequest( .put(Entity.json(new ArchiveController.CopyMediaRequest(
new RemoteAttachment(3, "invalid/urlBase64"), new RemoteAttachment(3, "invalid/urlBase64"),
100, 100,
TestRandomUtil.nextBytes(15), TestRandomUtil.nextBytes(15),
TestRandomUtil.nextBytes(32), TestRandomUtil.nextBytes(32),
TestRandomUtil.nextBytes(32), TestRandomUtil.nextBytes(32),
TestRandomUtil.nextBytes(16)))); TestRandomUtil.nextBytes(16))));
assertThat(r.getStatus()).isEqualTo(422); assertThat(r.getStatus()).isEqualTo(422);
} }
private static AuthenticatedBackupUser backupUser(byte[] backupId, BackupLevel backupLevel) { private static AuthenticatedBackupUser backupUser(final byte[] backupId, final BackupCredentialType credentialType, final BackupLevel backupLevel) {
return new AuthenticatedBackupUser(backupId, backupLevel, "myBackupDir", "myMediaDir"); return new AuthenticatedBackupUser(backupId, credentialType, backupLevel, "myBackupDir", "myMediaDir");
} }
} }

View File

@ -53,6 +53,7 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.signal.libsignal.zkgroup.backups.BackupCredentialType;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables;
@ -426,7 +427,7 @@ class AccountsTest {
generateAccount(e164, existingUuid, UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1))); generateAccount(e164, existingUuid, UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1)));
// the backup credential request and share-set are always preserved across account reclaims // the backup credential request and share-set are always preserved across account reclaims
existingAccount.setBackupCredentialRequest(TestRandomUtil.nextBytes(32)); existingAccount.setBackupCredentialRequests(TestRandomUtil.nextBytes(32), TestRandomUtil.nextBytes(32));
existingAccount.setSvr3ShareSet(TestRandomUtil.nextBytes(100)); existingAccount.setSvr3ShareSet(TestRandomUtil.nextBytes(100));
createAccount(existingAccount); createAccount(existingAccount);
final Account secondAccount = final Account secondAccount =
@ -435,7 +436,10 @@ class AccountsTest {
reclaimAccount(secondAccount); reclaimAccount(secondAccount);
final Account reclaimed = accounts.getByAccountIdentifier(existingUuid).get(); final Account reclaimed = accounts.getByAccountIdentifier(existingUuid).get();
assertThat(reclaimed.getBackupCredentialRequest()).isEqualTo(existingAccount.getBackupCredentialRequest()); assertThat(reclaimed.getBackupCredentialRequest(BackupCredentialType.MESSAGES).get())
.isEqualTo(existingAccount.getBackupCredentialRequest(BackupCredentialType.MESSAGES).get());
assertThat(reclaimed.getBackupCredentialRequest(BackupCredentialType.MEDIA).get())
.isEqualTo(existingAccount.getBackupCredentialRequest(BackupCredentialType.MEDIA).get());
assertThat(reclaimed.getSvr3ShareSet()).isEqualTo(existingAccount.getSvr3ShareSet()); assertThat(reclaimed.getSvr3ShareSet()).isEqualTo(existingAccount.getSvr3ShareSet());
} }