avoid baos::writeTo on virtual threads

This commit is contained in:
ravi-signal 2024-03-27 16:58:38 -05:00 committed by GitHub
parent a733f5c615
commit 37b657cbbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 142 additions and 0 deletions

View File

@ -220,6 +220,7 @@ import org.whispersystems.textsecuregcm.storage.VerificationSessions;
import org.whispersystems.textsecuregcm.subscriptions.BankMandateTranslator;
import org.whispersystems.textsecuregcm.subscriptions.BraintreeManager;
import org.whispersystems.textsecuregcm.subscriptions.StripeManager;
import org.whispersystems.textsecuregcm.util.BufferingInterceptor;
import org.whispersystems.textsecuregcm.util.DynamoDbFromConfig;
import org.whispersystems.textsecuregcm.util.ManagedAwsCrt;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -836,6 +837,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
Set.of(websocketServletPath, provisioningWebsocketServletPath, "/health-check"));
metricsHttpChannelListener.configure(environment);
environment.jersey().register(new BufferingInterceptor());
environment.jersey().register(new VirtualExecutorServiceProvider("managed-async-virtual-thread-"));
environment.jersey().register(new RequestStatisticsFilter(TrafficSource.HTTP));
environment.jersey().register(MultiRecipientMessageProvider.class);

View File

@ -0,0 +1,63 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.util;
import org.glassfish.jersey.message.internal.CommittingOutputStream;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.ext.WriterInterceptor;
import javax.ws.rs.ext.WriterInterceptorContext;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
/**
* This is an elaborate workaround to avoid doing blocking operations under synchronized blocks, which is currently a
* suboptimal case for virtual threads.
* <p>
* Jersey's {@link CommittingOutputStream} has two modes: direct write and buffered writes. In buffered mode, if the
* total amount written does not exceed the output stream's buffer size, CommittingOutputStream will compute the
* content-length for us. However, when it passes through our write to its own underlying output stream it uses
* {@link ByteArrayOutputStream#writeTo(OutputStream)} which performs the write under a synchronized block.
* <p>
* If we just disable buffering, we lose our content length. However, we can't really set content-length ourselves
* without the same access to internal state that CommittingOutputStream has. Fortunately, the underlying OutputStream
* wrapped by CommittingOutputStream ALSO has an internal buffer, and can compute the content-length from that if the
* content fits. But to make use of that, we need to avoid flushing that output stream until calling close, so that the
* underlying output stream can see that it has all the data. Unfortunately the runtime inserts manual flushes after
* writes rather than letting the underlying output stream handle it.
* <p>
* So here we disable buffering on CommittingOutputStream, and buffer ourselves. We don't write anything to the
* CommittingOutputStream until we are going to close, and we do nothing on flush.
*/
public class BufferingInterceptor implements WriterInterceptor {
@Override
public void aroundWriteTo(final WriterInterceptorContext ctx) throws IOException, WebApplicationException {
final OutputStream orig = ctx.getOutputStream();
if (Thread.currentThread().isVirtual() && orig instanceof CommittingOutputStream cos) {
cos.enableBuffering(0);
ctx.setOutputStream(new BufferingOutputStream(cos));
}
ctx.proceed();
}
private static class BufferingOutputStream extends ByteArrayOutputStream {
private final CommittingOutputStream original;
BufferingOutputStream(final CommittingOutputStream original) {
this.original = original;
}
@Override
public void close() throws IOException {
original.write(buf, 0, count);
original.close();
super.close();
}
}
}

View File

@ -0,0 +1,77 @@
package org.whispersystems.textsecuregcm;
import static org.assertj.core.api.Assertions.assertThat;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.apache.commons.lang3.RandomStringUtils;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.util.BufferingInterceptor;
import org.whispersystems.textsecuregcm.util.VirtualExecutorServiceProvider;
@ExtendWith(DropwizardExtensionsSupport.class)
public class BufferingInterceptorIntegrationTest {
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION =
new DropwizardAppExtension<>(TestApplication.class);
public static class TestApplication extends Application<Configuration> {
@Override
public void run(final Configuration configuration, final Environment environment) throws Exception {
final TestController testController = new TestController();
environment.jersey().register(testController);
environment.jersey().register(new BufferingInterceptor());
environment.jersey().register(new VirtualExecutorServiceProvider("virtual-thread-"));
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
}
}
@Test
public void testVirtual() {
final Response response = DROPWIZARD_APP_EXTENSION.client()
.target("http://127.0.0.1:%d/test/virtual/8".formatted(DROPWIZARD_APP_EXTENSION.getLocalPort()))
.request().get();
assertThat(response.getHeaders().getFirst(HttpHeaders.CONTENT_LENGTH)).isEqualTo("8");
}
@Test
public void testPlatform() {
final Response response = DROPWIZARD_APP_EXTENSION.client()
.target("http://127.0.0.1:%d/test/platform/8".formatted(DROPWIZARD_APP_EXTENSION.getLocalPort()))
.request().get();
assertThat(response.getHeaders().getFirst(HttpHeaders.CONTENT_LENGTH)).isEqualTo("8");
}
@Path("/test")
public static class TestController {
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/virtual/{size}")
@ManagedAsync
public String getVirtual(@PathParam("size") int size) {
return RandomStringUtils.randomAscii(size);
}
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/platform/{size}")
public String getPlatform(@PathParam("size") int size) {
return RandomStringUtils.randomAscii(size);
}
}
}