Fix some accidentally sync async methods

This commit is contained in:
Ravi Khadiwala 2024-05-24 16:29:19 -05:00 committed by ravi-signal
parent c7d1ad56ff
commit cea2abcf6e
7 changed files with 92 additions and 90 deletions

View File

@ -88,7 +88,7 @@ public class BackupAuthManager {
* @throws RateLimitExceededException If too many backup-ids have been committed * @throws RateLimitExceededException If too many backup-ids have been committed
*/ */
public CompletableFuture<Void> commitBackupId(final Account account, public CompletableFuture<Void> commitBackupId(final Account account,
final BackupAuthCredentialRequest backupAuthCredentialRequest) throws RateLimitExceededException { final BackupAuthCredentialRequest backupAuthCredentialRequest) {
if (configuredBackupLevel(account).isEmpty()) { if (configuredBackupLevel(account).isEmpty()) {
throw Status.PERMISSION_DENIED.withDescription("Backups not allowed on account").asRuntimeException(); throw Status.PERMISSION_DENIED.withDescription("Backups not allowed on account").asRuntimeException();
} }
@ -101,11 +101,12 @@ public class BackupAuthManager {
return CompletableFuture.completedFuture(null); return CompletableFuture.completedFuture(null);
} }
rateLimiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID).validate(account.getUuid()); return rateLimiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID)
.validateAsync(account.getUuid())
return this.accountsManager .thenCompose(ignored -> this.accountsManager
.updateAsync(account, acc -> acc.setBackupCredentialRequest(serializedRequest)) .updateAsync(account, acc -> acc.setBackupCredentialRequest(serializedRequest))
.thenRun(Util.NOOP); .thenRun(Util.NOOP))
.toCompletableFuture();
} }
public record Credential(BackupAuthCredentialResponse credential, Instant redemptionTime) {} public record Credential(BackupAuthCredentialResponse credential, Instant redemptionTime) {}

View File

@ -148,19 +148,19 @@ public class BackupManager {
.thenApply(result -> cdn3BackupCredentialGenerator.generateUpload(cdnMessageBackupName(backupUser))); .thenApply(result -> cdn3BackupCredentialGenerator.generateUpload(cdnMessageBackupName(backupUser)));
} }
public BackupUploadDescriptor createTemporaryAttachmentUploadDescriptor(final AuthenticatedBackupUser backupUser) public CompletionStage<BackupUploadDescriptor> createTemporaryAttachmentUploadDescriptor(
throws RateLimitExceededException { final AuthenticatedBackupUser backupUser) {
checkBackupLevel(backupUser, BackupLevel.MEDIA); checkBackupLevel(backupUser, BackupLevel.MEDIA);
RateLimiter.adaptLegacyException(() -> rateLimiters return RateLimiter.adaptLegacyException(rateLimiters
.forDescriptor(RateLimiters.For.BACKUP_ATTACHMENT) .forDescriptor(RateLimiters.For.BACKUP_ATTACHMENT)
.validate(rateLimitKey(backupUser))); .validateAsync(rateLimitKey(backupUser))).thenApply(ignored -> {
final byte[] bytes = new byte[15];
final byte[] bytes = new byte[15]; secureRandom.nextBytes(bytes);
secureRandom.nextBytes(bytes); final String attachmentKey = Base64.getUrlEncoder().encodeToString(bytes);
final String attachmentKey = Base64.getUrlEncoder().encodeToString(bytes); final AttachmentGenerator.Descriptor descriptor = tusAttachmentGenerator.generateAttachment(attachmentKey);
final AttachmentGenerator.Descriptor descriptor = tusAttachmentGenerator.generateAttachment(attachmentKey); return new BackupUploadDescriptor(3, attachmentKey, descriptor.headers(), descriptor.signedUploadLocation());
return new BackupUploadDescriptor(3, attachmentKey, descriptor.headers(), descriptor.signedUploadLocation()); });
} }
/** /**

View File

@ -27,8 +27,8 @@ import java.util.Base64;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage; import java.util.concurrent.CompletionStage;
import javax.annotation.Nullable;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.Max; import javax.validation.constraints.Max;
import javax.validation.constraints.Min; import javax.validation.constraints.Min;
@ -455,13 +455,7 @@ public class ArchiveController {
throw new BadRequestException("must not use authenticated connection for anonymous operations"); throw new BadRequestException("must not use authenticated connection for anonymous operations");
} }
return backupManager.authenticateBackupUser(presentation.presentation, signature.signature) return backupManager.authenticateBackupUser(presentation.presentation, signature.signature)
.thenApply(backupUser -> { .thenCompose(backupUser -> backupManager.createTemporaryAttachmentUploadDescriptor(backupUser))
try {
return backupManager.createTemporaryAttachmentUploadDescriptor(backupUser);
} catch (RateLimitExceededException e) {
throw ExceptionUtils.wrap(e);
}
})
.thenApply(result -> new UploadDescriptorResponse( .thenApply(result -> new UploadDescriptorResponse(
result.cdn(), result.cdn(),
result.key(), result.key(),
@ -553,14 +547,10 @@ public class ArchiveController {
throw new BadRequestException("must not use authenticated connection for anonymous operations"); throw new BadRequestException("must not use authenticated connection for anonymous operations");
} }
final AuthenticatedBackupUser backupUser = backupManager.authenticateBackupUser( return backupManager
presentation.presentation, signature.signature).join(); .authenticateBackupUser(presentation.presentation, signature.signature)
.thenCompose(backupUser -> checkMediaFits(backupUser, copyMediaRequest.objectLength)
final boolean fits = backupManager.canStoreMedia(backupUser, copyMediaRequest.objectLength()).join(); .thenCompose(ignored -> copyMediaImpl(backupUser, copyMediaRequest)))
if (!fits) {
throw new ClientErrorException("Media quota exhausted", Response.Status.REQUEST_ENTITY_TOO_LARGE);
}
return copyMediaImpl(backupUser, copyMediaRequest)
.thenApply(result -> new CopyMediaResponse(result.cdn())) .thenApply(result -> new CopyMediaResponse(result.cdn()))
.exceptionally(e -> { .exceptionally(e -> {
final Throwable unwrapped = ExceptionUtils.unwrap(e); final Throwable unwrapped = ExceptionUtils.unwrap(e);
@ -574,6 +564,16 @@ public class ArchiveController {
}); });
} }
private CompletableFuture<Void> checkMediaFits(AuthenticatedBackupUser backupUser, long amountToStore) {
return backupManager.canStoreMedia(backupUser, amountToStore)
.thenApply(fits -> {
if (!fits) {
throw new ClientErrorException("Media quota exhausted", Response.Status.REQUEST_ENTITY_TOO_LARGE);
}
return null;
});
}
private CompletionStage<BackupManager.StorageDescriptor> copyMediaImpl(final AuthenticatedBackupUser backupUser, private CompletionStage<BackupManager.StorageDescriptor> copyMediaImpl(final AuthenticatedBackupUser backupUser,
final CopyMediaRequest copyMediaRequest) { final CopyMediaRequest copyMediaRequest) {
return this.backupManager.copyToBackup( return this.backupManager.copyToBackup(
@ -662,41 +662,36 @@ public class ArchiveController {
throw new BadRequestException("must not use authenticated connection for anonymous operations"); throw new BadRequestException("must not use authenticated connection for anonymous operations");
} }
final AuthenticatedBackupUser backupUser = backupManager.authenticateBackupUser(
presentation.presentation, signature.signature).join();
// If the entire batch won't fit in the user's remaining quota, reject the whole request. // If the entire batch won't fit in the user's remaining quota, reject the whole request.
final long expectedStorage = copyMediaRequest.items().stream().mapToLong(CopyMediaRequest::objectLength).sum(); final long expectedStorage = copyMediaRequest.items().stream().mapToLong(CopyMediaRequest::objectLength).sum();
final boolean fits = backupManager.canStoreMedia(backupUser, expectedStorage).join();
if (!fits) {
throw new ClientErrorException("Media quota exhausted", Response.Status.REQUEST_ENTITY_TOO_LARGE);
}
return Flux.fromIterable(copyMediaRequest.items) return backupManager.authenticateBackupUser(presentation.presentation, signature.signature)
// Operate sequentially, waiting for one copy to finish before starting the next one. At least right now, .thenCompose(backupUser -> checkMediaFits(backupUser, expectedStorage).thenCompose(
// copying concurrently will introduce contention over the metadata. ignored -> Flux.fromIterable(copyMediaRequest.items)
.concatMap(request -> Mono // Operate sequentially, waiting for one copy to finish before starting the next one. At least right now,
.fromCompletionStage(copyMediaImpl(backupUser, request)) // copying concurrently will introduce contention over the metadata.
.map(result -> new CopyMediaBatchResponse.Entry(200, null, result.cdn(), result.key())) .concatMap(request -> Mono
.onErrorResume(throwable -> ExceptionUtils.unwrap(throwable) instanceof IOException, throwable -> { .fromCompletionStage(copyMediaImpl(backupUser, request))
final Throwable unwrapped = ExceptionUtils.unwrap(throwable); .map(result -> new CopyMediaBatchResponse.Entry(200, null, result.cdn(), result.key()))
.onErrorResume(throwable -> ExceptionUtils.unwrap(throwable) instanceof IOException, throwable -> {
final Throwable unwrapped = ExceptionUtils.unwrap(throwable);
int status; int status;
String error; String error;
if (unwrapped instanceof SourceObjectNotFoundException) { if (unwrapped instanceof SourceObjectNotFoundException) {
status = 410; status = 410;
error = "Source object not found " + unwrapped.getMessage(); error = "Source object not found " + unwrapped.getMessage();
} else if (unwrapped instanceof InvalidLengthException) { } else if (unwrapped instanceof InvalidLengthException) {
status = 400; status = 400;
error = "Invalid length " + unwrapped.getMessage(); error = "Invalid length " + unwrapped.getMessage();
} else { } else {
throw ExceptionUtils.wrap(throwable); throw ExceptionUtils.wrap(throwable);
} }
return Mono.just(new CopyMediaBatchResponse.Entry(status, error, null, request.mediaId)); return Mono.just(new CopyMediaBatchResponse.Entry(status, error, null, request.mediaId));
})) }))
.collectList() .collectList()
.map(list -> Response.status(207).entity(new CopyMediaBatchResponse(list)).build()) .map(list -> Response.status(207).entity(new CopyMediaBatchResponse(list)).build())
.toFuture(); .toFuture()));
} }
@POST @POST
@ -858,8 +853,7 @@ public class ArchiveController {
@DELETE @DELETE
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Operation(summary = "Delete entire backup", @Operation(summary = "Delete entire backup", description = """
description = """
Delete all backup metadata, objects, and stored public key. To use backups again, a public key must be resupplied. Delete all backup metadata, objects, and stored public key. To use backups again, a public key must be resupplied.
""") """)
@ApiResponse(responseCode = "204", description = "The backup has been successfully removed") @ApiResponse(responseCode = "204", description = "The backup has been successfully removed")

View File

@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.limits;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletionStage; import java.util.concurrent.CompletionStage;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
public interface RateLimiter { public interface RateLimiter {
@ -78,6 +79,16 @@ public interface RateLimiter {
return clearAsync(accountUuid.toString()); return clearAsync(accountUuid.toString());
} }
/**
* If the future throws a {@link RateLimitExceededException}, it will adapt it to ensure that
* {@link RateLimitExceededException#isLegacy()} returns {@code false}
*/
static CompletionStage<Void> adaptLegacyException(final CompletionStage<Void> rateLimitFuture) {
return rateLimitFuture.exceptionally(ExceptionUtils.exceptionallyHandler(RateLimitExceededException.class, e -> {
throw ExceptionUtils.wrap(new RateLimitExceededException(e.getRetryDuration().orElse(null), false));
}));
}
/** /**
* If the wrapped {@code validate()} call throws a {@link RateLimitExceededException}, it will adapt it to ensure that * If the wrapped {@code validate()} call throws a {@link RateLimitExceededException}, it will adapt it to ensure that
* {@link RateLimitExceededException#isLegacy()} returns {@code false} * {@link RateLimitExceededException#isLegacy()} returns {@code false}

View File

@ -10,7 +10,6 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
@ -29,6 +28,7 @@ import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.assertj.core.api.Assertions; import org.assertj.core.api.Assertions;
import org.assertj.core.api.ThrowableAssert; import org.assertj.core.api.ThrowableAssert;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
@ -63,7 +63,6 @@ import org.whispersystems.textsecuregcm.tests.util.ExperimentHelper;
import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.TestClock; import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.textsecuregcm.util.TestRandomUtil; import org.whispersystems.textsecuregcm.util.TestRandomUtil;
import javax.annotation.Nullable;
public class BackupAuthManagerTest { public class BackupAuthManagerTest {
@ -407,8 +406,9 @@ public class BackupAuthManagerTest {
when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account)); when(accountsManager.updateAsync(any(), any())).thenReturn(CompletableFuture.completedFuture(account));
// Should be rate limited // Should be rate limited
assertThatExceptionOfType(RateLimitExceededException.class) CompletableFutureTestUtil.assertFailsWithCause(
.isThrownBy(() -> authManager.commitBackupId(account, credentialRequest).join()); RateLimitExceededException.class,
authManager.commitBackupId(account, credentialRequest));
// If we don't change the request, shouldn't be rate limited // If we don't change the request, shouldn't be rate limited
when(account.getBackupCredentialRequest()).thenReturn(credentialRequest.serialize()); when(account.getBackupCredentialRequest()).thenReturn(credentialRequest.serialize());
@ -426,6 +426,7 @@ public class BackupAuthManagerTest {
private static RateLimiters allowRateLimiter() { private static RateLimiters allowRateLimiter() {
final RateLimiters limiters = mock(RateLimiters.class); final RateLimiters limiters = mock(RateLimiters.class);
final RateLimiter limiter = mock(RateLimiter.class); final RateLimiter limiter = mock(RateLimiter.class);
when(limiter.validateAsync(any(UUID.class))).thenReturn(CompletableFuture.completedFuture(null));
when(limiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID)).thenReturn(limiter); when(limiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID)).thenReturn(limiter);
return limiters; return limiters;
} }
@ -433,11 +434,8 @@ public class BackupAuthManagerTest {
private static RateLimiters denyRateLimiter(final UUID aci) { private static RateLimiters denyRateLimiter(final UUID aci) {
final RateLimiters limiters = mock(RateLimiters.class); final RateLimiters limiters = mock(RateLimiters.class);
final RateLimiter limiter = mock(RateLimiter.class); final RateLimiter limiter = mock(RateLimiter.class);
try { when(limiter.validateAsync(aci))
doThrow(new RateLimitExceededException(null, false)).when(limiter).validate(aci); .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null, false)));
} catch (RateLimitExceededException e) {
throw new AssertionError(e);
}
when(limiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID)).thenReturn(limiter); when(limiters.forDescriptor(RateLimiters.For.SET_BACKUP_ID)).thenReturn(limiter);
return limiters; return limiters;
} }

View File

@ -13,7 +13,6 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
@ -25,7 +24,6 @@ import static org.mockito.Mockito.when;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.StatusRuntimeException; import io.grpc.StatusRuntimeException;
import java.io.IOException; import java.io.IOException;
import java.net.URI;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.MessageDigest; import java.security.MessageDigest;
@ -141,15 +139,14 @@ public class BackupManagerTest {
} }
@Test @Test
public void createTemporaryMediaAttachmentRateLimited() throws RateLimitExceededException { public void createTemporaryMediaAttachmentRateLimited() {
final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA); final AuthenticatedBackupUser backupUser = backupUser(TestRandomUtil.nextBytes(16), BackupLevel.MEDIA);
doThrow(new RateLimitExceededException(null, true)) when(mediaUploadLimiter.validateAsync(eq(BackupManager.rateLimitKey(backupUser))))
.when(mediaUploadLimiter) .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null, true)));
.validate(eq(BackupManager.rateLimitKey(backupUser))); final RateLimitExceededException e = CompletableFutureTestUtil.assertFailsWithCause(
RateLimitExceededException.class,
assertThatExceptionOfType(RateLimitExceededException.class) backupManager.createTemporaryAttachmentUploadDescriptor(backupUser).toCompletableFuture());
.isThrownBy(() -> backupManager.createTemporaryAttachmentUploadDescriptor(backupUser)) assertThat(e.isLegacy()).isFalse();
.satisfies(e -> assertThat(e.isLegacy()).isFalse());
} }
@Test @Test

View File

@ -541,7 +541,8 @@ public class ArchiveControllerTest {
when(backupManager.authenticateBackupUser(any(), any())) when(backupManager.authenticateBackupUser(any(), any()))
.thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA))); .thenReturn(CompletableFuture.completedFuture(backupUser(presentation.getBackupId(), BackupLevel.MEDIA)));
when(backupManager.createTemporaryAttachmentUploadDescriptor(any())) when(backupManager.createTemporaryAttachmentUploadDescriptor(any()))
.thenReturn(new BackupUploadDescriptor(3, "abc", Map.of("k", "v"), "example.org")); .thenReturn(CompletableFuture.completedFuture(
new BackupUploadDescriptor(3, "abc", Map.of("k", "v"), "example.org")));
final ArchiveController.UploadDescriptorResponse desc = resources.getJerseyTest() final ArchiveController.UploadDescriptorResponse desc = resources.getJerseyTest()
.target("v1/archives/media/upload/form") .target("v1/archives/media/upload/form")
.request() .request()
@ -555,7 +556,7 @@ public class ArchiveControllerTest {
// rate limit // rate limit
when(backupManager.createTemporaryAttachmentUploadDescriptor(any())) when(backupManager.createTemporaryAttachmentUploadDescriptor(any()))
.thenThrow(new RateLimitExceededException(null, false)); .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null, false)));
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
.target("v1/archives/media/upload/form") .target("v1/archives/media/upload/form")
.request() .request()