Fix some accidentally sync async methods

This commit is contained in:
Ravi Khadiwala 2024-05-24 16:29:19 -05:00 committed by ravi-signal
parent c7d1ad56ff
commit cea2abcf6e
7 changed files with 92 additions and 90 deletions

View File

@ -88,7 +88,7 @@ public class BackupAuthManager {
* @throws RateLimitExceededException If too many backup-ids have been committed * @throws RateLimitExceededException If too many backup-ids have been committed
*/ */
public CompletableFuture<Void> commitBackupId(final Account account, public CompletableFuture<Void> commitBackupId(final Account account,
final BackupAuthCredentialRequest backupAuthCredentialRequest) throws RateLimitExceededException { final BackupAuthCredentialRequest backupAuthCredentialRequest) {
if (configuredBackupLevel(account).isEmpty()) { if (configuredBackupLevel(account).isEmpty()) {
throw Status.PERMISSION_DENIED.withDescription("Backups not allowed on account").asRuntimeException(); throw Status.PERMISSION_DENIED.withDescription("Backups not allowed on account").asRuntimeException();
} }
@ -101,11 +101,12 @@ public class BackupAuthManager {
return CompletableFuture.completedFuture(null); return CompletableFuture.completedFuture(null);
} }
rateLimiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID).validate(account.getUuid()); return rateLimiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID)
.validateAsync(account.getUuid())
return this.accountsManager .thenCompose(ignored -> this.accountsManager
.updateAsync(account, acc -> acc.setBackupCredentialRequest(serializedRequest)) .updateAsync(account, acc -> acc.setBackupCredentialRequest(serializedRequest))
.thenRun(Util.NOOP); .thenRun(Util.NOOP))
.toCompletableFuture();
} }
public record Credential(BackupAuthCredentialResponse credential, Instant redemptionTime) {} public record Credential(BackupAuthCredentialResponse credential, Instant redemptionTime) {}

View File

@ -148,19 +148,19 @@ public class BackupManager {
.thenApply(result -> cdn3BackupCredentialGenerator.generateUpload(cdnMessageBackupName(backupUser))); .thenApply(result -> cdn3BackupCredentialGenerator.generateUpload(cdnMessageBackupName(backupUser)));
} }
public BackupUploadDescriptor createTemporaryAttachmentUploadDescriptor(final AuthenticatedBackupUser backupUser) public CompletionStage<BackupUploadDescriptor> createTemporaryAttachmentUploadDescriptor(
throws RateLimitExceededException { final AuthenticatedBackupUser backupUser) {
checkBackupLevel(backupUser, BackupLevel.MEDIA); checkBackupLevel(backupUser, BackupLevel.MEDIA);
RateLimiter.adaptLegacyException(() -> rateLimiters return RateLimiter.adaptLegacyException(rateLimiters
.forDescriptor(RateLimiters.For.BACKUP_ATTACHMENT) .forDescriptor(RateLimiters.For.BACKUP_ATTACHMENT)
.validate(rateLimitKey(backupUser))); .validateAsync(rateLimitKey(backupUser))).thenApply(ignored -> {
final byte[] bytes = new byte[15];
final byte[] bytes = new byte[15]; secureRandom.nextBytes(bytes);
secureRandom.nextBytes(bytes); final String attachmentKey = Base64.getUrlEncoder().encodeToString(bytes);
final String attachmentKey = Base64.getUrlEncoder().encodeToString(bytes); final AttachmentGenerator.Descriptor descriptor = tusAttachmentGenerator.generateAttachment(attachmentKey);
final AttachmentGenerator.Descriptor descriptor = tusAttachmentGenerator.generateAttachment(attachmentKey); return new BackupUploadDescriptor(3, attachmentKey, descriptor.headers(), descriptor.signedUploadLocation());
return new BackupUploadDescriptor(3, attachmentKey, descriptor.headers(), descriptor.signedUploadLocation()); });
} }
/** /**

View File

@ -27,8 +27,8 @@ import java.util.Base64;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage; import java.util.concurrent.CompletionStage;
import javax.annotation.Nullable;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.Max; import javax.validation.constraints.Max;
import javax.validation.constraints.Min; import javax.validation.constraints.Min;
@ -186,13 +186,13 @@ public class ArchiveController {
description = """ description = """
After setting a blinded backup-id with PUT /v1/archives/, this fetches credentials that can be used to perform After setting a blinded backup-id with PUT /v1/archives/, this fetches credentials that can be used to perform
operations against that backup-id. Clients may (and should) request up to 7 days of credentials at a time. operations against that backup-id. Clients may (and should) request up to 7 days of credentials at a time.
The redemptionStart and redemptionEnd seconds must be UTC day aligned, and must not span more than 7 days. The redemptionStart and redemptionEnd seconds must be UTC day aligned, and must not span more than 7 days.
Each credential contains a receipt level which indicates the backup level the credential is good for. If the Each credential contains a receipt level which indicates the backup level the credential is good for. If the
account has paid backup access that expires at some point in the provided redemption window, credentials with account has paid backup access that expires at some point in the provided redemption window, credentials with
redemption times after the expiration may be on a lower backup level. redemption times after the expiration may be on a lower backup level.
Clients must validate the receipt level on the credential matches a known receipt level before using it. Clients must validate the receipt level on the credential matches a known receipt level before using it.
""") """)
@ApiResponse(responseCode = "200", content = @Content(schema = @Schema(implementation = BackupAuthCredentialsResponse.class))) @ApiResponse(responseCode = "200", content = @Content(schema = @Schema(implementation = BackupAuthCredentialsResponse.class)))
@ -455,13 +455,7 @@ public class ArchiveController {
throw new BadRequestException("must not use authenticated connection for anonymous operations"); throw new BadRequestException("must not use authenticated connection for anonymous operations");
} }
return backupManager.authenticateBackupUser(presentation.presentation, signature.signature) return backupManager.authenticateBackupUser(presentation.presentation, signature.signature)
.thenApply(backupUser -> { .thenCompose(backupUser -> backupManager.createTemporaryAttachmentUploadDescriptor(backupUser))
try {
return backupManager.createTemporaryAttachmentUploadDescriptor(backupUser);
} catch (RateLimitExceededException e) {
throw ExceptionUtils.wrap(e);
}
})
.thenApply(result -> new UploadDescriptorResponse( .thenApply(result -> new UploadDescriptorResponse(
result.cdn(), result.cdn(),
result.key(), result.key(),
@ -553,14 +547,10 @@ public class ArchiveController {
throw new BadRequestException("must not use authenticated connection for anonymous operations"); throw new BadRequestException("must not use authenticated connection for anonymous operations");
} }
final AuthenticatedBackupUser backupUser = backupManager.authenticateBackupUser( return backupManager
presentation.presentation, signature.signature).join(); .authenticateBackupUser(presentation.presentation, signature.signature)
.thenCompose(backupUser -> checkMediaFits(backupUser, copyMediaRequest.objectLength)
final boolean fits = backupManager.canStoreMedia(backupUser, copyMediaRequest.objectLength()).join(); .thenCompose(ignored -> copyMediaImpl(backupUser, copyMediaRequest)))
if (!fits) {
throw new ClientErrorException("Media quota exhausted", Response.Status.REQUEST_ENTITY_TOO_LARGE);
}
return copyMediaImpl(backupUser, copyMediaRequest)
.thenApply(result -> new CopyMediaResponse(result.cdn())) .thenApply(result -> new CopyMediaResponse(result.cdn()))
.exceptionally(e -> { .exceptionally(e -> {
final Throwable unwrapped = ExceptionUtils.unwrap(e); final Throwable unwrapped = ExceptionUtils.unwrap(e);
@ -574,6 +564,16 @@ public class ArchiveController {
}); });
} }
private CompletableFuture<Void> checkMediaFits(AuthenticatedBackupUser backupUser, long amountToStore) {
return backupManager.canStoreMedia(backupUser, amountToStore)
.thenApply(fits -> {
if (!fits) {
throw new ClientErrorException("Media quota exhausted", Response.Status.REQUEST_ENTITY_TOO_LARGE);
}
return null;
});
}
private CompletionStage<BackupManager.StorageDescriptor> copyMediaImpl(final AuthenticatedBackupUser backupUser, private CompletionStage<BackupManager.StorageDescriptor> copyMediaImpl(final AuthenticatedBackupUser backupUser,
final CopyMediaRequest copyMediaRequest) { final CopyMediaRequest copyMediaRequest) {
return this.backupManager.copyToBackup( return this.backupManager.copyToBackup(
@ -662,41 +662,36 @@ public class ArchiveController {
throw new BadRequestException("must not use authenticated connection for anonymous operations"); throw new BadRequestException("must not use authenticated connection for anonymous operations");
} }
final AuthenticatedBackupUser backupUser = backupManager.authenticateBackupUser(
presentation.presentation, signature.signature).join();
// If the entire batch won't fit in the user's remaining quota, reject the whole request. // If the entire batch won't fit in the user's remaining quota, reject the whole request.
final long expectedStorage = copyMediaRequest.items().stream().mapToLong(CopyMediaRequest::objectLength).sum(); final long expectedStorage = copyMediaRequest.items().stream().mapToLong(CopyMediaRequest::objectLength).sum();
final boolean fits = backupManager.canStoreMedia(backupUser, expectedStorage).join();
if (!fits) {
throw new ClientErrorException("Media quota exhausted", Response.Status.REQUEST_ENTITY_TOO_LARGE);
}
return Flux.fromIterable(copyMediaRequest.items) return backupManager.authenticateBackupUser(presentation.presentation, signature.signature)
// Operate sequentially, waiting for one copy to finish before starting the next one. At least right now, .thenCompose(backupUser -> checkMediaFits(backupUser, expectedStorage).thenCompose(
// copying concurrently will introduce contention over the metadata. ignored -> Flux.fromIterable(copyMediaRequest.items)
.concatMap(request -> Mono // Operate sequentially, waiting for one copy to finish before starting the next one. At least right now,
.fromCompletionStage(copyMediaImpl(backupUser, request)) // copying concurrently will introduce contention over the metadata.
.map(result -> new CopyMediaBatchResponse.Entry(200, null, result.cdn(), result.key())) .concatMap(request -> Mono
.onErrorResume(throwable -> ExceptionUtils.unwrap(throwable) instanceof IOException, throwable -> { .fromCompletionStage(copyMediaImpl(backupUser, request))
final Throwable unwrapped = ExceptionUtils.unwrap(throwable); .map(result -> new CopyMediaBatchResponse.Entry(200, null, result.cdn(), result.key()))
.onErrorResume(throwable -> ExceptionUtils.unwrap(throwable) instanceof IOException, throwable -> {
final Throwable unwrapped = ExceptionUtils.unwrap(throwable);
int status; int status;
String error; String error;
if (unwrapped instanceof SourceObjectNotFoundException) { if (unwrapped instanceof SourceObjectNotFoundException) {
status = 410; status = 410;
error = "Source object not found " + unwrapped.getMessage(); error = "Source object not found " + unwrapped.getMessage();
} else if (unwrapped instanceof InvalidLengthException) { } else if (unwrapped instanceof InvalidLengthException) {
status = 400; status = 400;
error = "Invalid length " + unwrapped.getMessage(); error = "Invalid length " + unwrapped.getMessage();
} else { } else {
throw ExceptionUtils.wrap(throwable); throw ExceptionUtils.wrap(throwable);
} }
return Mono.just(new CopyMediaBatchResponse.Entry(status, error, null, request.mediaId)); return Mono.just(new CopyMediaBatchResponse.Entry(status, error, null, request.mediaId));
})) }))
.collectList() .collectList()
.map(list -> Response.status(207).entity(new CopyMediaBatchResponse(list)).build()) .map(list -> Response.status(207).entity(new CopyMediaBatchResponse(list)).build())
.toFuture(); .toFuture()));
} }
@POST @POST
@ -858,8 +853,7 @@ public class ArchiveController {
@DELETE @DELETE
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Operation(summary = "Delete entire backup", @Operation(summary = "Delete entire backup", description = """
description = """
Delete all backup metadata, objects, and stored public key. To use backups again, a public key must be resupplied. Delete all backup metadata, objects, and stored public key. To use backups again, a public key must be resupplied.
""") """)
@ApiResponse(responseCode = "204", description = "The backup has been successfully removed") @ApiResponse(responseCode = "204", description = "The backup has been successfully removed")

View File

@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.limits;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletionStage; import java.util.concurrent.CompletionStage;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
public interface RateLimiter { public interface RateLimiter {
@ -78,6 +79,16 @@ public interface RateLimiter {
return clearAsync(accountUuid.toString()); return clearAsync(accountUuid.toString());
} }
/**
* If the future throws a {@link RateLimitExceededException}, it will adapt it to ensure that
* {@link RateLimitExceededException#isLegacy()} returns {@code false}
*/
static CompletionStage<Void> adaptLegacyException(final CompletionStage<Void> rateLimitFuture) {
return rateLimitFuture.exceptionally(ExceptionUtils.exceptionallyHandler(RateLimitExceededException.class, e -> {
throw ExceptionUtils.wrap(new RateLimitExceededException(e.getRetryDuration().orElse(null), false));
}));
}
/** /**
* If the wrapped {@code validate()} call throws a {@link RateLimitExceededException}, it will adapt it to ensure that * If the wrapped {@code validate()} call throws a {@link RateLimitExceededException}, it will adapt it to ensure that
* {@link RateLimitExceededException#isLegacy()} returns {@code false} * {@link RateLimitExceededException#isLegacy()} returns {@code false}

View File

@ -10,7 +10,6 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
@ -29,6 +28,7 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.assertj.core.api.Assertions; import org.assertj.core.api.Assertions;
import org.assertj.core.api.ThrowableAssert; import org.assertj.core.api.ThrowableAssert;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -63,7 +63,6 @@ import org.whispersystems.textsecuregcm.tests.util.ExperimentHelper;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import javax.annotation.Nullable;
public class BackupAuthManagerTest { public class BackupAuthManagerTest {
@ -407,8 +406,9 @@ public class BackupAuthManagerTest {
when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account)); when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account));
// Should be rate limited // Should be rate limited
assertThatExceptionOfType(RateLimitExceededException.class) CompletableFutureTestUtil.assertFailsWithCause(
.isThrownBy(() -> authManager.commitBackupId(account, credentialRequest).join()); RateLimitExceededException.class,
authManager.commitBackupId(account, credentialRequest));
// If we don't change the request, shouldn't be rate limited // If we don't change the request, shouldn't be rate limited
when(account.getBackupCredentialRequest()).thenReturn(credentialRequest.serialize()); when(account.getBackupCredentialRequest()).thenReturn(credentialRequest.serialize());
@ -426,6 +426,7 @@ public class BackupAuthManagerTest {
private static RateLimiters allowRateLimiter() { private static RateLimiters allowRateLimiter() {
final RateLimiters limiters = mock(RateLimiters.class); final RateLimiters limiters = mock(RateLimiters.class);
final RateLimiter limiter = mock(RateLimiter.class); final RateLimiter limiter = mock(RateLimiter.class);
when(limiter.validateAsync(any(UUID.class))).thenReturn(CompletableFuture.completedFuture(null));
when(limiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID)).thenReturn(limiter); when(limiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID)).thenReturn(limiter);
return limiters; return limiters;
} }
@ -433,11 +434,8 @@ public class BackupAuthManagerTest {
private static RateLimiters denyRateLimiter(final UUID aci) { private static RateLimiters denyRateLimiter(final UUID aci) {
final RateLimiters limiters = mock(RateLimiters.class); final RateLimiters limiters = mock(RateLimiters.class);
final RateLimiter limiter = mock(RateLimiter.class); final RateLimiter limiter = mock(RateLimiter.class);
try { when(limiter.validateAsync(aci))
doThrow(new RateLimitExceededException(null, false)).when(limiter).validate(aci); .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null, false)));
} catch (RateLimitExceededException e) {
throw new AssertionError(e);
}
when(limiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID)).thenReturn(limiter); when(limiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID)).thenReturn(limiter);
return limiters; return limiters;
} }

View File

@ -13,7 +13,6 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
@ -25,7 +24,6 @@ import static org.mockito.Mockito.when;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusRuntimeException; import io.grpc.StatusRuntimeException;
import java.io.IOException; import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.MessageDigest; import java.security.MessageDigest;
@ -101,7 +99,7 @@ public class BackupManagerTest {
final RateLimiters rateLimiters = mock(RateLimiters.class); final RateLimiters rateLimiters = mock(RateLimiters.class);
when(rateLimiters.forDescriptor(RateLimiters.For.BACKUP_ATTACHMENT)).thenReturn(mediaUploadLimiter); when(rateLimiters.forDescriptor(RateLimiters.For.BACKUP_ATTACHMENT)).thenReturn(mediaUploadLimiter);
when(remoteStorageManager.cdnNumber()).thenReturn(3); when(remoteStorageManager.cdnNumber()).thenReturn(3);
this.backupsDb = new BackupsDb( this.backupsDb = new BackupsDb(
@ -141,15 +139,14 @@ public class BackupManagerTest {
} }
@Test @Test
public void createTemporaryMediaAttachmentRateLimited() throws RateLimitExceededException { public void createTemporaryMediaAttachmentRateLimited() {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA);
doThrow(new RateLimitExceededException(null, true)) when(mediaUploadLimiter.validateAsync(eq(BackupManager.rateLimitKey(backupUser))))
.when(mediaUploadLimiter) .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null, true)));
.validate(eq(BackupManager.rateLimitKey(backupUser))); final RateLimitExceededException e = CompletableFutureTestUtil.assertFailsWithCause(
RateLimitExceededException.class,
assertThatExceptionOfType(RateLimitExceededException.class) backupManager.createTemporaryAttachmentUploadDescriptor(backupUser).toCompletableFuture());
.isThrownBy(() -> backupManager.createTemporaryAttachmentUploadDescriptor(backupUser)) assertThat(e.isLegacy()).isFalse();
.satisfies(e -> assertThat(e.isLegacy()).isFalse());
} }
@Test @Test

View File

@ -541,7 +541,8 @@ public class ArchiveControllerTest {
when(backupManager.authenticateBackupUser(any(), any())) when(backupManager.authenticateBackupUser(any(), any()))
.thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA)));
when(backupManager.createTemporaryAttachmentUploadDescriptor(any())) when(backupManager.createTemporaryAttachmentUploadDescriptor(any()))
.thenReturn(new BackupUploadDescriptor(3, "abc", Map.of("k", "v"), "example.org")); .thenReturn(CompletableFuture.completedFuture(
new BackupUploadDescriptor(3, "abc", Map.of("k", "v"), "example.org")));
final ArchiveController.UploadDescriptorResponse desc = resources.getJerseyTest() final ArchiveController.UploadDescriptorResponse desc = resources.getJerseyTest()
.target("v1/archives/media/upload/form") .target("v1/archives/media/upload/form")
.request() .request()
@ -555,7 +556,7 @@ public class ArchiveControllerTest {
// rate limit // rate limit
when(backupManager.createTemporaryAttachmentUploadDescriptor(any())) when(backupManager.createTemporaryAttachmentUploadDescriptor(any()))
.thenThrow(new RateLimitExceededException(null, false)); .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null, false)));
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
.target("v1/archives/media/upload/form") .target("v1/archives/media/upload/form")
.request() .request()