Make sure we close the HTTP/2 stream after cdn read errors

This commit is contained in:
Ravi Khadiwala 2024-03-27 16:29:01 -05:00 committed by ravi-signal
parent de9eaa98db
commit a550caf63f
1 changed files with 72 additions and 34 deletions

View File

@ -8,6 +8,7 @@ import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.ByteBuffer;
import java.security.cert.CertificateException;
import java.time.Duration;
import java.util.ArrayList;
@ -18,6 +19,7 @@ import java.util.Optional;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executors;
import java.util.concurrent.Flow;
import java.util.concurrent.ScheduledExecutorService;
import java.util.stream.Stream;
import javax.annotation.Nullable;
@ -118,42 +120,16 @@ public class Cdn3RemoteStorageManager implements RemoteStorageManager {
final Timer.Sample sample = Timer.start();
final BackupMediaEncrypter encrypter = new BackupMediaEncrypter(encryptionParameters);
final HttpRequest request = HttpRequest.newBuilder().GET().uri(sourceUri).build();
final int expectedEncryptedLength = encrypter.outputSize(expectedSourceLength);
return cdnHttpClient.sendAsync(request, HttpResponse.BodyHandlers.ofPublisher()).thenCompose(response -> {
if (response.statusCode() == Response.Status.NOT_FOUND.getStatusCode()) {
throw new CompletionException(new SourceObjectNotFoundException());
} else if (response.statusCode() != Response.Status.OK.getStatusCode()) {
throw new CompletionException(new IOException("error reading from source: " + response.statusCode()));
try {
return cdnHttpClient.sendAsync(
createCopyRequest(expectedSourceLength, uploadDescriptor, encrypter, response),
HttpResponse.BodyHandlers.discarding());
} catch (Exception e) {
// Discard the response body so we don't hold the http2 stream open
response.body().subscribe(CancelSubscriber.INSTANCE);
throw ExceptionUtils.wrap(e);
}
final int actualSourceLength = Math.toIntExact(response.headers().firstValueAsLong("Content-Length")
.orElseThrow(() -> new CompletionException(new IOException("upstream missing Content-Length"))));
if (actualSourceLength != expectedSourceLength) {
throw new CompletionException(
new InvalidLengthException("Provided sourceLength " + expectedSourceLength + " was " + actualSourceLength));
}
final HttpRequest.BodyPublisher encryptedBody = HttpRequest.BodyPublishers.fromPublisher(
encrypter.encryptBody(response.body()), expectedEncryptedLength);
final String[] headers = Stream.concat(
uploadDescriptor.headers().entrySet()
.stream()
.flatMap(e -> Stream.of(e.getKey(), e.getValue())),
Stream.of(
TUS_VERSION_HEADER, TUS_VERSION,
TUS_UPLOAD_LENGTH_HEADER, Integer.toString(expectedEncryptedLength),
HttpHeaders.CONTENT_TYPE, TUS_CONTENT_TYPE))
.toArray(String[]::new);
final HttpRequest post = HttpRequest.newBuilder()
.uri(URI.create(uploadDescriptor.signedUploadLocation()))
.headers(headers)
.POST(encryptedBody)
.build();
return cdnHttpClient.sendAsync(post, HttpResponse.BodyHandlers.discarding());
})
.thenAccept(response -> {
if (response.statusCode() != Response.Status.CREATED.getStatusCode() &&
@ -162,6 +138,7 @@ public class Cdn3RemoteStorageManager implements RemoteStorageManager {
}
long uploadOffset = response.headers().firstValueAsLong(TUS_UPLOAD_OFFSET_HEADER)
.orElseThrow(() -> new CompletionException(new IOException("Tus server did not return Upload-Offset")));
final int expectedEncryptedLength = encrypter.outputSize(expectedSourceLength);
if (uploadOffset != expectedEncryptedLength) {
throw new CompletionException(new IOException(
"Expected to upload %d bytes, uploaded %d".formatted(expectedEncryptedLength, uploadOffset)));
@ -171,6 +148,47 @@ public class Cdn3RemoteStorageManager implements RemoteStorageManager {
sample.stop(Metrics.timer(STORAGE_MANAGER_TIMER_NAME, OPERATION_TAG_NAME, "copy")));
}
private HttpRequest createCopyRequest(
final int expectedSourceLength,
final MessageBackupUploadDescriptor uploadDescriptor,
BackupMediaEncrypter encrypter,
HttpResponse<Flow.Publisher<List<ByteBuffer>>> response) throws IOException {
if (response.statusCode() == Response.Status.NOT_FOUND.getStatusCode()) {
throw new SourceObjectNotFoundException();
} else if (response.statusCode() != Response.Status.OK.getStatusCode()) {
throw new IOException("error reading from source: " + response.statusCode());
}
final int actualSourceLength = Math.toIntExact(response.headers().firstValueAsLong("Content-Length")
.orElseThrow(() -> new IOException("upstream missing Content-Length")));
if (actualSourceLength != expectedSourceLength) {
throw new InvalidLengthException(
"Provided sourceLength " + expectedSourceLength + " was " + actualSourceLength);
}
final int expectedEncryptedLength = encrypter.outputSize(expectedSourceLength);
final HttpRequest.BodyPublisher encryptedBody = HttpRequest.BodyPublishers.fromPublisher(
encrypter.encryptBody(response.body()), expectedEncryptedLength);
final String[] headers = Stream.concat(
uploadDescriptor.headers().entrySet()
.stream()
.flatMap(e -> Stream.of(e.getKey(), e.getValue())),
Stream.of(
TUS_VERSION_HEADER, TUS_VERSION,
TUS_UPLOAD_LENGTH_HEADER, Integer.toString(expectedEncryptedLength),
HttpHeaders.CONTENT_TYPE, TUS_CONTENT_TYPE))
.toArray(String[]::new);
return HttpRequest.newBuilder()
.uri(URI.create(uploadDescriptor.signedUploadLocation()))
.headers(headers)
.POST(encryptedBody)
.build();
}
@Override
public CompletionStage<ListResult> list(
final String prefix,
@ -318,5 +336,25 @@ public class Cdn3RemoteStorageManager implements RemoteStorageManager {
return "%s/%s/".formatted(storageManagerBaseUrl, Cdn3BackupCredentialGenerator.CDN_PATH);
}
private static class CancelSubscriber implements Flow.Subscriber<List<ByteBuffer>> {
private static CancelSubscriber INSTANCE = new CancelSubscriber();
@Override
public void onSubscribe(final Flow.Subscription subscription) {
subscription.cancel();
}
@Override
public void onNext(final List<ByteBuffer> item) {
}
@Override
public void onError(final Throwable throwable) {
}
@Override
public void onComplete() {
}
}
}