diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index f4fde4014..d2bf1cfb3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -220,6 +220,7 @@ import org.whispersystems.textsecuregcm.storage.VerificationSessions; import org.whispersystems.textsecuregcm.subscriptions.BankMandateTranslator; import org.whispersystems.textsecuregcm.subscriptions.BraintreeManager; import org.whispersystems.textsecuregcm.subscriptions.StripeManager; +import org.whispersystems.textsecuregcm.util.BufferingInterceptor; import org.whispersystems.textsecuregcm.util.DynamoDbFromConfig; import org.whispersystems.textsecuregcm.util.ManagedAwsCrt; import org.whispersystems.textsecuregcm.util.SystemMapper; @@ -836,6 +837,7 @@ public class WhisperServerService extends Application + * Jersey's {@link CommittingOutputStream} has two modes: direct write and buffered writes. In buffered mode, if the + * total amount written does not exceed the output stream's buffer size, CommittingOutputStream will compute the + * content-length for us. However, when it passes through our write to its own underlying output stream it uses + * {@link ByteArrayOutputStream#writeTo(OutputStream)} which performs the write under a synchronized block. + *

+ * If we just disable buffering, we lose our content length. However, we can't really set content-length ourselves + * without the same access to internal state that CommittingOutputStream has. Fortunately, the underlying OutputStream + * wrapped by CommittingOutputStream ALSO has an internal buffer, and can compute the content-length from that if the + * content fits. But to make use of that, we need to avoid flushing that output stream until calling close, so that the + * underlying output stream can see that it has all the data. Unfortunately the runtime inserts manual flushes after + * writes rather than letting the underlying output stream handle it. + *

+ * So here we disable buffering on CommittingOutputStream, and buffer ourselves. We don't write anything to the + * CommittingOutputStream until we are going to close, and we do nothing on flush. + */ +public class BufferingInterceptor implements WriterInterceptor { + + @Override + public void aroundWriteTo(final WriterInterceptorContext ctx) throws IOException, WebApplicationException { + final OutputStream orig = ctx.getOutputStream(); + if (Thread.currentThread().isVirtual() && orig instanceof CommittingOutputStream cos) { + cos.enableBuffering(0); + ctx.setOutputStream(new BufferingOutputStream(cos)); + } + ctx.proceed(); + } + + private static class BufferingOutputStream extends ByteArrayOutputStream { + + private final CommittingOutputStream original; + + BufferingOutputStream(final CommittingOutputStream original) { + this.original = original; + } + + @Override + public void close() throws IOException { + original.write(buf, 0, count); + original.close(); + super.close(); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/BufferingInterceptorIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/BufferingInterceptorIntegrationTest.java new file mode 100644 index 000000000..470207ac8 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/BufferingInterceptorIntegrationTest.java @@ -0,0 +1,77 @@ +package org.whispersystems.textsecuregcm; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.dropwizard.core.Application; +import io.dropwizard.core.Configuration; +import io.dropwizard.core.setup.Environment; +import io.dropwizard.testing.junit5.DropwizardAppExtension; +import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; +import javax.ws.rs.GET; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.core.HttpHeaders; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import org.apache.commons.lang3.RandomStringUtils; +import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; +import org.glassfish.jersey.server.ManagedAsync; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.whispersystems.textsecuregcm.util.BufferingInterceptor; +import org.whispersystems.textsecuregcm.util.VirtualExecutorServiceProvider; + +@ExtendWith(DropwizardExtensionsSupport.class) +public class BufferingInterceptorIntegrationTest { + private static final DropwizardAppExtension DROPWIZARD_APP_EXTENSION = + new DropwizardAppExtension<>(TestApplication.class); + + public static class TestApplication extends Application { + + @Override + public void run(final Configuration configuration, final Environment environment) throws Exception { + final TestController testController = new TestController(); + environment.jersey().register(testController); + environment.jersey().register(new BufferingInterceptor()); + environment.jersey().register(new VirtualExecutorServiceProvider("virtual-thread-")); + JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); + } + } + + @Test + public void testVirtual() { + final Response response = DROPWIZARD_APP_EXTENSION.client() + .target("http://127.0.0.1:%d/test/virtual/8".formatted(DROPWIZARD_APP_EXTENSION.getLocalPort())) + .request().get(); + assertThat(response.getHeaders().getFirst(HttpHeaders.CONTENT_LENGTH)).isEqualTo("8"); + } + + @Test + public void testPlatform() { + final Response response = DROPWIZARD_APP_EXTENSION.client() + .target("http://127.0.0.1:%d/test/platform/8".formatted(DROPWIZARD_APP_EXTENSION.getLocalPort())) + .request().get(); + assertThat(response.getHeaders().getFirst(HttpHeaders.CONTENT_LENGTH)).isEqualTo("8"); + + } + + @Path("/test") + public static class TestController { + + @GET + @Produces(MediaType.APPLICATION_JSON) + @Path("/virtual/{size}") + @ManagedAsync + public String getVirtual(@PathParam("size") int size) { + return RandomStringUtils.randomAscii(size); + } + + @GET + @Produces(MediaType.APPLICATION_JSON) + @Path("/platform/{size}") + public String getPlatform(@PathParam("size") int size) { + return RandomStringUtils.randomAscii(size); + } + } +}