diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupAuthManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupAuthManager.java index 74e9a4eec..ddfadcce5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupAuthManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupAuthManager.java @@ -88,7 +88,7 @@ public class BackupAuthManager { * @throws RateLimitExceededException If too many backup-ids have been committed */ public CompletableFuture commitBackupId(final Account account, - final BackupAuthCredentialRequest backupAuthCredentialRequest) throws RateLimitExceededException { + final BackupAuthCredentialRequest backupAuthCredentialRequest) { if (configuredBackupLevel(account).isEmpty()) { throw Status.PERMISSION_DENIED.withDescription("Backups not allowed on account").asRuntimeException(); } @@ -101,11 +101,12 @@ public class BackupAuthManager { return CompletableFuture.completedFuture(null); } - rateLimiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID).validate(account.getUuid()); - - return this.accountsManager - .updateAsync(account, acc -> acc.setBackupCredentialRequest(serializedRequest)) - .thenRun(Util.NOOP); + return rateLimiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID) + .validateAsync(account.getUuid()) + .thenCompose(ignored -> this.accountsManager + .updateAsync(account, acc -> acc.setBackupCredentialRequest(serializedRequest)) + .thenRun(Util.NOOP)) + .toCompletableFuture(); } public record Credential(BackupAuthCredentialResponse credential, Instant redemptionTime) {} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupManager.java index 2f16fb29a..a79892841 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/backup/BackupManager.java @@ -148,19 +148,19 @@ public class BackupManager { .thenApply(result -> cdn3BackupCredentialGenerator.generateUpload(cdnMessageBackupName(backupUser))); } - public BackupUploadDescriptor createTemporaryAttachmentUploadDescriptor(final AuthenticatedBackupUser backupUser) - throws RateLimitExceededException { + public CompletionStage createTemporaryAttachmentUploadDescriptor( + final AuthenticatedBackupUser backupUser) { checkBackupLevel(backupUser, BackupLevel.MEDIA); - RateLimiter.adaptLegacyException(() -> rateLimiters + return RateLimiter.adaptLegacyException(rateLimiters .forDescriptor(RateLimiters.For.BACKUP_ATTACHMENT) - .validate(rateLimitKey(backupUser))); - - final byte[] bytes = new byte[15]; - secureRandom.nextBytes(bytes); - final String attachmentKey = Base64.getUrlEncoder().encodeToString(bytes); - final AttachmentGenerator.Descriptor descriptor = tusAttachmentGenerator.generateAttachment(attachmentKey); - return new BackupUploadDescriptor(3, attachmentKey, descriptor.headers(), descriptor.signedUploadLocation()); + .validateAsync(rateLimitKey(backupUser))).thenApply(ignored -> { + final byte[] bytes = new byte[15]; + secureRandom.nextBytes(bytes); + final String attachmentKey = Base64.getUrlEncoder().encodeToString(bytes); + final AttachmentGenerator.Descriptor descriptor = tusAttachmentGenerator.generateAttachment(attachmentKey); + return new BackupUploadDescriptor(3, attachmentKey, descriptor.headers(), descriptor.signedUploadLocation()); + }); } /** diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ArchiveController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ArchiveController.java index c17447006..8be846ad5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ArchiveController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ArchiveController.java @@ -27,8 +27,8 @@ import java.util.Base64; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; -import javax.annotation.Nullable; import javax.validation.Valid; import javax.validation.constraints.Max; import javax.validation.constraints.Min; @@ -186,13 +186,13 @@ public class ArchiveController { description = """ After setting a blinded backup-id with PUT /v1/archives/, this fetches credentials that can be used to perform operations against that backup-id. Clients may (and should) request up to 7 days of credentials at a time. - + The redemptionStart and redemptionEnd seconds must be UTC day aligned, and must not span more than 7 days. - + Each credential contains a receipt level which indicates the backup level the credential is good for. If the account has paid backup access that expires at some point in the provided redemption window, credentials with redemption times after the expiration may be on a lower backup level. - + Clients must validate the receipt level on the credential matches a known receipt level before using it. """) @ApiResponse(responseCode = "200", content = @Content(schema = @Schema(implementation = BackupAuthCredentialsResponse.class))) @@ -455,13 +455,7 @@ public class ArchiveController { throw new BadRequestException("must not use authenticated connection for anonymous operations"); } return backupManager.authenticateBackupUser(presentation.presentation, signature.signature) - .thenApply(backupUser -> { - try { - return backupManager.createTemporaryAttachmentUploadDescriptor(backupUser); - } catch (RateLimitExceededException e) { - throw ExceptionUtils.wrap(e); - } - }) + .thenCompose(backupUser -> backupManager.createTemporaryAttachmentUploadDescriptor(backupUser)) .thenApply(result -> new UploadDescriptorResponse( result.cdn(), result.key(), @@ -553,14 +547,10 @@ public class ArchiveController { throw new BadRequestException("must not use authenticated connection for anonymous operations"); } - final AuthenticatedBackupUser backupUser = backupManager.authenticateBackupUser( - presentation.presentation, signature.signature).join(); - - final boolean fits = backupManager.canStoreMedia(backupUser, copyMediaRequest.objectLength()).join(); - if (!fits) { - throw new ClientErrorException("Media quota exhausted", Response.Status.REQUEST_ENTITY_TOO_LARGE); - } - return copyMediaImpl(backupUser, copyMediaRequest) + return backupManager + .authenticateBackupUser(presentation.presentation, signature.signature) + .thenCompose(backupUser -> checkMediaFits(backupUser, copyMediaRequest.objectLength) + .thenCompose(ignored -> copyMediaImpl(backupUser, copyMediaRequest))) .thenApply(result -> new CopyMediaResponse(result.cdn())) .exceptionally(e -> { final Throwable unwrapped = ExceptionUtils.unwrap(e); @@ -574,6 +564,16 @@ public class ArchiveController { }); } + private CompletableFuture checkMediaFits(AuthenticatedBackupUser backupUser, long amountToStore) { + return backupManager.canStoreMedia(backupUser, amountToStore) + .thenApply(fits -> { + if (!fits) { + throw new ClientErrorException("Media quota exhausted", Response.Status.REQUEST_ENTITY_TOO_LARGE); + } + return null; + }); + } + private CompletionStage copyMediaImpl(final AuthenticatedBackupUser backupUser, final CopyMediaRequest copyMediaRequest) { return this.backupManager.copyToBackup( @@ -662,41 +662,36 @@ public class ArchiveController { throw new BadRequestException("must not use authenticated connection for anonymous operations"); } - final AuthenticatedBackupUser backupUser = backupManager.authenticateBackupUser( - presentation.presentation, signature.signature).join(); - // If the entire batch won't fit in the user's remaining quota, reject the whole request. final long expectedStorage = copyMediaRequest.items().stream().mapToLong(CopyMediaRequest::objectLength).sum(); - final boolean fits = backupManager.canStoreMedia(backupUser, expectedStorage).join(); - if (!fits) { - throw new ClientErrorException("Media quota exhausted", Response.Status.REQUEST_ENTITY_TOO_LARGE); - } - return Flux.fromIterable(copyMediaRequest.items) - // Operate sequentially, waiting for one copy to finish before starting the next one. At least right now, - // copying concurrently will introduce contention over the metadata. - .concatMap(request -> Mono - .fromCompletionStage(copyMediaImpl(backupUser, request)) - .map(result -> new CopyMediaBatchResponse.Entry(200, null, result.cdn(), result.key())) - .onErrorResume(throwable -> ExceptionUtils.unwrap(throwable) instanceof IOException, throwable -> { - final Throwable unwrapped = ExceptionUtils.unwrap(throwable); + return backupManager.authenticateBackupUser(presentation.presentation, signature.signature) + .thenCompose(backupUser -> checkMediaFits(backupUser, expectedStorage).thenCompose( + ignored -> Flux.fromIterable(copyMediaRequest.items) + // Operate sequentially, waiting for one copy to finish before starting the next one. At least right now, + // copying concurrently will introduce contention over the metadata. + .concatMap(request -> Mono + .fromCompletionStage(copyMediaImpl(backupUser, request)) + .map(result -> new CopyMediaBatchResponse.Entry(200, null, result.cdn(), result.key())) + .onErrorResume(throwable -> ExceptionUtils.unwrap(throwable) instanceof IOException, throwable -> { + final Throwable unwrapped = ExceptionUtils.unwrap(throwable); - int status; - String error; - if (unwrapped instanceof SourceObjectNotFoundException) { - status = 410; - error = "Source object not found " + unwrapped.getMessage(); - } else if (unwrapped instanceof InvalidLengthException) { - status = 400; - error = "Invalid length " + unwrapped.getMessage(); - } else { - throw ExceptionUtils.wrap(throwable); - } - return Mono.just(new CopyMediaBatchResponse.Entry(status, error, null, request.mediaId)); - })) - .collectList() - .map(list -> Response.status(207).entity(new CopyMediaBatchResponse(list)).build()) - .toFuture(); + int status; + String error; + if (unwrapped instanceof SourceObjectNotFoundException) { + status = 410; + error = "Source object not found " + unwrapped.getMessage(); + } else if (unwrapped instanceof InvalidLengthException) { + status = 400; + error = "Invalid length " + unwrapped.getMessage(); + } else { + throw ExceptionUtils.wrap(throwable); + } + return Mono.just(new CopyMediaBatchResponse.Entry(status, error, null, request.mediaId)); + })) + .collectList() + .map(list -> Response.status(207).entity(new CopyMediaBatchResponse(list)).build()) + .toFuture())); } @POST @@ -858,8 +853,7 @@ public class ArchiveController { @DELETE @Produces(MediaType.APPLICATION_JSON) - @Operation(summary = "Delete entire backup", - description = """ + @Operation(summary = "Delete entire backup", description = """ Delete all backup metadata, objects, and stored public key. To use backups again, a public key must be resupplied. """) @ApiResponse(responseCode = "204", description = "The backup has been successfully removed") diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java index 98fa8a362..4fb063bee 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java @@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.limits; import java.util.UUID; import java.util.concurrent.CompletionStage; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.util.ExceptionUtils; import reactor.core.publisher.Mono; public interface RateLimiter { @@ -78,6 +79,16 @@ public interface RateLimiter { return clearAsync(accountUuid.toString()); } + /** + * If the future throws a {@link RateLimitExceededException}, it will adapt it to ensure that + * {@link RateLimitExceededException#isLegacy()} returns {@code false} + */ + static CompletionStage adaptLegacyException(final CompletionStage rateLimitFuture) { + return rateLimitFuture.exceptionally(ExceptionUtils.exceptionallyHandler(RateLimitExceededException.class, e -> { + throw ExceptionUtils.wrap(new RateLimitExceededException(e.getRetryDuration().orElse(null), false)); + })); + } + /** * If the wrapped {@code validate()} call throws a {@link RateLimitExceededException}, it will adapt it to ensure that * {@link RateLimitExceededException#isLegacy()} returns {@code false} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java index a48fcf6b4..2d1accfdb 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupAuthManagerTest.java @@ -10,7 +10,6 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; @@ -29,6 +28,7 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; import java.util.stream.Stream; +import javax.annotation.Nullable; import org.assertj.core.api.Assertions; import org.assertj.core.api.ThrowableAssert; import org.junit.jupiter.api.BeforeEach; @@ -63,7 +63,6 @@ import org.whispersystems.textsecuregcm.tests.util.ExperimentHelper; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.TestRandomUtil; -import javax.annotation.Nullable; public class BackupAuthManagerTest { @@ -407,8 +406,9 @@ public class BackupAuthManagerTest { when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account)); // Should be rate limited - assertThatExceptionOfType(RateLimitExceededException.class) - .isThrownBy(() -> authManager.commitBackupId(account, credentialRequest).join()); + CompletableFutureTestUtil.assertFailsWithCause( + RateLimitExceededException.class, + authManager.commitBackupId(account, credentialRequest)); // If we don't change the request, shouldn't be rate limited when(account.getBackupCredentialRequest()).thenReturn(credentialRequest.serialize()); @@ -426,6 +426,7 @@ public class BackupAuthManagerTest { private static RateLimiters allowRateLimiter() { final RateLimiters limiters = mock(RateLimiters.class); final RateLimiter limiter = mock(RateLimiter.class); + when(limiter.validateAsync(any(UUID.class))).thenReturn(CompletableFuture.completedFuture(null)); when(limiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID)).thenReturn(limiter); return limiters; } @@ -433,11 +434,8 @@ public class BackupAuthManagerTest { private static RateLimiters denyRateLimiter(final UUID aci) { final RateLimiters limiters = mock(RateLimiters.class); final RateLimiter limiter = mock(RateLimiter.class); - try { - doThrow(new RateLimitExceededException(null, false)).when(limiter).validate(aci); - } catch (RateLimitExceededException e) { - throw new AssertionError(e); - } + when(limiter.validateAsync(aci)) + .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null, false))); when(limiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID)).thenReturn(limiter); return limiters; } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java index be958cb6b..d6ddbeddf 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/backup/BackupManagerTest.java @@ -13,7 +13,6 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; @@ -25,7 +24,6 @@ import static org.mockito.Mockito.when; import io.grpc.Status; import io.grpc.StatusRuntimeException; import java.io.IOException; -import java.net.URI; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.MessageDigest; @@ -101,7 +99,7 @@ public class BackupManagerTest { final RateLimiters rateLimiters = mock(RateLimiters.class); when(rateLimiters.forDescriptor(RateLimiters.For.BACKUP_ATTACHMENT)).thenReturn(mediaUploadLimiter); - + when(remoteStorageManager.cdnNumber()).thenReturn(3); this.backupsDb = new BackupsDb( @@ -141,15 +139,14 @@ public class BackupManagerTest { } @Test - public void createTemporaryMediaAttachmentRateLimited() throws RateLimitExceededException { + public void createTemporaryMediaAttachmentRateLimited() { final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); - doThrow(new RateLimitExceededException(null, true)) - .when(mediaUploadLimiter) - .validate(eq(BackupManager.rateLimitKey(backupUser))); - - assertThatExceptionOfType(RateLimitExceededException.class) - .isThrownBy(() -> backupManager.createTemporaryAttachmentUploadDescriptor(backupUser)) - .satisfies(e -> assertThat(e.isLegacy()).isFalse()); + when(mediaUploadLimiter.validateAsync(eq(BackupManager.rateLimitKey(backupUser)))) + .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null, true))); + final RateLimitExceededException e = CompletableFutureTestUtil.assertFailsWithCause( + RateLimitExceededException.class, + backupManager.createTemporaryAttachmentUploadDescriptor(backupUser).toCompletableFuture()); + assertThat(e.isLegacy()).isFalse(); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ArchiveControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ArchiveControllerTest.java index e89ef08fe..b7e5110e8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ArchiveControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ArchiveControllerTest.java @@ -541,7 +541,8 @@ public class ArchiveControllerTest { when(backupManager.authenticateBackupUser(any(), any())) .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); when(backupManager.createTemporaryAttachmentUploadDescriptor(any())) - .thenReturn(new BackupUploadDescriptor(3, "abc", Map.of("k", "v"), "example.org")); + .thenReturn(CompletableFuture.completedFuture( + new BackupUploadDescriptor(3, "abc", Map.of("k", "v"), "example.org"))); final ArchiveController.UploadDescriptorResponse desc = resources.getJerseyTest() .target("v1/archives/media/upload/form") .request() @@ -555,7 +556,7 @@ public class ArchiveControllerTest { // rate limit when(backupManager.createTemporaryAttachmentUploadDescriptor(any())) - .thenThrow(new RateLimitExceededException(null, false)); + .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null, false))); final Response response = resources.getJerseyTest() .target("v1/archives/media/upload/form") .request()