diff --git a/integration-tests/src/main/java/org/signal/integration/TestUser.java b/integration-tests/src/main/java/org/signal/integration/TestUser.java index 8a7dbeff5..6619637c6 100644 --- a/integration-tests/src/main/java/org/signal/integration/TestUser.java +++ b/integration-tests/src/main/java/org/signal/integration/TestUser.java @@ -18,6 +18,7 @@ import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.IdentityKeyPair; +import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.signal.libsignal.protocol.state.SignedPreKeyRecord; import org.signal.libsignal.protocol.util.KeyHelper; @@ -161,15 +162,19 @@ public class TestUser { : aciIdentityKey; final TestDevice device = requireNonNull(devices.get(deviceId)); final SignedPreKeyRecord signedPreKeyRecord = device.latestSignedPreKey(identity); - return new PreKeySetPublicView( - Collections.emptyList(), - identity.getPublicKey(), - new SignedPreKeyPublicView( - signedPreKeyRecord.getId(), - signedPreKeyRecord.getKeyPair().getPublicKey(), - signedPreKeyRecord.getSignature() - ) - ); + try { + return new PreKeySetPublicView( + Collections.emptyList(), + identity.getPublicKey(), + new SignedPreKeyPublicView( + signedPreKeyRecord.getId(), + signedPreKeyRecord.getKeyPair().getPublicKey(), + signedPreKeyRecord.getSignature() + ) + ); + } catch (InvalidKeyException e) { + throw new RuntimeException(e); + } } public record SignedPreKeyPublicView( diff --git a/pom.xml b/pom.xml index 688d8ef90..bfba1137d 100644 --- a/pom.xml +++ b/pom.xml @@ -272,7 +272,7 @@ org.signal libsignal-server - 0.54.2 + 0.60.0 org.signal.forks diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedBackupUser.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedBackupUser.java index 8f20e2062..20709c806 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedBackupUser.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedBackupUser.java @@ -5,6 +5,12 @@ package org.whispersystems.textsecuregcm.auth; +import org.signal.libsignal.zkgroup.backups.BackupCredentialType; import org.signal.libsignal.zkgroup.backups.BackupLevel; -public record AuthenticatedBackupUser(byte[] backupId, BackupLevel backupLevel, String backupDir, String mediaDir) {} +public record AuthenticatedBackupUser(byte[] backupId, + BackupCredentialType credentialType, + BackupLevel backupLevel, + String backupDir, + String mediaDir) { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupAuthManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupAuthManager.java index b3cc70934..7abd416de 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupAuthManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupAuthManager.java @@ -21,6 +21,7 @@ import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.VerificationFailedException; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequest; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialResponse; +import org.signal.libsignal.zkgroup.backups.BackupCredentialType; import org.signal.libsignal.zkgroup.backups.BackupLevel; import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialPresentation; import org.signal.libsignal.zkgroup.receipts.ReceiptSerial; @@ -29,7 +30,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; -import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; @@ -81,22 +81,34 @@ public class BackupAuthManager { } /** - * Store a credential request containing a blinded backup-id for future use. + * Store credential requests containing blinded backup-ids for future use. * - * @param account The account using the backup-id - * @param backupAuthCredentialRequest A request containing the blinded backup-id + * @param account The account using the backup-id + * @param messagesBackupCredentialRequest A request containing the blinded backup-id the client will use to upload + * message backups + * @param mediaBackupCredentialRequest A request containing the blinded backup-id the client will use to upload + * media backups * @return A future that completes when the credentialRequest has been stored * @throws RateLimitExceededException If too many backup-ids have been committed */ public CompletableFuture commitBackupId(final Account account, - final BackupAuthCredentialRequest backupAuthCredentialRequest) { + final BackupAuthCredentialRequest messagesBackupCredentialRequest, + final BackupAuthCredentialRequest mediaBackupCredentialRequest) { if (configuredBackupLevel(account).isEmpty()) { throw Status.PERMISSION_DENIED.withDescription("Backups not allowed on account").asRuntimeException(); } + final byte[] serializedMessageCredentialRequest = messagesBackupCredentialRequest.serialize(); + final byte[] serializedMediaCredentialRequest = mediaBackupCredentialRequest.serialize(); - byte[] serializedRequest = backupAuthCredentialRequest.serialize(); - byte[] existingRequest = account.getBackupCredentialRequest(); - if (existingRequest != null && MessageDigest.isEqual(serializedRequest, existingRequest)) { + final boolean messageCredentialRequestMatches = account.getBackupCredentialRequest(BackupCredentialType.MESSAGES) + .map(storedCredentialRequest -> MessageDigest.isEqual(storedCredentialRequest, serializedMessageCredentialRequest)) + .orElse(false); + + final boolean mediaCredentialRequestMatches = account.getBackupCredentialRequest(BackupCredentialType.MEDIA) + .map(storedCredentialRequest -> MessageDigest.isEqual(storedCredentialRequest, serializedMediaCredentialRequest)) + .orElse(false); + + if (messageCredentialRequestMatches && mediaCredentialRequestMatches) { // No need to update or enforce rate limits, this is the credential that the user has already // committed to. return CompletableFuture.completedFuture(null); @@ -105,7 +117,7 @@ public class BackupAuthManager { return rateLimiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID) .validateAsync(account.getUuid()) .thenCompose(ignored -> this.accountsManager - .updateAsync(account, acc -> acc.setBackupCredentialRequest(serializedRequest)) + .updateAsync(account, a -> a.setBackupCredentialRequests(serializedMessageCredentialRequest, serializedMediaCredentialRequest)) .thenRun(Util.NOOP)) .toCompletableFuture(); } @@ -123,12 +135,14 @@ public class BackupAuthManager { * method will also remove the expired voucher from the account. * * @param account The account to create the credentials for + * @param credentialType The type of backup credentials to create * @param redemptionStart The day (must be truncated to a day boundary) the first credential should be valid * @param redemptionEnd The day (must be truncated to a day boundary) the last credential should be valid * @return Credentials and the day on which they may be redeemed */ public CompletableFuture> getBackupAuthCredentials( final Account account, + final BackupCredentialType credentialType, final Instant redemptionStart, final Instant redemptionEnd) { @@ -139,7 +153,7 @@ public class BackupAuthManager { if (hasExpiredVoucher(a)) { a.setBackupVoucher(null); } - }).thenCompose(updated -> getBackupAuthCredentials(updated, redemptionStart, redemptionEnd)); + }).thenCompose(updated -> getBackupAuthCredentials(updated, credentialType, redemptionStart, redemptionEnd)); } // If this account isn't allowed some level of backup access via configuration, don't continue @@ -157,23 +171,20 @@ public class BackupAuthManager { } // fetch the blinded backup-id the account should have previously committed to - final byte[] committedBytes = account.getBackupCredentialRequest(); - if (committedBytes == null) { - throw Status.NOT_FOUND.withDescription("No blinded backup-id has been added to the account").asRuntimeException(); - } + final byte[] committedBytes = account.getBackupCredentialRequest(credentialType) + .orElseThrow(() -> Status.NOT_FOUND.withDescription("No blinded backup-id has been added to the account").asRuntimeException()); try { // create a credential for every day in the requested period final BackupAuthCredentialRequest credentialReq = new BackupAuthCredentialRequest(committedBytes); return CompletableFuture.completedFuture(Stream - .iterate(redemptionStart, curr -> curr.plus(Duration.ofDays(1))) - .takeWhile(redemptionTime -> !redemptionTime.isAfter(redemptionEnd)) + .iterate(redemptionStart, redemptionTime -> !redemptionTime.isAfter(redemptionEnd), curr -> curr.plus(Duration.ofDays(1))) .map(redemptionTime -> { // Check if the account has a voucher that's good for a certain receiptLevel at redemption time, otherwise // use the default receipt level final BackupLevel backupLevel = storedBackupLevel(account, redemptionTime).orElse(configuredBackupLevel); return new Credential( - credentialReq.issueCredential(redemptionTime, backupLevel, serverSecretParams), + credentialReq.issueCredential(redemptionTime, backupLevel, credentialType, serverSecretParams), redemptionTime); }) .toList()); @@ -210,7 +221,7 @@ public class BackupAuthManager { final long receiptLevel = receiptCredentialPresentation.getReceiptLevel(); - if (BackupLevelUtil.fromReceiptLevel(receiptLevel) != BackupLevel.MEDIA) { + if (BackupLevelUtil.fromReceiptLevel(receiptLevel) != BackupLevel.PAID) { throw Status.INVALID_ARGUMENT .withDescription("server does not recognize the requested receipt level") .asRuntimeException(); @@ -281,10 +292,10 @@ public class BackupAuthManager { */ private Optional configuredBackupLevel(final Account account) { if (inExperiment(BACKUP_MEDIA_EXPERIMENT_NAME, account)) { - return Optional.of(BackupLevel.MEDIA); + return Optional.of(BackupLevel.PAID); } if (inExperiment(BACKUP_EXPERIMENT_NAME, account)) { - return Optional.of(BackupLevel.MESSAGES); + return Optional.of(BackupLevel.FREE); } return Optional.empty(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupManager.java index 7ab32ea26..f917bc472 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupManager.java @@ -28,25 +28,22 @@ import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.signal.libsignal.zkgroup.GenericServerSecretParams; import org.signal.libsignal.zkgroup.VerificationFailedException; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialPresentation; +import org.signal.libsignal.zkgroup.backups.BackupCredentialType; import org.signal.libsignal.zkgroup.backups.BackupLevel; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.attachments.AttachmentGenerator; import org.whispersystems.textsecuregcm.attachments.TusAttachmentGenerator; import org.whispersystems.textsecuregcm.auth.AuthenticatedBackupUser; -import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.util.AsyncTimerUtil; import org.whispersystems.textsecuregcm.util.ExceptionUtils; +import org.whispersystems.textsecuregcm.util.Pair; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Scheduler; public class BackupManager { - private static final Logger logger = LoggerFactory.getLogger(BackupManager.class); - static final String MESSAGE_BACKUP_NAME = "messageBackup"; public static final long MAX_TOTAL_BACKUP_MEDIA_BYTES = DataSize.gibibytes(100).toBytes(); static final long MAX_MEDIA_OBJECT_SIZE = DataSize.mebibytes(101).toBytes(); @@ -120,8 +117,10 @@ public class BackupManager { // Note: this is a special case where we can't validate the presentation signature against the stored public key // because we are currently setting it. We check against the provided public key, but we must also verify that // there isn't an existing, different stored public key for the backup-id (verified with a condition expression) - final BackupLevel backupLevel = verifyPresentation(presentation).verifySignature(signature, publicKey); - return backupsDb.setPublicKey(presentation.getBackupId(), backupLevel, publicKey) + final Pair credentialTypeAndBackupLevel = + verifyPresentation(presentation).verifySignature(signature, publicKey); + + return backupsDb.setPublicKey(presentation.getBackupId(), credentialTypeAndBackupLevel.second(), publicKey) .exceptionally(ExceptionUtils.exceptionallyHandler(PublicKeyConflictException.class, ex -> { Metrics.counter(ZK_AUTHN_COUNTER_NAME, SUCCESS_TAG_NAME, String.valueOf(false), @@ -144,7 +143,8 @@ public class BackupManager { */ public CompletableFuture createMessageBackupUploadDescriptor( final AuthenticatedBackupUser backupUser) { - checkBackupLevel(backupUser, BackupLevel.MESSAGES); + checkBackupLevel(backupUser, BackupLevel.FREE); + checkBackupCredentialType(backupUser, BackupCredentialType.MESSAGES); // this could race with concurrent updates, but the only effect would be last-writer-wins on the timestamp return backupsDb @@ -154,7 +154,8 @@ public class BackupManager { public CompletableFuture createTemporaryAttachmentUploadDescriptor( final AuthenticatedBackupUser backupUser) { - checkBackupLevel(backupUser, BackupLevel.MEDIA); + checkBackupLevel(backupUser, BackupLevel.PAID); + checkBackupCredentialType(backupUser, BackupCredentialType.MEDIA); return rateLimiters.forDescriptor(RateLimiters.For.BACKUP_ATTACHMENT) .validateAsync(rateLimitKey(backupUser)).thenApply(ignored -> { @@ -172,7 +173,7 @@ public class BackupManager { * @param backupUser an already ZK authenticated backup user */ public CompletableFuture ttlRefresh(final AuthenticatedBackupUser backupUser) { - checkBackupLevel(backupUser, BackupLevel.MESSAGES); + checkBackupLevel(backupUser, BackupLevel.FREE); // update message backup TTL return backupsDb.ttlRefresh(backupUser); } @@ -187,7 +188,7 @@ public class BackupManager { * @return Information about the existing backup */ public CompletableFuture backupInfo(final AuthenticatedBackupUser backupUser) { - checkBackupLevel(backupUser, BackupLevel.MESSAGES); + checkBackupLevel(backupUser, BackupLevel.FREE); return backupsDb.describeBackup(backupUser) .thenApply(backupDescription -> new BackupInfo( backupDescription.cdn(), @@ -210,7 +211,8 @@ public class BackupManager { * detailing why the object could not be copied. */ public Flux copyToBackup(final AuthenticatedBackupUser backupUser, List toCopy) { - checkBackupLevel(backupUser, BackupLevel.MEDIA); + checkBackupLevel(backupUser, BackupLevel.PAID); + checkBackupCredentialType(backupUser, BackupCredentialType.MEDIA); return Mono // Figure out how many objects we're allowed to copy, updating the quota usage for the amount we are allowed @@ -349,7 +351,7 @@ public class BackupManager { * @return A map of headers to include with CDN requests */ public Map generateReadAuth(final AuthenticatedBackupUser backupUser, final int cdnNumber) { - checkBackupLevel(backupUser, BackupLevel.MESSAGES); + checkBackupLevel(backupUser, BackupLevel.FREE); if (cdnNumber != 3) { throw Status.INVALID_ARGUMENT.withDescription("unknown cdn").asRuntimeException(); } @@ -377,7 +379,7 @@ public class BackupManager { final AuthenticatedBackupUser backupUser, final Optional cursor, final int limit) { - checkBackupLevel(backupUser, BackupLevel.MESSAGES); + checkBackupLevel(backupUser, BackupLevel.FREE); return remoteStorageManager.list(cdnMediaDirectory(backupUser), cursor, limit) .thenApply(result -> new ListMediaResult( @@ -395,7 +397,7 @@ public class BackupManager { } public CompletableFuture deleteEntireBackup(final AuthenticatedBackupUser backupUser) { - checkBackupLevel(backupUser, BackupLevel.MESSAGES); + checkBackupLevel(backupUser, BackupLevel.FREE); return backupsDb // Try to swap out the backupDir for the user .scheduleBackupDeletion(backupUser) @@ -408,7 +410,8 @@ public class BackupManager { public Flux deleteMedia(final AuthenticatedBackupUser backupUser, final List storageDescriptors) { - checkBackupLevel(backupUser, BackupLevel.MESSAGES); + checkBackupLevel(backupUser, BackupLevel.FREE); + checkBackupCredentialType(backupUser, BackupCredentialType.MEDIA); // Check for a cdn we don't know how to process if (storageDescriptors.stream().anyMatch(sd -> sd.cdn() != remoteStorageManager.cdnNumber())) { @@ -492,10 +495,16 @@ public class BackupManager { // There was no stored public key, use a bunk public key so that validation will fail return new BackupsDb.AuthenticationData(INVALID_PUBLIC_KEY, null, null); }); + + final Pair credentialTypeAndBackupLevel = + signatureVerifier.verifySignature(signature, authenticationData.publicKey()); + return new AuthenticatedBackupUser( presentation.getBackupId(), - signatureVerifier.verifySignature(signature, authenticationData.publicKey()), - authenticationData.backupDir(), authenticationData.mediaDir()); + credentialTypeAndBackupLevel.first(), + credentialTypeAndBackupLevel.second(), + authenticationData.backupDir(), + authenticationData.mediaDir()); }) .thenApply(result -> { Metrics.counter(ZK_AUTHN_COUNTER_NAME, SUCCESS_TAG_NAME, String.valueOf(true)).increment(); @@ -579,7 +588,7 @@ public class BackupManager { interface PresentationSignatureVerifier { - BackupLevel verifySignature(byte[] signature, ECPublicKey publicKey); + Pair verifySignature(byte[] signature, ECPublicKey publicKey); } /** @@ -611,7 +620,7 @@ public class BackupManager { .withDescription("backup auth credential presentation signature verification failed") .asRuntimeException(); } - return presentation.getBackupLevel(); + return new Pair<>(presentation.getType(), presentation.getBackupLevel()); }; } @@ -622,9 +631,34 @@ public class BackupManager { * @param backupLevel The authorization level to verify the backupUser has access to * @throws {@link Status#PERMISSION_DENIED} error if the backup user is not authorized to access {@code backupLevel} */ - private static void checkBackupLevel(final AuthenticatedBackupUser backupUser, final BackupLevel backupLevel) { + @VisibleForTesting + static void checkBackupLevel(final AuthenticatedBackupUser backupUser, final BackupLevel backupLevel) { if (backupUser.backupLevel().compareTo(backupLevel) < 0) { - Metrics.counter(ZK_AUTHZ_FAILURE_COUNTER_NAME).increment(); + Metrics.counter(ZK_AUTHZ_FAILURE_COUNTER_NAME, + FAILURE_REASON_TAG_NAME, "level") + .increment(); + + throw Status.PERMISSION_DENIED + .withDescription("credential does not support the requested operation") + .asRuntimeException(); + } + } + + /** + * Check that the authenticated backup user is authenticated with the given credential type + * + * @param backupUser The backup user to check + * @param credentialType The credential type to require + * @throws {@link Status#PERMISSION_DENIED} error if the backup user is not authenticated with the given + * {@code credentialType} + */ + @VisibleForTesting + static void checkBackupCredentialType(final AuthenticatedBackupUser backupUser, final BackupCredentialType credentialType) { + if (backupUser.credentialType() != credentialType) { + Metrics.counter(ZK_AUTHZ_FAILURE_COUNTER_NAME, + FAILURE_REASON_TAG_NAME, "credential_type") + .increment(); + throw Status.PERMISSION_DENIED .withDescription("credential does not support the requested operation") .asRuntimeException(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupsDb.java b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupsDb.java index 3aa4487f7..c705e40b3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupsDb.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupsDb.java @@ -87,7 +87,7 @@ public class BackupsDb { // garbage collection of archive objects. public static final String ATTR_LAST_REFRESH = "R"; // N: Time in seconds since epoch of the last backup media refresh. This timestamp can only be updated if the client - // has BackupLevel.MEDIA, and must be periodically updated to avoid garbage collection of media objects. + // has BackupLevel.PAID, and must be periodically updated to avoid garbage collection of media objects. public static final String ATTR_LAST_MEDIA_REFRESH = "MR"; // B: A 32 byte public key that should be used to sign the presentation used to authenticate requests against the // backup-id @@ -265,7 +265,7 @@ public class BackupsDb { * Indicates that we couldn't schedule a deletion because one was already scheduled. The caller may want to delete the * objects directly. */ - class PendingDeletionException extends IOException {} + static class PendingDeletionException extends IOException {} /** * Attempt to mark a backup as expired and swap in a new empty backupDir for the user. @@ -285,7 +285,7 @@ public class BackupsDb { final byte[] hashedBackupId = hashedBackupId(backupUser); // Clear usage metadata, swap names of things we intend to delete, and record our intent to delete in attr_expired_prefix - return dynamoClient.updateItem(new UpdateBuilder(backupTableName, BackupLevel.MEDIA, hashedBackupId) + return dynamoClient.updateItem(new UpdateBuilder(backupTableName, BackupLevel.PAID, hashedBackupId) .clearMediaUsage(clock) .expireDirectoryNames(secureRandom, ExpiredBackup.ExpirationType.ALL) .setRefreshTimes(Instant.ofEpochSecond(0)) @@ -300,7 +300,7 @@ public class BackupsDb { // is toggling backups on and off. In this case, it should be pretty cheap to directly delete the backup. // Instead of changing the backupDir, just make sure the row has expired/ timestamps and tell the caller we // couldn't schedule the deletion. - dynamoClient.updateItem(new UpdateBuilder(backupTableName, BackupLevel.MEDIA, hashedBackupId) + dynamoClient.updateItem(new UpdateBuilder(backupTableName, BackupLevel.PAID, hashedBackupId) .setRefreshTimes(Instant.ofEpochSecond(0)) .updateItemBuilder() .build()) @@ -399,7 +399,7 @@ public class BackupsDb { } // Clear usage metadata, swap names of things we intend to delete, and record our intent to delete in attr_expired_prefix - return dynamoClient.updateItem(new UpdateBuilder(backupTableName, BackupLevel.MEDIA, expiredBackup.hashedBackupId()) + return dynamoClient.updateItem(new UpdateBuilder(backupTableName, BackupLevel.PAID, expiredBackup.hashedBackupId()) .clearMediaUsage(clock) .expireDirectoryNames(secureRandom, expiredBackup.expirationType()) .addRemoveExpression(Map.entry("#mediaRefresh", ATTR_LAST_MEDIA_REFRESH)) @@ -433,7 +433,7 @@ public class BackupsDb { .build()) .thenRun(Util.NOOP); } else { - return dynamoClient.updateItem(new UpdateBuilder(backupTableName, BackupLevel.MEDIA, hashedBackupId) + return dynamoClient.updateItem(new UpdateBuilder(backupTableName, BackupLevel.PAID, hashedBackupId) .addRemoveExpression(Map.entry("#expiredPrefixes", ATTR_EXPIRED_PREFIX)) .updateItemBuilder() .build()) @@ -722,7 +722,7 @@ public class BackupsDb { Map.entry("#lastRefreshTime", ATTR_LAST_REFRESH), Map.entry(":lastRefreshTime", AttributeValues.n(refreshTime.getEpochSecond()))); - if (backupLevel.compareTo(BackupLevel.MEDIA) >= 0) { + if (backupLevel.compareTo(BackupLevel.PAID) >= 0) { // update the media time if we have the appropriate level addSetExpression("#lastMediaRefreshTime = :lastMediaRefreshTime", Map.entry("#lastMediaRefreshTime", ATTR_LAST_MEDIA_REFRESH), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ArchiveController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ArchiveController.java index 711393d53..0138d1c6d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ArchiveController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ArchiveController.java @@ -23,11 +23,14 @@ import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; import java.time.Instant; +import java.util.Arrays; import java.util.Base64; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Stream; import javax.validation.Valid; import javax.validation.constraints.Max; @@ -52,6 +55,7 @@ import org.signal.libsignal.protocol.ecc.ECPublicKey; import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialPresentation; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequest; +import org.signal.libsignal.zkgroup.backups.BackupCredentialType; import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialPresentation; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.backup.BackupAuthManager; @@ -89,11 +93,21 @@ public class ArchiveController { public record SetBackupIdRequest( @Schema(description = """ - A BackupAuthCredentialRequest containing a blinded encrypted backup-id, encoded in standard padded base64 + A BackupAuthCredentialRequest containing a blinded encrypted backup-id, encoded in standard padded base64. + This backup-id should be used for message backups only, and must have the message backup type set on the + credential. """, implementation = String.class) @JsonDeserialize(using = BackupAuthCredentialAdapter.CredentialRequestDeserializer.class) @JsonSerialize(using = BackupAuthCredentialAdapter.CredentialRequestSerializer.class) - @NotNull BackupAuthCredentialRequest backupAuthCredentialRequest) {} + @NotNull BackupAuthCredentialRequest messagesBackupAuthCredentialRequest, + + @Schema(description = """ + A BackupAuthCredentialRequest containing a blinded encrypted backup-id, encoded in standard padded base64. + This backup-id should be used for media only, and must have the media type set on the credential. + """, implementation = String.class) + @JsonDeserialize(using = BackupAuthCredentialAdapter.CredentialRequestDeserializer.class) + @JsonSerialize(using = BackupAuthCredentialAdapter.CredentialRequestSerializer.class) + @NotNull BackupAuthCredentialRequest mediaBackupAuthCredentialRequest) {} @PUT @@ -115,8 +129,9 @@ public class ArchiveController { public CompletionStage setBackupId( @Mutable @Auth final AuthenticatedDevice account, @Valid @NotNull final SetBackupIdRequest setBackupIdRequest) throws RateLimitExceededException { + return this.backupAuthManager - .commitBackupId(account.getAccount(), setBackupIdRequest.backupAuthCredentialRequest) + .commitBackupId(account.getAccount(), setBackupIdRequest.messagesBackupAuthCredentialRequest, setBackupIdRequest.mediaBackupAuthCredentialRequest) .thenApply(Util.ASYNC_EMPTY_RESPONSE); } @@ -166,8 +181,8 @@ public class ArchiveController { } public record BackupAuthCredentialsResponse( - @Schema(description = "A list of BackupAuthCredentials and their validity periods") - List credentials) { + @Schema(description = "A map of credential types to lists of BackupAuthCredentials and their validity periods") + Map> credentials) { public record BackupAuthCredential( @Schema(description = "A BackupAuthCredential, encoded in standard padded base64") @@ -202,14 +217,21 @@ public class ArchiveController { @NotNull @QueryParam("redemptionStartSeconds") Long startSeconds, @NotNull @QueryParam("redemptionEndSeconds") Long endSeconds) { - return this.backupAuthManager.getBackupAuthCredentials( - auth.getAccount(), - Instant.ofEpochSecond(startSeconds), Instant.ofEpochSecond(endSeconds)) - .thenApply(creds -> new BackupAuthCredentialsResponse(creds.stream() - .map(cred -> new BackupAuthCredentialsResponse.BackupAuthCredential( - cred.credential().serialize(), - cred.redemptionTime().getEpochSecond())) - .toList())); + final Map> credentialsByType = + new ConcurrentHashMap<>(); + + return CompletableFuture.allOf(Arrays.stream(BackupCredentialType.values()) + .map(credentialType -> this.backupAuthManager.getBackupAuthCredentials( + auth.getAccount(), + credentialType, + Instant.ofEpochSecond(startSeconds), Instant.ofEpochSecond(endSeconds)) + .thenAccept(credentials -> credentialsByType.put(credentialType, credentials.stream() + .map(credential -> new BackupAuthCredentialsResponse.BackupAuthCredential( + credential.credential().serialize(), + credential.redemptionTime().getEpochSecond())) + .toList()))) + .toArray(CompletableFuture[]::new)) + .thenApply(ignored -> new BackupAuthCredentialsResponse(credentialsByType)); } @@ -227,7 +249,8 @@ public class ArchiveController { @ApiResponse(responseCode = "401", description = """ The provided backup auth credential presentation could not be verified or The public key signature was invalid or - There is no backup associated with the backup-id in the presentation""") + There is no backup associated with the backup-id in the presentation or + The credential was of the wrong type (messages/media)""") @ApiResponse(responseCode = "400", description = "Bad arguments. The request may have been made on an authenticated channel") @interface ApiResponseZkAuth {} @@ -453,7 +476,7 @@ public class ArchiveController { throw new BadRequestException("must not use authenticated connection for anonymous operations"); } return backupManager.authenticateBackupUser(presentation.presentation, signature.signature) - .thenCompose(backupUser -> backupManager.createTemporaryAttachmentUploadDescriptor(backupUser)) + .thenCompose(backupManager::createTemporaryAttachmentUploadDescriptor) .thenApply(result -> new UploadDescriptorResponse( result.cdn(), result.key(), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java index bb1eb61e8..3c6f171f2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java @@ -22,6 +22,7 @@ import java.util.function.Predicate; import javax.annotation.Nullable; import org.apache.commons.lang3.StringUtils; import org.signal.libsignal.protocol.IdentityKey; +import org.signal.libsignal.zkgroup.backups.BackupCredentialType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.SaltedTokenHash; @@ -116,7 +117,11 @@ public class Account { @JsonProperty("bcr") @Nullable - private byte[] backupCredentialRequest; + private byte[] messagesBackupCredentialRequest; + + @JsonProperty("mbcr") + @Nullable + private byte[] mediaBackupCredentialRequest; @JsonProperty("bv") @Nullable @@ -284,7 +289,7 @@ public class Account { requireNotStale(); return Optional.ofNullable(getPrimaryDevice().getCapabilities()) - .map(Device.DeviceCapabilities::transfer) + .map(DeviceCapabilities::transfer) .orElse(false); } @@ -509,12 +514,22 @@ public class Account { this.svr3ShareSet = svr3ShareSet; } - public byte[] getBackupCredentialRequest() { - return backupCredentialRequest; + public void setBackupCredentialRequests(final byte[] messagesBackupCredentialRequest, + final byte[] mediaBackupCredentialRequest) { + + requireNotStale(); + + this.messagesBackupCredentialRequest = messagesBackupCredentialRequest; + this.mediaBackupCredentialRequest = mediaBackupCredentialRequest; } - public void setBackupCredentialRequest(final byte[] backupCredentialRequest) { - this.backupCredentialRequest = backupCredentialRequest; + public Optional getBackupCredentialRequest(final BackupCredentialType credentialType) { + requireNotStale(); + + return Optional.ofNullable(switch (credentialType) { + case MESSAGES -> messagesBackupCredentialRequest; + case MEDIA -> mediaBackupCredentialRequest; + }); } public @Nullable BackupVoucher getBackupVoucher() { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java index fb0d0e9cd..c4a96306d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Accounts.java @@ -36,6 +36,7 @@ import java.util.function.Predicate; import java.util.stream.Collectors; import javax.annotation.Nonnull; import javax.annotation.Nullable; +import org.signal.libsignal.zkgroup.backups.BackupCredentialType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.util.AsyncTimerUtil; @@ -321,7 +322,9 @@ public class Accounts extends AbstractDynamoDbStore { // Carry over the old backup id commitment. If the new account claimer cannot does not have the secret used to // generate their backup-id, this credential is useless, however if they can produce the same credential they // won't be rate-limited for setting their backup-id. - accountToCreate.setBackupCredentialRequest(existingAccount.getBackupCredentialRequest()); + accountToCreate.setBackupCredentialRequests( + existingAccount.getBackupCredentialRequest(BackupCredentialType.MESSAGES).orElse(null), + existingAccount.getBackupCredentialRequest(BackupCredentialType.MEDIA).orElse(null)); // Carry over the old SVR3 share-set. This is required for an account to restore information from SVR. The share- // set is not a secret, if the new account claimer does not have the SVR3 pin, it is useless. diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/BackupMetricsCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/BackupMetricsCommand.java index 96430236f..e80029f0b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/BackupMetricsCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/BackupMetricsCommand.java @@ -6,8 +6,6 @@ package org.whispersystems.textsecuregcm.workers; import io.dropwizard.core.Application; -import io.dropwizard.core.cli.Cli; -import io.dropwizard.core.cli.EnvironmentCommand; import io.dropwizard.core.setup.Environment; import io.micrometer.core.instrument.DistributionSummary; import io.micrometer.core.instrument.Metrics; @@ -18,8 +16,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.WhisperServerConfiguration; import org.whispersystems.textsecuregcm.backup.BackupManager; -import org.whispersystems.textsecuregcm.metrics.MetricsUtil; -import org.whispersystems.textsecuregcm.util.logging.UncaughtExceptionHandler; import reactor.core.scheduler.Schedulers; import java.time.Clock; @@ -69,13 +65,13 @@ public class BackupMetricsCommand extends AbstractCommandWithDependencies { Runtime.getRuntime().availableProcessors()); final DistributionSummary numObjectsMediaTier = Metrics.summary(name(getClass(), "numObjects"), - "tier", BackupLevel.MEDIA.name()); + "tier", BackupLevel.PAID.name()); final DistributionSummary bytesUsedMediaTier = Metrics.summary(name(getClass(), "bytesUsed"), - "tier", BackupLevel.MEDIA.name()); + "tier", BackupLevel.PAID.name()); final DistributionSummary numObjectsMessagesTier = Metrics.summary(name(getClass(), "numObjects"), - "tier", BackupLevel.MESSAGES.name()); + "tier", BackupLevel.FREE.name()); final DistributionSummary bytesUsedMessagesTier = Metrics.summary(name(getClass(), "bytesUsed"), - "tier", BackupLevel.MESSAGES.name()); + "tier", BackupLevel.FREE.name()); final DistributionSummary timeSinceLastRefresh = Metrics.summary(name(getClass(), "timeSinceLastRefresh")); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java index 537866c68..aed5e61bd 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java @@ -24,6 +24,7 @@ import java.time.Duration; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.List; +import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; @@ -39,12 +40,14 @@ import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.NullSource; import org.junit.jupiter.params.provider.ValueSource; +import org.junitpioneer.jupiter.cartesian.CartesianTest; import org.mockito.ArgumentCaptor; import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.ServerSecretParams; import org.signal.libsignal.zkgroup.VerificationFailedException; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequest; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequestContext; +import org.signal.libsignal.zkgroup.backups.BackupCredentialType; import org.signal.libsignal.zkgroup.backups.BackupLevel; import org.signal.libsignal.zkgroup.receipts.ClientZkReceiptOperations; import org.signal.libsignal.zkgroup.receipts.ReceiptCredential; @@ -67,7 +70,8 @@ import org.whispersystems.textsecuregcm.util.TestRandomUtil; public class BackupAuthManagerTest { private final UUID aci = UUID.randomUUID(); - private final byte[] backupKey = TestRandomUtil.nextBytes(32); + private final byte[] messagesBackupKey = TestRandomUtil.nextBytes(32); + private final byte[] mediaBackupKey = TestRandomUtil.nextBytes(32); private final ServerSecretParams receiptParams = ServerSecretParams.generate(); private final TestClock clock = TestClock.now(); private final BackupAuthTestUtil backupAuthTestUtil = new BackupAuthTestUtil(clock); @@ -92,6 +96,30 @@ public class BackupAuthManagerTest { clock); } + @Test + void commitBackupId() { + final BackupAuthManager authManager = create(BackupLevel.FREE, false); + + final Account account = mock(Account.class); + when(account.getUuid()).thenReturn(aci); + when(accountsManager.updateAsync(any(), any())) + .thenAnswer(invocation -> { + final Account a = invocation.getArgument(0); + final Consumer updater = invocation.getArgument(1); + + updater.accept(a); + + return CompletableFuture.completedFuture(a); + }); + + final BackupAuthCredentialRequest messagesCredentialRequest = backupAuthTestUtil.getRequest(messagesBackupKey, aci); + final BackupAuthCredentialRequest mediaCredentialRequest = backupAuthTestUtil.getRequest(mediaBackupKey, aci); + + authManager.commitBackupId(account, messagesCredentialRequest, mediaCredentialRequest).join(); + + verify(account).setBackupCredentialRequests(messagesCredentialRequest.serialize(), mediaCredentialRequest.serialize()); + } + @ParameterizedTest @EnumSource @NullSource @@ -102,9 +130,11 @@ public class BackupAuthManagerTest { when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account)); final ThrowableAssert.ThrowingCallable commit = () -> - authManager.commitBackupId(account, backupAuthTestUtil.getRequest(backupKey, aci)).join(); + authManager.commitBackupId(account, + backupAuthTestUtil.getRequest(messagesBackupKey, aci), + backupAuthTestUtil.getRequest(mediaBackupKey, aci)).join(); if (backupLevel == null) { - Assertions.assertThatExceptionOfType(StatusRuntimeException.class) + assertThatExceptionOfType(StatusRuntimeException.class) .isThrownBy(commit) .extracting(ex -> ex.getStatus().getCode()) .isEqualTo(Status.Code.PERMISSION_DENIED); @@ -113,46 +143,70 @@ public class BackupAuthManagerTest { } } + @CartesianTest + void getBackupAuthCredentials(@CartesianTest.Enum final BackupLevel backupLevel, + @CartesianTest.Enum final BackupCredentialType credentialType) { - @ParameterizedTest - @EnumSource - @NullSource - void credentialsRequiresBackupLevel(final BackupLevel backupLevel) { final BackupAuthManager authManager = create(backupLevel, false); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); - when(account.getBackupCredentialRequest()).thenReturn(backupAuthTestUtil.getRequest(backupKey, aci).serialize()); + when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES)) + .thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize())); + when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA)) + .thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize())); - final ThrowableAssert.ThrowingCallable getCreds = () -> - assertThat(authManager.getBackupAuthCredentials(account, - clock.instant().truncatedTo(ChronoUnit.DAYS), - clock.instant().plus(Duration.ofDays(1)).truncatedTo(ChronoUnit.DAYS)).join()) - .hasSize(2); - if (backupLevel == null) { - Assertions.assertThatExceptionOfType(StatusRuntimeException.class) - .isThrownBy(getCreds) - .extracting(ex -> ex.getStatus().getCode()) - .isEqualTo(Status.Code.PERMISSION_DENIED); - } else { - Assertions.assertThatNoException().isThrownBy(getCreds); - } + assertThat(authManager.getBackupAuthCredentials(account, + credentialType, + clock.instant().truncatedTo(ChronoUnit.DAYS), + clock.instant().plus(Duration.ofDays(1)).truncatedTo(ChronoUnit.DAYS)).join()) + .hasSize(2); } @ParameterizedTest @EnumSource - void getReceiptCredentials(final BackupLevel backupLevel) throws VerificationFailedException { - final BackupAuthManager authManager = create(backupLevel, false); - - final BackupAuthCredentialRequestContext requestContext = BackupAuthCredentialRequestContext.create(backupKey, aci); + void getBackupAuthCredentialsNoBackupLevel(final BackupCredentialType credentialType) { + final BackupAuthManager authManager = create(null, false); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); - when(account.getBackupCredentialRequest()).thenReturn(requestContext.getRequest().serialize()); + when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES)) + .thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize())); + when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA)) + .thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize())); + + assertThatExceptionOfType(StatusRuntimeException.class) + .isThrownBy(() -> authManager.getBackupAuthCredentials(account, + credentialType, + clock.instant().truncatedTo(ChronoUnit.DAYS), + clock.instant().plus(Duration.ofDays(1)).truncatedTo(ChronoUnit.DAYS)).join()) + .extracting(ex -> ex.getStatus().getCode()) + .isEqualTo(Status.Code.PERMISSION_DENIED); + } + + @CartesianTest + void getReceiptCredentials(@CartesianTest.Enum final BackupLevel backupLevel, + @CartesianTest.Enum final BackupCredentialType credentialType) throws VerificationFailedException { + final BackupAuthManager authManager = create(backupLevel, false); + + final byte[] backupKey = switch (credentialType) { + case MESSAGES -> messagesBackupKey; + case MEDIA -> mediaBackupKey; + }; + + final BackupAuthCredentialRequestContext requestContext = + BackupAuthCredentialRequestContext.create(backupKey, aci); + + final Account account = mock(Account.class); + when(account.getUuid()).thenReturn(aci); + when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES)) + .thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize())); + when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA)) + .thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize())); final Instant start = clock.instant().truncatedTo(ChronoUnit.DAYS); final List creds = authManager.getBackupAuthCredentials(account, - start, start.plus(Duration.ofDays(7))).join(); + credentialType, start, start.plus(Duration.ofDays(7))).join(); assertThat(creds).hasSize(8); Instant redemptionTime = start; @@ -190,16 +244,19 @@ public class BackupAuthManagerTest { @MethodSource void invalidCredentialTimeWindows(final Instant requestRedemptionStart, final Instant requestRedemptionEnd, final Instant now) { - final BackupAuthManager authManager = create(BackupLevel.MESSAGES, false); + final BackupAuthManager authManager = create(BackupLevel.FREE, false); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); - when(account.getBackupCredentialRequest()).thenReturn(backupAuthTestUtil.getRequest(backupKey, aci).serialize()); + when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES)) + .thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize())); + when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA)) + .thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize())); clock.pin(now); assertThatExceptionOfType(StatusRuntimeException.class) .isThrownBy( - () -> authManager.getBackupAuthCredentials(account, requestRedemptionStart, requestRedemptionEnd).join()) + () -> authManager.getBackupAuthCredentials(account, BackupCredentialType.MESSAGES, requestRedemptionStart, requestRedemptionEnd).join()) .extracting(ex -> ex.getStatus().getCode()) .isEqualTo(Status.Code.INVALID_ARGUMENT); } @@ -211,19 +268,23 @@ public class BackupAuthManagerTest { final Instant day4 = Instant.EPOCH.plus(Duration.ofDays(4)); final Instant dayMax = day0.plus(BackupAuthManager.MAX_REDEMPTION_DURATION); - final BackupAuthManager authManager = create(BackupLevel.MESSAGES, false); + final BackupAuthManager authManager = create(BackupLevel.FREE, false); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); - when(account.getBackupCredentialRequest()).thenReturn(backupAuthTestUtil.getRequest(backupKey, aci).serialize()); + when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES)) + .thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize())); + when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA)) + .thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize())); when(account.getBackupVoucher()).thenReturn(new Account.BackupVoucher(201, day4)); - final List creds = authManager.getBackupAuthCredentials(account, day0, dayMax).join(); + final List creds = authManager.getBackupAuthCredentials(account, BackupCredentialType.MESSAGES, day0, dayMax).join(); Instant redemptionTime = day0; - final BackupAuthCredentialRequestContext requestContext = BackupAuthCredentialRequestContext.create(backupKey, aci); + final BackupAuthCredentialRequestContext requestContext = BackupAuthCredentialRequestContext.create( + messagesBackupKey, aci); for (int i = 0; i < creds.size(); i++) { // Before the expiration, credentials should have a media receipt, otherwise messages only - final BackupLevel level = i < 5 ? BackupLevel.MEDIA : BackupLevel.MESSAGES; + final BackupLevel level = i < 5 ? BackupLevel.PAID : BackupLevel.FREE; final BackupAuthManager.Credential cred = creds.get(i); assertThat(requestContext .receiveResponse(cred.credential(), redemptionTime, backupAuthTestUtil.params.getPublicParams()) @@ -240,19 +301,23 @@ public class BackupAuthManagerTest { final Instant day2 = Instant.EPOCH.plus(Duration.ofDays(2)); final Instant day3 = Instant.EPOCH.plus(Duration.ofDays(3)); - final BackupAuthManager authManager = create(BackupLevel.MESSAGES, false); + final BackupAuthManager authManager = create(BackupLevel.FREE, false); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); when(account.getBackupVoucher()).thenReturn(new Account.BackupVoucher(3, day1)); final Account updated = mock(Account.class); when(updated.getUuid()).thenReturn(aci); - when(updated.getBackupCredentialRequest()).thenReturn(backupAuthTestUtil.getRequest(backupKey, aci).serialize()); + when(updated.getBackupCredentialRequest(BackupCredentialType.MESSAGES)) + .thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize())); + when(updated.getBackupCredentialRequest(BackupCredentialType.MEDIA)) + .thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize())); + when(updated.getBackupVoucher()).thenReturn(null); when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(updated)); clock.pin(day2.plus(Duration.ofSeconds(1))); - assertThat(authManager.getBackupAuthCredentials(account, day2, day2.plus(Duration.ofDays(7))).join()) + assertThat(authManager.getBackupAuthCredentials(account, BackupCredentialType.MESSAGES, day2, day2.plus(Duration.ofDays(7))).join()) .hasSize(8); @SuppressWarnings("unchecked") final ArgumentCaptor> accountUpdater = ArgumentCaptor.forClass( @@ -276,7 +341,7 @@ public class BackupAuthManagerTest { @Test void redeemReceipt() throws InvalidInputException, VerificationFailedException { final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); - final BackupAuthManager authManager = create(BackupLevel.MESSAGES, false); + final BackupAuthManager authManager = create(BackupLevel.FREE, false); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); @@ -293,7 +358,7 @@ public class BackupAuthManagerTest { final Instant newExpirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); final Instant existingExpirationTime = Instant.EPOCH.plus(Duration.ofDays(1)).plus(Duration.ofSeconds(1)); - final BackupAuthManager authManager = create(BackupLevel.MESSAGES, false); + final BackupAuthManager authManager = create(BackupLevel.FREE, false); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); @@ -318,8 +383,8 @@ public class BackupAuthManagerTest { void redeemExpiredReceipt() { final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); clock.pin(expirationTime.plus(Duration.ofSeconds(1))); - final BackupAuthManager authManager = create(BackupLevel.MESSAGES, false); - Assertions.assertThatExceptionOfType(StatusRuntimeException.class) + final BackupAuthManager authManager = create(BackupLevel.FREE, false); + assertThatExceptionOfType(StatusRuntimeException.class) .isThrownBy(() -> authManager.redeemReceipt(mock(Account.class), receiptPresentation(3, expirationTime)).join()) .extracting(ex -> ex.getStatus().getCode()) .isEqualTo(Status.Code.INVALID_ARGUMENT); @@ -332,8 +397,8 @@ public class BackupAuthManagerTest { void redeemInvalidLevel(long level) { final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); clock.pin(expirationTime.plus(Duration.ofSeconds(1))); - final BackupAuthManager authManager = create(BackupLevel.MESSAGES, false); - Assertions.assertThatExceptionOfType(StatusRuntimeException.class) + final BackupAuthManager authManager = create(BackupLevel.FREE, false); + assertThatExceptionOfType(StatusRuntimeException.class) .isThrownBy(() -> authManager.redeemReceipt(mock(Account.class), receiptPresentation(level, expirationTime)).join()) .extracting(ex -> ex.getStatus().getCode()) @@ -344,9 +409,9 @@ public class BackupAuthManagerTest { @Test void redeemInvalidPresentation() throws InvalidInputException, VerificationFailedException { - final BackupAuthManager authManager = create(BackupLevel.MESSAGES, false); + final BackupAuthManager authManager = create(BackupLevel.FREE, false); final ReceiptCredentialPresentation invalid = receiptPresentation(ServerSecretParams.generate(), 3L, Instant.EPOCH); - Assertions.assertThatExceptionOfType(StatusRuntimeException.class) + assertThatExceptionOfType(StatusRuntimeException.class) .isThrownBy(() -> authManager.redeemReceipt(mock(Account.class), invalid).join()) .extracting(ex -> ex.getStatus().getCode()) .isEqualTo(Status.Code.INVALID_ARGUMENT); @@ -357,7 +422,7 @@ public class BackupAuthManagerTest { @Test void receiptAlreadyRedeemed() throws InvalidInputException, VerificationFailedException { final Instant expirationTime = Instant.EPOCH.plus(Duration.ofDays(1)); - final BackupAuthManager authManager = create(BackupLevel.MESSAGES, false); + final BackupAuthManager authManager = create(BackupLevel.FREE, false); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); @@ -397,28 +462,31 @@ public class BackupAuthManagerTest { @Test void testRateLimits() { final AccountsManager accountsManager = mock(AccountsManager.class); - final BackupAuthManager authManager = create(BackupLevel.MESSAGES, true); + final BackupAuthManager authManager = create(BackupLevel.FREE, true); - final BackupAuthCredentialRequest credentialRequest = backupAuthTestUtil.getRequest(backupKey, aci); + final BackupAuthCredentialRequest messagesCredential = backupAuthTestUtil.getRequest(messagesBackupKey, aci); + final BackupAuthCredentialRequest mediaCredential = backupAuthTestUtil.getRequest(mediaBackupKey, aci); final Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account)); // Should be rate limited - final RateLimitExceededException ex = CompletableFutureTestUtil.assertFailsWithCause( - RateLimitExceededException.class, - authManager.commitBackupId(account, credentialRequest)); + CompletableFutureTestUtil.assertFailsWithCause(RateLimitExceededException.class, + authManager.commitBackupId(account, messagesCredential, mediaCredential)); // If we don't change the request, shouldn't be rate limited - when(account.getBackupCredentialRequest()).thenReturn(credentialRequest.serialize()); - assertDoesNotThrow(() -> authManager.commitBackupId(account, credentialRequest).join()); + when(account.getBackupCredentialRequest(BackupCredentialType.MESSAGES)) + .thenReturn(Optional.of(backupAuthTestUtil.getRequest(messagesBackupKey, aci).serialize())); + when(account.getBackupCredentialRequest(BackupCredentialType.MEDIA)) + .thenReturn(Optional.of(backupAuthTestUtil.getRequest(mediaBackupKey, aci).serialize())); + assertDoesNotThrow(() -> authManager.commitBackupId(account, messagesCredential, mediaCredential).join()); } private static String experimentName(@Nullable BackupLevel backupLevel) { return switch (backupLevel) { - case MESSAGES -> BackupAuthManager.BACKUP_EXPERIMENT_NAME; - case MEDIA -> BackupAuthManager.BACKUP_MEDIA_EXPERIMENT_NAME; + case FREE -> BackupAuthManager.BACKUP_EXPERIMENT_NAME; + case PAID -> BackupAuthManager.BACKUP_MEDIA_EXPERIMENT_NAME; case null -> "fake_experiment"; }; } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthTestUtil.java b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthTestUtil.java index e52e01ce2..d0029bc83 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthTestUtil.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthTestUtil.java @@ -12,12 +12,14 @@ import java.time.Clock; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.List; +import java.util.Optional; import java.util.UUID; import org.signal.libsignal.zkgroup.GenericServerSecretParams; import org.signal.libsignal.zkgroup.VerificationFailedException; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialPresentation; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequest; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialRequestContext; +import org.signal.libsignal.zkgroup.backups.BackupCredentialType; import org.signal.libsignal.zkgroup.backups.BackupLevel; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.tests.util.ExperimentHelper; @@ -48,7 +50,7 @@ public class BackupAuthTestUtil { final BackupAuthCredentialRequestContext ctx = BackupAuthCredentialRequestContext.create(backupKey, aci); return ctx.receiveResponse( ctx.getRequest() - .issueCredential(clock.instant().truncatedTo(ChronoUnit.DAYS), backupLevel, params), + .issueCredential(clock.instant().truncatedTo(ChronoUnit.DAYS), backupLevel, BackupCredentialType.MESSAGES, params), redemptionTime, params.getPublicParams()) .present(params.getPublicParams()); @@ -57,19 +59,20 @@ public class BackupAuthTestUtil { public List getCredentials( final BackupLevel backupLevel, final BackupAuthCredentialRequest request, + final BackupCredentialType credentialType, final Instant redemptionStart, final Instant redemptionEnd) { final UUID aci = UUID.randomUUID(); final String experimentName = switch (backupLevel) { - case MESSAGES -> BackupAuthManager.BACKUP_EXPERIMENT_NAME; - case MEDIA -> BackupAuthManager.BACKUP_MEDIA_EXPERIMENT_NAME; + case FREE -> BackupAuthManager.BACKUP_EXPERIMENT_NAME; + case PAID -> BackupAuthManager.BACKUP_MEDIA_EXPERIMENT_NAME; }; final BackupAuthManager issuer = new BackupAuthManager( ExperimentHelper.withEnrollment(experimentName, aci), null, null, null, null, params, clock); Account account = mock(Account.class); when(account.getUuid()).thenReturn(aci); - when(account.getBackupCredentialRequest()).thenReturn(request.serialize()); - return issuer.getBackupAuthCredentials(account, redemptionStart, redemptionEnd).join(); + when(account.getBackupCredentialRequest(credentialType)).thenReturn(Optional.of(request.serialize())); + return issuer.getBackupAuthCredentials(account, credentialType, redemptionStart, redemptionEnd).join(); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java index e23e9d253..f1a8b4259 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java @@ -45,10 +45,12 @@ import java.util.function.Function; import java.util.stream.IntStream; import javax.annotation.Nullable; import org.apache.commons.lang3.StringUtils; +import org.assertj.core.api.ThrowableAssert; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.ValueSource; import org.junitpioneer.jupiter.cartesian.CartesianTest; @@ -58,6 +60,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.zkgroup.GenericServerSecretParams; import org.signal.libsignal.zkgroup.VerificationFailedException; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialPresentation; +import org.signal.libsignal.zkgroup.backups.BackupCredentialType; import org.signal.libsignal.zkgroup.backups.BackupLevel; import org.whispersystems.textsecuregcm.attachments.TusAttachmentGenerator; import org.whispersystems.textsecuregcm.auth.AuthenticatedBackupUser; @@ -87,7 +90,6 @@ public class BackupManagerTest { private static final CopyParameters COPY_PARAM = new CopyParameters( 3, "abc", 100, COPY_ENCRYPTION_PARAM, TestRandomUtil.nextBytes(15)); - private static final String COPY_DEST_STRING = Base64.getEncoder().encodeToString(COPY_PARAM.destinationMediaId()); private final TestClock testClock = TestClock.now(); private final BackupAuthTestUtil backupAuthTestUtil = new BackupAuthTestUtil(testClock); @@ -125,6 +127,62 @@ public class BackupManagerTest { testClock); } + @ParameterizedTest + @CsvSource({ + "FREE, FREE, false", + "FREE, PAID, true", + "PAID, FREE, false", + "PAID, PAID, false" + }) + void checkBackupLevel(final BackupLevel authenticateBackupLevel, + final BackupLevel requiredLevel, + final boolean expectException) { + + final AuthenticatedBackupUser backupUser = + backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, authenticateBackupLevel); + + final ThrowableAssert.ThrowingCallable checkBackupLevel = + () -> BackupManager.checkBackupLevel(backupUser, requiredLevel); + + if (expectException) { + assertThatExceptionOfType(StatusRuntimeException.class) + .isThrownBy(checkBackupLevel) + .extracting(StatusRuntimeException::getStatus) + .extracting(Status::getCode) + .isEqualTo(Status.Code.PERMISSION_DENIED); + } else { + assertThatNoException().isThrownBy(checkBackupLevel); + } + } + + @ParameterizedTest + @CsvSource({ + "MESSAGES, MESSAGES, false", + "MESSAGES, MEDIA, true", + "MEDIA, MESSAGES, true", + "MEDIA, MEDIA, false" + }) + void checkBackupCredentialType(final BackupCredentialType authenticateCredentialType, + final BackupCredentialType requiredCredentialType, + final boolean expectException) { + + final AuthenticatedBackupUser backupUser = + backupUser(TestRandomUtil.nextBytes(16), authenticateCredentialType, BackupLevel.FREE); + + final ThrowableAssert.ThrowingCallable checkCredentialType = + () -> BackupManager.checkBackupCredentialType(backupUser, requiredCredentialType); + + if (expectException) { + assertThatExceptionOfType(StatusRuntimeException.class) + .isThrownBy(checkCredentialType) + .extracting(StatusRuntimeException::getStatus) + .extracting(Status::getCode) + .isEqualTo(Status.Code.PERMISSION_DENIED); + } else { + assertThatNoException().isThrownBy(checkCredentialType); + } + } + @ParameterizedTest @EnumSource public void createBackup(final BackupLevel backupLevel) { @@ -132,7 +190,7 @@ public class BackupManagerTest { final Instant now = Instant.ofEpochSecond(Duration.ofDays(1).getSeconds()); testClock.pin(now); - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), backupLevel); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, backupLevel); backupManager.createMessageBackupUploadDescriptor(backupUser).join(); verify(tusCredentialGenerator, times(1)) @@ -144,22 +202,46 @@ public class BackupManagerTest { assertThat(info.mediaUsedSpace()).isEqualTo(Optional.empty()); // Check that the initial expiration times are the initial write times - checkExpectedExpirations(now, backupLevel == BackupLevel.MEDIA ? now : null, backupUser); + checkExpectedExpirations(now, backupLevel == BackupLevel.PAID ? now : null, backupUser); + } + + @ParameterizedTest + @EnumSource + public void createBackupWrongCredentialType(final BackupLevel backupLevel) { + + final Instant now = Instant.ofEpochSecond(Duration.ofDays(1).getSeconds()); + testClock.pin(now); + + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, backupLevel); + + assertThatExceptionOfType(StatusRuntimeException.class) + .isThrownBy(() -> backupManager.createMessageBackupUploadDescriptor(backupUser).join()) + .matches(exception -> exception.getStatus().getCode() == Status.PERMISSION_DENIED.getCode()); } @Test public void createTemporaryMediaAttachmentRateLimited() { - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID); when(mediaUploadLimiter.validateAsync(eq(BackupManager.rateLimitKey(backupUser)))) .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null))); - final RateLimitExceededException e = CompletableFutureTestUtil.assertFailsWithCause( + CompletableFutureTestUtil.assertFailsWithCause( RateLimitExceededException.class, backupManager.createTemporaryAttachmentUploadDescriptor(backupUser).toCompletableFuture()); } @Test public void createTemporaryMediaAttachmentWrongTier() { - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MESSAGES); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.FREE); + assertThatExceptionOfType(StatusRuntimeException.class) + .isThrownBy(() -> backupManager.createTemporaryAttachmentUploadDescriptor(backupUser)) + .extracting(StatusRuntimeException::getStatus) + .extracting(Status::getCode) + .isEqualTo(Status.Code.PERMISSION_DENIED); + } + + @Test + public void createTemporaryMediaAttachmentWrongCredentialType() { + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, BackupLevel.PAID); assertThatExceptionOfType(StatusRuntimeException.class) .isThrownBy(() -> backupManager.createTemporaryAttachmentUploadDescriptor(backupUser)) .extracting(StatusRuntimeException::getStatus) @@ -170,7 +252,7 @@ public class BackupManagerTest { @ParameterizedTest @EnumSource public void ttlRefresh(final BackupLevel backupLevel) { - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), backupLevel); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, backupLevel); final Instant tstart = Instant.ofEpochSecond(1).plus(Duration.ofDays(1)); final Instant tnext = tstart.plus(Duration.ofSeconds(1)); @@ -185,7 +267,7 @@ public class BackupManagerTest { checkExpectedExpirations( tnext, - backupLevel == BackupLevel.MEDIA ? tnext : null, + backupLevel == BackupLevel.PAID ? tnext : null, backupUser); } @@ -195,7 +277,7 @@ public class BackupManagerTest { final Instant tstart = Instant.ofEpochSecond(1).plus(Duration.ofDays(1)); final Instant tnext = tstart.plus(Duration.ofSeconds(1)); - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), backupLevel); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, backupLevel); // create backup at t=tstart testClock.pin(tstart); @@ -207,7 +289,7 @@ public class BackupManagerTest { checkExpectedExpirations( tnext, - backupLevel == BackupLevel.MEDIA ? tnext : null, + backupLevel == BackupLevel.PAID ? tnext : null, backupUser); } @@ -215,7 +297,7 @@ public class BackupManagerTest { public void invalidPresentationNoPublicKey() throws VerificationFailedException { final BackupAuthCredentialPresentation invalidPresentation = backupAuthTestUtil.getPresentation( GenericServerSecretParams.generate(), - BackupLevel.MESSAGES, backupKey, aci); + BackupLevel.FREE, backupKey, aci); final ECKeyPair keyPair = Curve.generateKeyPair(); @@ -233,10 +315,10 @@ public class BackupManagerTest { @Test public void invalidPresentationCorrectSignature() throws VerificationFailedException { final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( - BackupLevel.MESSAGES, backupKey, aci); + BackupLevel.FREE, backupKey, aci); final BackupAuthCredentialPresentation invalidPresentation = backupAuthTestUtil.getPresentation( GenericServerSecretParams.generate(), - BackupLevel.MESSAGES, backupKey, aci); + BackupLevel.FREE, backupKey, aci); final ECKeyPair keyPair = Curve.generateKeyPair(); backupManager.setPublicKey( @@ -256,7 +338,7 @@ public class BackupManagerTest { @Test public void unknownPublicKey() throws VerificationFailedException { final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( - BackupLevel.MESSAGES, backupKey, aci); + BackupLevel.FREE, backupKey, aci); final ECKeyPair keyPair = Curve.generateKeyPair(); final byte[] signature = keyPair.getPrivateKey().calculateSignature(presentation.serialize()); @@ -272,7 +354,7 @@ public class BackupManagerTest { @Test public void mismatchedPublicKey() throws VerificationFailedException { final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( - BackupLevel.MESSAGES, backupKey, aci); + BackupLevel.FREE, backupKey, aci); final ECKeyPair keyPair1 = Curve.generateKeyPair(); final ECKeyPair keyPair2 = Curve.generateKeyPair(); @@ -295,7 +377,7 @@ public class BackupManagerTest { @Test public void signatureValidation() throws VerificationFailedException { final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( - BackupLevel.MESSAGES, backupKey, aci); + BackupLevel.FREE, backupKey, aci); final ECKeyPair keyPair = Curve.generateKeyPair(); final byte[] signature = keyPair.getPrivateKey().calculateSignature(presentation.serialize()); @@ -322,7 +404,7 @@ public class BackupManagerTest { // correct signature final AuthenticatedBackupUser user = backupManager.authenticateBackupUser(presentation, signature).join(); assertThat(user.backupId()).isEqualTo(presentation.getBackupId()); - assertThat(user.backupLevel()).isEqualTo(BackupLevel.MESSAGES); + assertThat(user.backupLevel()).isEqualTo(BackupLevel.FREE); } @Test @@ -330,7 +412,7 @@ public class BackupManagerTest { // credential for 1 day after epoch testClock.pin(Instant.ofEpochSecond(1).plus(Duration.ofDays(1))); - final BackupAuthCredentialPresentation oldCredential = backupAuthTestUtil.getPresentation(BackupLevel.MESSAGES, + final BackupAuthCredentialPresentation oldCredential = backupAuthTestUtil.getPresentation(BackupLevel.FREE, backupKey, aci); final ECKeyPair keyPair = Curve.generateKeyPair(); final byte[] signature = keyPair.getPrivateKey().calculateSignature(oldCredential.serialize()); @@ -355,7 +437,7 @@ public class BackupManagerTest { @Test public void copySuccess() { - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID); final CopyResult copied = copy(backupUser); assertThat(copied.cdn()).isEqualTo(3); @@ -372,7 +454,7 @@ public class BackupManagerTest { @Test public void copyFailure() { - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID); assertThat(copyError(backupUser, new SourceObjectNotFoundException()).outcome()) .isEqualTo(CopyResult.Outcome.SOURCE_NOT_FOUND); @@ -384,7 +466,7 @@ public class BackupManagerTest { @Test public void copyPartialSuccess() { - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID); final List toCopy = List.of( new CopyParameters(3, "success", 100, COPY_ENCRYPTION_PARAM, TestRandomUtil.nextBytes(15)), new CopyParameters(3, "missing", 200, COPY_ENCRYPTION_PARAM, TestRandomUtil.nextBytes(15)), @@ -413,9 +495,20 @@ public class BackupManagerTest { assertThat(AttributeValues.getLong(backup, BackupsDb.ATTR_MEDIA_COUNT, -1L)).isEqualTo(1L); } + @Test + public void copyWrongCredentialType() { + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, BackupLevel.PAID); + + assertThatExceptionOfType(StatusRuntimeException.class) + .isThrownBy(() -> copy(backupUser)) + .extracting(StatusRuntimeException::getStatus) + .extracting(Status::getCode) + .isEqualTo(Status.Code.PERMISSION_DENIED); + } + @Test public void quotaEnforcementNoRecalculation() { - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID); verifyNoInteractions(remoteStorageManager); // set the backupsDb to be out of quota at t=0 @@ -432,7 +525,7 @@ public class BackupManagerTest { @Test public void quotaEnforcementRecalculation() { - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID); final String backupMediaPrefix = "%s/%s/".formatted(backupUser.backupDir(), backupUser.mediaDir()); final long remainingAfterRecalc = BackupManager.MAX_TOTAL_BACKUP_MEDIA_BYTES - COPY_PARAM.destinationObjectSize(); @@ -462,7 +555,7 @@ public class BackupManagerTest { @CartesianTest.Values(booleans = {true, false}) boolean hasSpaceBeforeRecalc, @CartesianTest.Values(booleans = {true, false}) boolean hasSpaceAfterRecalc, @CartesianTest.Values(booleans = {true, false}) boolean doesReaclc) { - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID); final String backupMediaPrefix = "%s/%s/".formatted(backupUser.backupDir(), backupUser.mediaDir()); final long destSize = COPY_PARAM.destinationObjectSize(); @@ -496,7 +589,7 @@ public class BackupManagerTest { @ValueSource(strings = {"", "cursor"}) public void list(final String cursorVal) { final Optional cursor = Optional.of(cursorVal).filter(StringUtils::isNotBlank); - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, BackupLevel.PAID); final String backupMediaPrefix = "%s/%s/".formatted(backupUser.backupDir(), backupUser.mediaDir()); when(remoteStorageManager.cdnNumber()).thenReturn(13); @@ -519,14 +612,14 @@ public class BackupManagerTest { @Test public void deleteEntireBackup() { - final AuthenticatedBackupUser original = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); + final AuthenticatedBackupUser original = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, BackupLevel.PAID); testClock.pin(Instant.ofEpochSecond(10)); // Deleting should swap the backupDir for the user backupManager.deleteEntireBackup(original).join(); verifyNoInteractions(remoteStorageManager); - final AuthenticatedBackupUser after = retrieveBackupUser(original.backupId(), BackupLevel.MEDIA); + final AuthenticatedBackupUser after = retrieveBackupUser(original.backupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID); assertThat(original.backupDir()).isNotEqualTo(after.backupDir()); assertThat(original.mediaDir()).isNotEqualTo(after.mediaDir()); @@ -552,7 +645,7 @@ public class BackupManagerTest { @Test public void delete() { - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID); final byte[] mediaId = TestRandomUtil.nextBytes(16); final String backupMediaKey = "%s/%s/%s".formatted( backupUser.backupDir(), @@ -571,9 +664,24 @@ public class BackupManagerTest { .isEqualTo(new UsageInfo(93, 999)); } + @Test + public void deleteWrongCredentialType() { + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, BackupLevel.PAID); + final byte[] mediaId = TestRandomUtil.nextBytes(16); + final String backupMediaKey = "%s/%s/%s".formatted( + backupUser.backupDir(), + backupUser.mediaDir(), + BackupManager.encodeMediaIdForCdn(mediaId)); + + assertThatThrownBy(() -> + backupManager.deleteMedia(backupUser, List.of(new BackupManager.StorageDescriptor(5, mediaId))).then().block()) + .isInstanceOf(StatusRuntimeException.class) + .matches(e -> ((StatusRuntimeException) e).getStatus().getCode() == Status.PERMISSION_DENIED.getCode()); + } + @Test public void deleteUnknownCdn() { - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID); final BackupManager.StorageDescriptor sd = new BackupManager.StorageDescriptor(4, TestRandomUtil.nextBytes(15)); when(remoteStorageManager.cdnNumber()).thenReturn(5); assertThatThrownBy(() -> @@ -584,7 +692,7 @@ public class BackupManagerTest { @Test public void deletePartialFailure() { - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID); final List descriptors = new ArrayList<>(); long initialBytes = 0; @@ -621,7 +729,7 @@ public class BackupManagerTest { @Test public void alreadyDeleted() { - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID); final byte[] mediaId = TestRandomUtil.nextBytes(16); final String backupMediaKey = "%s/%s/%s".formatted( backupUser.backupDir(), @@ -642,7 +750,7 @@ public class BackupManagerTest { @Test public void listExpiredBackups() { final List backupUsers = IntStream.range(0, 10) - .mapToObj(i -> backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA)) + .mapToObj(i -> backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, BackupLevel.PAID)) .toList(); for (int i = 0; i < backupUsers.size(); i++) { testClock.pin(Instant.ofEpochSecond(i)); @@ -680,11 +788,11 @@ public class BackupManagerTest { // refreshed media timestamp at t=5 testClock.pin(Instant.ofEpochSecond(5)); - backupManager.createMessageBackupUploadDescriptor(backupUser(backupId, BackupLevel.MEDIA)).join(); + backupManager.createMessageBackupUploadDescriptor(backupUser(backupId, BackupCredentialType.MESSAGES, BackupLevel.PAID)).join(); // refreshed messages timestamp at t=6 testClock.pin(Instant.ofEpochSecond(6)); - backupManager.createMessageBackupUploadDescriptor(backupUser(backupId, BackupLevel.MESSAGES)).join(); + backupManager.createMessageBackupUploadDescriptor(backupUser(backupId, BackupCredentialType.MESSAGES, BackupLevel.FREE)).join(); Function> getExpired = time -> backupManager .getExpiredBackups(1, Schedulers.immediate(), time) @@ -704,7 +812,7 @@ public class BackupManagerTest { @ParameterizedTest @EnumSource(mode = EnumSource.Mode.INCLUDE, names = {"MEDIA", "ALL"}) public void expireBackup(ExpiredBackup.ExpirationType expirationType) { - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, BackupLevel.PAID); backupManager.createMessageBackupUploadDescriptor(backupUser).join(); final String expectedPrefixToDelete = switch (expirationType) { @@ -746,7 +854,7 @@ public class BackupManagerTest { @Test public void deleteBackupPaginated() { - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MESSAGES, BackupLevel.PAID); backupManager.createMessageBackupUploadDescriptor(backupUser).join(); final ExpiredBackup expiredBackup = expiredBackup(ExpiredBackup.ExpirationType.MEDIA, backupUser); @@ -845,9 +953,9 @@ public class BackupManagerTest { } /** - * Create BackupUser with the provided backupId and tier + * Create BackupUser with the provided backupId, credential type, and tier */ - private AuthenticatedBackupUser backupUser(final byte[] backupId, final BackupLevel backupLevel) { + private AuthenticatedBackupUser backupUser(final byte[] backupId, final BackupCredentialType credentialType, final BackupLevel backupLevel) { // Won't actually validate the public key, but need to have a public key to perform BackupsDB operations byte[] privateKey = new byte[32]; ByteBuffer.wrap(privateKey).put(backupId); @@ -856,14 +964,14 @@ public class BackupManagerTest { } catch (InvalidKeyException e) { throw new RuntimeException(e); } - return retrieveBackupUser(backupId, backupLevel); + return retrieveBackupUser(backupId, credentialType, backupLevel); } /** * Retrieve an existing BackupUser from the database */ - private AuthenticatedBackupUser retrieveBackupUser(final byte[] backupId, final BackupLevel backupLevel) { + private AuthenticatedBackupUser retrieveBackupUser(final byte[] backupId, final BackupCredentialType credentialType, final BackupLevel backupLevel) { final BackupsDb.AuthenticationData authData = backupsDb.retrieveAuthenticationData(backupId).join().get(); - return new AuthenticatedBackupUser(backupId, backupLevel, authData.backupDir(), authData.mediaDir()); + return new AuthenticatedBackupUser(backupId, credentialType, backupLevel, authData.backupDir(), authData.mediaDir()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupsDbTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupsDbTest.java index b93c19557..ab7dd1b25 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupsDbTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupsDbTest.java @@ -23,6 +23,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.ValueSource; import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.zkgroup.backups.BackupCredentialType; import org.signal.libsignal.zkgroup.backups.BackupLevel; import org.whispersystems.textsecuregcm.auth.AuthenticatedBackupUser; import org.whispersystems.textsecuregcm.storage.DynamoDbExtension; @@ -51,7 +52,7 @@ public class BackupsDbTest { @Test public void trackMediaStats() { - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID); // add at least one message backup so we can describe it backupsDb.addMessageBackup(backupUser).join(); int total = 0; @@ -74,7 +75,7 @@ public class BackupsDbTest { @ValueSource(booleans = {false, true}) public void setUsage(boolean mediaAlreadyExists) { testClock.pin(Instant.ofEpochSecond(5)); - final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); + final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID); if (mediaAlreadyExists) { this.backupsDb.trackMedia(backupUser, 1, 10).join(); } @@ -90,12 +91,12 @@ public class BackupsDbTest { final byte[] backupId = TestRandomUtil.nextBytes(16); // Refresh media/messages at t=0 testClock.pin(Instant.ofEpochSecond(0L)); - backupsDb.setPublicKey(backupId, BackupLevel.MEDIA, Curve.generateKeyPair().getPublicKey()).join(); - this.backupsDb.ttlRefresh(backupUser(backupId, BackupLevel.MEDIA)).join(); + backupsDb.setPublicKey(backupId, BackupLevel.PAID, Curve.generateKeyPair().getPublicKey()).join(); + this.backupsDb.ttlRefresh(backupUser(backupId, BackupCredentialType.MEDIA, BackupLevel.PAID)).join(); // refresh only messages at t=2 testClock.pin(Instant.ofEpochSecond(2L)); - this.backupsDb.ttlRefresh(backupUser(backupId, BackupLevel.MESSAGES)).join(); + this.backupsDb.ttlRefresh(backupUser(backupId, BackupCredentialType.MEDIA, BackupLevel.FREE)).join(); final Function> expiredBackups = purgeTime -> backupsDb .getExpiredBackups(1, Schedulers.immediate(), purgeTime) @@ -132,13 +133,13 @@ public class BackupsDbTest { final byte[] backupId = TestRandomUtil.nextBytes(16); // Refresh media/messages at t=0 testClock.pin(Instant.ofEpochSecond(0L)); - backupsDb.setPublicKey(backupId, BackupLevel.MEDIA, Curve.generateKeyPair().getPublicKey()).join(); - this.backupsDb.ttlRefresh(backupUser(backupId, BackupLevel.MEDIA)).join(); + backupsDb.setPublicKey(backupId, BackupLevel.PAID, Curve.generateKeyPair().getPublicKey()).join(); + this.backupsDb.ttlRefresh(backupUser(backupId, BackupCredentialType.MEDIA, BackupLevel.PAID)).join(); if (expirationType == ExpiredBackup.ExpirationType.MEDIA) { // refresh only messages at t=2 so that we only expire media at t=1 testClock.pin(Instant.ofEpochSecond(2L)); - this.backupsDb.ttlRefresh(backupUser(backupId, BackupLevel.MESSAGES)).join(); + this.backupsDb.ttlRefresh(backupUser(backupId, BackupCredentialType.MEDIA, BackupLevel.FREE)).join(); } final Function> expiredBackups = purgeTime -> { @@ -192,7 +193,7 @@ public class BackupsDbTest { // should be nothing to expire at t=1 assertThat(opt).isEmpty(); // The backup should still exist - backupsDb.describeBackup(backupUser(backupId, BackupLevel.MEDIA)).join(); + backupsDb.describeBackup(backupUser(backupId, BackupCredentialType.MEDIA, BackupLevel.PAID)).join(); } else { // Cleaned up the failed attempt, now should tell us to clean the whole backup assertThat(opt.get()).matches(eb -> eb.expirationType() == ExpiredBackup.ExpirationType.ALL, @@ -202,7 +203,7 @@ public class BackupsDbTest { // The backup entry should be gone assertThat(CompletableFutureTestUtil.assertFailsWithCause(StatusRuntimeException.class, - backupsDb.describeBackup(backupUser(backupId, BackupLevel.MEDIA))) + backupsDb.describeBackup(backupUser(backupId, BackupCredentialType.MEDIA, BackupLevel.PAID))) .getStatus().getCode()) .isEqualTo(Status.Code.NOT_FOUND); assertThat(expiredBackups.apply(Instant.ofEpochSecond(10))).isEmpty(); @@ -211,9 +212,9 @@ public class BackupsDbTest { @Test public void list() { - final AuthenticatedBackupUser u1 = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MESSAGES); - final AuthenticatedBackupUser u2 = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); - final AuthenticatedBackupUser u3 = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); + final AuthenticatedBackupUser u1 = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.FREE); + final AuthenticatedBackupUser u2 = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID); + final AuthenticatedBackupUser u3 = backupUser(TestRandomUtil.nextBytes(16), BackupCredentialType.MEDIA, BackupLevel.PAID); // add at least one message backup, so we can describe it testClock.pin(Instant.ofEpochSecond(10)); @@ -248,7 +249,7 @@ public class BackupsDbTest { assertThat(sbm3.lastRefresh()).isEqualTo(sbm3.lastMediaRefresh()).isEqualTo(Instant.ofEpochSecond(30)); } - private AuthenticatedBackupUser backupUser(final byte[] backupId, final BackupLevel backupLevel) { - return new AuthenticatedBackupUser(backupId, backupLevel, "myBackupDir", "myMediaDir"); + private AuthenticatedBackupUser backupUser(final byte[] backupId, final BackupCredentialType credentialType, final BackupLevel backupLevel) { + return new AuthenticatedBackupUser(backupId, credentialType, backupLevel, "myBackupDir", "myMediaDir"); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ArchiveControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ArchiveControllerTest.java index a26955c90..ab3264de2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ArchiveControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ArchiveControllerTest.java @@ -10,6 +10,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import io.dropwizard.auth.AuthValueFactoryProvider; @@ -51,6 +52,7 @@ import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.ServerSecretParams; import org.signal.libsignal.zkgroup.VerificationFailedException; import org.signal.libsignal.zkgroup.backups.BackupAuthCredentialPresentation; +import org.signal.libsignal.zkgroup.backups.BackupCredentialType; import org.signal.libsignal.zkgroup.backups.BackupLevel; import org.signal.libsignal.zkgroup.receipts.ClientZkReceiptOperations; import org.signal.libsignal.zkgroup.receipts.ReceiptCredential; @@ -59,8 +61,8 @@ import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialRequestContext; import org.signal.libsignal.zkgroup.receipts.ReceiptCredentialResponse; import org.signal.libsignal.zkgroup.receipts.ReceiptSerial; import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations; -import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedBackupUser; +import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.backup.BackupAuthManager; import org.whispersystems.textsecuregcm.backup.BackupAuthTestUtil; import org.whispersystems.textsecuregcm.backup.BackupManager; @@ -71,6 +73,7 @@ import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper; import org.whispersystems.textsecuregcm.mappers.GrpcStatusRuntimeExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; +import org.whispersystems.textsecuregcm.util.EnumMapUtil; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.TestRandomUtil; import reactor.core.publisher.Flux; @@ -95,7 +98,8 @@ public class ArchiveControllerTest { .build(); private final UUID aci = UUID.randomUUID(); - private final byte[] backupKey = TestRandomUtil.nextBytes(32); + private final byte[] messagesBackupKey = TestRandomUtil.nextBytes(32); + private final byte[] mediaBackupKey = TestRandomUtil.nextBytes(32); @BeforeEach public void setUp() { @@ -132,7 +136,7 @@ public class ArchiveControllerTest { public void anonymousAuthOnly(final String method, final String path, final String body) throws VerificationFailedException { final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( - BackupLevel.MEDIA, backupKey, aci); + BackupLevel.PAID, messagesBackupKey, aci); final Invocation.Builder request = resources.getJerseyTest() .target(path) .request() @@ -152,15 +156,22 @@ public class ArchiveControllerTest { @Test public void setBackupId() throws RateLimitExceededException { - when(backupAuthManager.commitBackupId(any(), any())).thenReturn(CompletableFuture.completedFuture(null)); + when(backupAuthManager.commitBackupId(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null)); final Response response = resources.getJerseyTest() .target("v1/archives/backupid") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ArchiveController.SetBackupIdRequest(backupAuthTestUtil.getRequest(backupKey, aci)), + .put(Entity.entity(new ArchiveController.SetBackupIdRequest( + backupAuthTestUtil.getRequest(messagesBackupKey, aci), + backupAuthTestUtil.getRequest(mediaBackupKey, aci)), MediaType.APPLICATION_JSON_TYPE)); + assertThat(response.getStatus()).isEqualTo(204); + + verify(backupAuthManager).commitBackupId(AuthHelper.VALID_ACCOUNT, + backupAuthTestUtil.getRequest(messagesBackupKey, aci), + backupAuthTestUtil.getRequest(mediaBackupKey, aci)); } @Test @@ -191,7 +202,7 @@ public class ArchiveControllerTest { when(backupManager.setPublicKey(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null)); final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( - BackupLevel.MEDIA, backupKey, aci); + BackupLevel.PAID, messagesBackupKey, aci); final Response response = resources.getJerseyTest() .target("v1/archives/keys") .request() @@ -208,7 +219,7 @@ public class ArchiveControllerTest { when(backupManager.setPublicKey(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null)); final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( - BackupLevel.MEDIA, backupKey, aci); + BackupLevel.PAID, messagesBackupKey, aci); final Response response = resources.getJerseyTest() .target("v1/archives/keys") .request() @@ -223,7 +234,7 @@ public class ArchiveControllerTest { when(backupManager.setPublicKey(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null)); final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( - BackupLevel.MEDIA, backupKey, aci); + BackupLevel.PAID, messagesBackupKey, aci); final Response response = resources.getJerseyTest() .target("v1/archives/keys") .request() @@ -239,8 +250,8 @@ public class ArchiveControllerTest { @ParameterizedTest @CsvSource(textBlock = """ {}, 422 - '{"backupAuthCredentialRequest": "aaa"}', 400 - '{"backupAuthCredentialRequest": ""}', 400 + '{"messagesBackupAuthCredentialRequest": "aaa", "mediaBackupAuthCredentialRequest": "aaa"}', 400 + '{"messagesBackupAuthCredentialRequest": "", "mediaBackupAuthCredentialRequest": ""}', 400 """) public void setBackupIdInvalid(final String requestBody, final int expectedStatus) { final Response response = resources.getJerseyTest() @@ -264,15 +275,17 @@ public class ArchiveControllerTest { public void setBackupIdException(final Exception ex, final boolean sync, final int expectedStatus) throws RateLimitExceededException { if (sync) { - when(backupAuthManager.commitBackupId(any(), any())).thenThrow(ex); + when(backupAuthManager.commitBackupId(any(), any(), any())).thenThrow(ex); } else { - when(backupAuthManager.commitBackupId(any(), any())).thenReturn(CompletableFuture.failedFuture(ex)); + when(backupAuthManager.commitBackupId(any(), any(), any())).thenReturn(CompletableFuture.failedFuture(ex)); } final Response response = resources.getJerseyTest() .target("v1/archives/backupid") .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) - .put(Entity.entity(new ArchiveController.SetBackupIdRequest(backupAuthTestUtil.getRequest(backupKey, aci)), + .put(Entity.entity(new ArchiveController.SetBackupIdRequest( + backupAuthTestUtil.getRequest(messagesBackupKey, aci), + backupAuthTestUtil.getRequest(mediaBackupKey, aci)), MediaType.APPLICATION_JSON_TYPE)); assertThat(response.getStatus()).isEqualTo(expectedStatus); } @@ -281,18 +294,36 @@ public class ArchiveControllerTest { public void getCredentials() { final Instant start = Instant.now().truncatedTo(ChronoUnit.DAYS); final Instant end = start.plus(Duration.ofDays(1)); - final List expectedResponse = backupAuthTestUtil.getCredentials( - BackupLevel.MEDIA, backupAuthTestUtil.getRequest(backupKey, aci), start, end); - when(backupAuthManager.getBackupAuthCredentials(any(), eq(start), eq(end))).thenReturn( - CompletableFuture.completedFuture(expectedResponse)); - final ArchiveController.BackupAuthCredentialsResponse creds = resources.getJerseyTest() + + final Map> expectedCredentialsByType = + EnumMapUtil.toEnumMap(BackupCredentialType.class, credentialType -> backupAuthTestUtil.getCredentials( + BackupLevel.PAID, backupAuthTestUtil.getRequest(messagesBackupKey, aci), credentialType, start, end)); + + expectedCredentialsByType.forEach((credentialType, expectedCredentials) -> + when(backupAuthManager.getBackupAuthCredentials(any(), eq(credentialType), eq(start), eq(end))).thenReturn( + CompletableFuture.completedFuture(expectedCredentials))); + + final ArchiveController.BackupAuthCredentialsResponse credentialResponse = resources.getJerseyTest() .target("v1/archives/auth") .queryParam("redemptionStartSeconds", start.getEpochSecond()) .queryParam("redemptionEndSeconds", end.getEpochSecond()) .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .get(ArchiveController.BackupAuthCredentialsResponse.class); - assertThat(creds.credentials().getFirst().redemptionTime()).isEqualTo(start.getEpochSecond()); + + expectedCredentialsByType.forEach((credentialType, expectedCredentials) -> { + assertThat(credentialResponse.credentials().get(credentialType)).size().isEqualTo(expectedCredentials.size()); + assertThat(credentialResponse.credentials().get(credentialType).getFirst().redemptionTime()) + .isEqualTo(start.getEpochSecond()); + + for (int i = 0; i < expectedCredentials.size(); i++) { + assertThat(credentialResponse.credentials().get(credentialType).get(i).redemptionTime()) + .isEqualTo(expectedCredentials.get(i).redemptionTime().getEpochSecond()); + + assertThat(credentialResponse.credentials().get(credentialType).get(i).credential()) + .isEqualTo(expectedCredentials.get(i).credential().serialize()); + } + }); } public enum BadCredentialsType {MISSING_START, MISSING_END, MISSING_BOTH} @@ -322,9 +353,9 @@ public class ArchiveControllerTest { @Test public void getBackupInfo() throws VerificationFailedException { final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( - BackupLevel.MEDIA, backupKey, aci); + BackupLevel.PAID, messagesBackupKey, aci); when(backupManager.authenticateBackupUser(any(), any())) - .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); + .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID))); when(backupManager.backupInfo(any())).thenReturn(CompletableFuture.completedFuture(new BackupManager.BackupInfo( 1, "myBackupDir", "myMediaDir", "filename", Optional.empty()))); final ArchiveController.BackupInfoResponse response = resources.getJerseyTest() @@ -342,9 +373,9 @@ public class ArchiveControllerTest { @Test public void putMediaBatchSuccess() throws VerificationFailedException { final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( - BackupLevel.MEDIA, backupKey, aci); + BackupLevel.PAID, messagesBackupKey, aci); when(backupManager.authenticateBackupUser(any(), any())) - .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); + .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID))); final byte[][] mediaIds = new byte[][]{TestRandomUtil.nextBytes(15), TestRandomUtil.nextBytes(15)}; when(backupManager.copyToBackup(any(), any())) .thenReturn(Flux.just( @@ -389,9 +420,9 @@ public class ArchiveControllerTest { public void putMediaBatchPartialFailure() throws VerificationFailedException { final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( - BackupLevel.MEDIA, backupKey, aci); + BackupLevel.PAID, messagesBackupKey, aci); when(backupManager.authenticateBackupUser(any(), any())) - .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); + .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID))); final byte[][] mediaIds = IntStream.range(0, 4).mapToObj(i -> TestRandomUtil.nextBytes(15)).toArray(byte[][]::new); when(backupManager.copyToBackup(any(), any())) @@ -448,9 +479,9 @@ public class ArchiveControllerTest { @Test public void copyMediaWithNegativeLength() throws VerificationFailedException { final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( - BackupLevel.MEDIA, backupKey, aci); + BackupLevel.PAID, messagesBackupKey, aci); when(backupManager.authenticateBackupUser(any(), any())) - .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); + .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID))); final byte[][] mediaIds = new byte[][]{TestRandomUtil.nextBytes(15), TestRandomUtil.nextBytes(15)}; final Response r = resources.getJerseyTest() .target("v1/archives/media/batch") @@ -483,9 +514,9 @@ public class ArchiveControllerTest { @CartesianTest.Values(booleans = {true, false}) final boolean cursorReturned) throws VerificationFailedException { final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( - BackupLevel.MEDIA, backupKey, aci); + BackupLevel.PAID, messagesBackupKey, aci); when(backupManager.authenticateBackupUser(any(), any())) - .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); + .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID))); final byte[] mediaId = TestRandomUtil.nextBytes(15); final Optional expectedCursor = cursorProvided ? Optional.of("myCursor") : Optional.empty(); @@ -517,10 +548,10 @@ public class ArchiveControllerTest { @Test public void delete() throws VerificationFailedException { - final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation(BackupLevel.MEDIA, - backupKey, aci); + final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation(BackupLevel.PAID, + messagesBackupKey, aci); when(backupManager.authenticateBackupUser(any(), any())) - .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); + .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID))); final ArchiveController.DeleteMedia deleteRequest = new ArchiveController.DeleteMedia( IntStream @@ -544,9 +575,9 @@ public class ArchiveControllerTest { @Test public void mediaUploadForm() throws RateLimitExceededException, VerificationFailedException { final BackupAuthCredentialPresentation presentation = - backupAuthTestUtil.getPresentation(BackupLevel.MEDIA, backupKey, aci); + backupAuthTestUtil.getPresentation(BackupLevel.PAID, messagesBackupKey, aci); when(backupManager.authenticateBackupUser(any(), any())) - .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); + .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID))); when(backupManager.createTemporaryAttachmentUploadDescriptor(any())) .thenReturn(CompletableFuture.completedFuture( new BackupUploadDescriptor(3, "abc", Map.of("k", "v"), "example.org"))); @@ -576,9 +607,9 @@ public class ArchiveControllerTest { @Test public void readAuth() throws VerificationFailedException { final BackupAuthCredentialPresentation presentation = - backupAuthTestUtil.getPresentation(BackupLevel.MEDIA, backupKey, aci); + backupAuthTestUtil.getPresentation(BackupLevel.PAID, messagesBackupKey, aci); when(backupManager.authenticateBackupUser(any(), any())) - .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); + .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID))); when(backupManager.generateReadAuth(any(), eq(3))).thenReturn(Map.of("key", "value")); final ArchiveController.ReadAuthResponse response = resources.getJerseyTest() .target("v1/archives/auth/read") @@ -593,7 +624,7 @@ public class ArchiveControllerTest { @Test public void readAuthInvalidParam() throws VerificationFailedException { final BackupAuthCredentialPresentation presentation = - backupAuthTestUtil.getPresentation(BackupLevel.MEDIA, backupKey, aci); + backupAuthTestUtil.getPresentation(BackupLevel.PAID, messagesBackupKey, aci); Response response = resources.getJerseyTest() .target("v1/archives/auth/read") .request() @@ -615,9 +646,9 @@ public class ArchiveControllerTest { @Test public void deleteEntireBackup() throws VerificationFailedException { final BackupAuthCredentialPresentation presentation = - backupAuthTestUtil.getPresentation(BackupLevel.MEDIA, backupKey, aci); + backupAuthTestUtil.getPresentation(BackupLevel.PAID, messagesBackupKey, aci); when(backupManager.authenticateBackupUser(any(), any())) - .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); + .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID))); when(backupManager.deleteEntireBackup(any())).thenReturn(CompletableFuture.completedFuture(null)); Response response = resources.getJerseyTest() .target("v1/archives/") @@ -631,25 +662,25 @@ public class ArchiveControllerTest { @Test public void invalidSourceAttachmentKey() throws VerificationFailedException { final BackupAuthCredentialPresentation presentation = backupAuthTestUtil.getPresentation( - BackupLevel.MEDIA, backupKey, aci); + BackupLevel.PAID, messagesBackupKey, aci); when(backupManager.authenticateBackupUser(any(), any())) - .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); + .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupCredentialType.MESSAGES, BackupLevel.PAID))); final Response r = resources.getJerseyTest() .target("v1/archives/media") .request() .header("X-Signal-ZK-Auth", Base64.getEncoder().encodeToString(presentation.serialize())) .header("X-Signal-ZK-Auth-Signature", "aaa") .put(Entity.json(new ArchiveController.CopyMediaRequest( - new RemoteAttachment(3, "invalid/urlBase64"), - 100, - TestRandomUtil.nextBytes(15), - TestRandomUtil.nextBytes(32), - TestRandomUtil.nextBytes(32), - TestRandomUtil.nextBytes(16)))); + new RemoteAttachment(3, "invalid/urlBase64"), + 100, + TestRandomUtil.nextBytes(15), + TestRandomUtil.nextBytes(32), + TestRandomUtil.nextBytes(32), + TestRandomUtil.nextBytes(16)))); assertThat(r.getStatus()).isEqualTo(422); } - private static AuthenticatedBackupUser backupUser(byte[] backupId, BackupLevel backupLevel) { - return new AuthenticatedBackupUser(backupId, backupLevel, "myBackupDir", "myMediaDir"); + private static AuthenticatedBackupUser backupUser(final byte[] backupId, final BackupCredentialType credentialType, final BackupLevel backupLevel) { + return new AuthenticatedBackupUser(backupId, credentialType, backupLevel, "myBackupDir", "myMediaDir"); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java index 18e8ccb9d..ce58c046d 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java @@ -53,6 +53,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; +import org.signal.libsignal.zkgroup.backups.BackupCredentialType; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; @@ -426,7 +427,7 @@ class AccountsTest { generateAccount(e164, existingUuid, UUID.randomUUID(), List.of(generateDevice(DEVICE_ID_1))); // the backup credential request and share-set are always preserved across account reclaims - existingAccount.setBackupCredentialRequest(TestRandomUtil.nextBytes(32)); + existingAccount.setBackupCredentialRequests(TestRandomUtil.nextBytes(32), TestRandomUtil.nextBytes(32)); existingAccount.setSvr3ShareSet(TestRandomUtil.nextBytes(100)); createAccount(existingAccount); final Account secondAccount = @@ -435,7 +436,10 @@ class AccountsTest { reclaimAccount(secondAccount); final Account reclaimed = accounts.getByAccountIdentifier(existingUuid).get(); - assertThat(reclaimed.getBackupCredentialRequest()).isEqualTo(existingAccount.getBackupCredentialRequest()); + assertThat(reclaimed.getBackupCredentialRequest(BackupCredentialType.MESSAGES).get()) + .isEqualTo(existingAccount.getBackupCredentialRequest(BackupCredentialType.MESSAGES).get()); + assertThat(reclaimed.getBackupCredentialRequest(BackupCredentialType.MEDIA).get()) + .isEqualTo(existingAccount.getBackupCredentialRequest(BackupCredentialType.MEDIA).get()); assertThat(reclaimed.getSvr3ShareSet()).isEqualTo(existingAccount.getSvr3ShareSet()); }