From 9e510a678cc16c5358f80daac9988811ffcacdc4 Mon Sep 17 00:00:00 2001 From: Ravi Khadiwala Date: Thu, 7 Mar 2024 17:16:29 -0600 Subject: [PATCH] disable response buffering on the websocket Jersey buffers responses (by default up to 8192 bytes) just so it can add a content length to responses. We already buffer our responses to serialize them as protos, so we can compute the content length ourselves. Setting the buffer to zero disables buffering. --- ...socketResourceProviderIntegrationTest.java | 138 ++++++++++++++++++ .../websocket/WebSocketResourceProvider.java | 2 +- .../WebSocketResourceProviderFactory.java | 7 + .../WebSocketResourceProviderTest.java | 1 + 4 files changed, 147 insertions(+), 1 deletion(-) create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/WebsocketResourceProviderIntegrationTest.java diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketResourceProviderIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketResourceProviderIntegrationTest.java new file mode 100644 index 000000000..3728ad04d --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketResourceProviderIntegrationTest.java @@ -0,0 +1,138 @@ +package org.whispersystems.textsecuregcm; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.when; +import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME; + +import io.dropwizard.core.Application; +import io.dropwizard.core.Configuration; +import io.dropwizard.core.setup.Environment; +import io.dropwizard.testing.junit5.DropwizardAppExtension; +import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; +import java.io.IOException; +import java.net.URI; +import java.util.EnumSet; +import javax.servlet.DispatcherType; +import javax.servlet.ServletRegistration; +import javax.ws.rs.GET; +import javax.ws.rs.PUT; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.core.HttpHeaders; +import javax.ws.rs.core.MediaType; +import org.apache.commons.lang3.RandomStringUtils; +import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; +import org.glassfish.jersey.server.ManagedAsync; +import org.glassfish.jersey.server.ServerProperties; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; +import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener; +import org.whispersystems.websocket.ReusableAuth; +import org.whispersystems.websocket.WebSocketResourceProviderFactory; +import org.whispersystems.websocket.auth.PrincipalSupplier; +import org.whispersystems.websocket.configuration.WebSocketConfiguration; +import org.whispersystems.websocket.messages.WebSocketResponseMessage; +import org.whispersystems.websocket.setup.WebSocketEnvironment; + +@ExtendWith(DropwizardExtensionsSupport.class) +public class WebsocketResourceProviderIntegrationTest { + private static final DropwizardAppExtension DROPWIZARD_APP_EXTENSION = + new DropwizardAppExtension<>(TestApplication.class); + + + private WebSocketClient client; + + @BeforeEach + void setUp() throws Exception { + client = new WebSocketClient(); + client.start(); + } + + @AfterEach + void tearDown() throws Exception { + client.stop(); + } + + + public static class TestApplication extends Application { + + @Override + public void run(final Configuration configuration, final Environment environment) throws Exception { + final TestController testController = new TestController(); + + final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration(); + + final WebSocketEnvironment webSocketEnvironment = + new WebSocketEnvironment<>(environment, webSocketConfiguration); + + environment.jersey().register(testController); + environment.servlets() + .addFilter("RemoteAddressFilter", new RemoteAddressFilter(true)) + .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); + webSocketEnvironment.jersey().register(testController); + webSocketEnvironment.jersey().register(new RemoteAddressFilter(true)); + webSocketEnvironment.setAuthenticator(upgradeRequest -> + ReusableAuth.authenticated(mock(AuthenticatedAccount.class), PrincipalSupplier.forImmutablePrincipal())); + + webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE); + webSocketEnvironment.setConnectListener(webSocketSessionContext -> { + }); + + final WebSocketResourceProviderFactory webSocketServlet = + new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedAccount.class, + webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME); + + JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); + + final ServletRegistration.Dynamic websocketServlet = + environment.servlets().addServlet("WebSocket", webSocketServlet); + + websocketServlet.addMapping("/websocket"); + websocketServlet.setAsyncSupported(true); + } + } + + + @ParameterizedTest + // Jersey's content-length buffering by default does not buffer responses with a content-length of > 8192. We disable + // that buffering and do our own though, so the 9000 byte case should work. + @ValueSource(ints = {0, 1, 100, 1025, 9000}) + public void contentLength(int length) throws IOException { + final TestWebsocketListener testWebsocketListener = new TestWebsocketListener(); + client.connect(testWebsocketListener, + URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort()))); + + final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/%d".formatted(length)).join(); + assertThat(readResponse.getHeaders().get(HttpHeaders.CONTENT_LENGTH.toLowerCase())) + .isEqualTo(Integer.toString(length)); + } + + + @Path("/test") + public static class TestController { + + @GET + @Produces(MediaType.APPLICATION_JSON) + @Path("/{size}") + @ManagedAsync + public String get(@PathParam("size") int size) { + return RandomStringUtils.randomAscii(size); + } + + @PUT + @ManagedAsync + public String put() { + return "put"; + } + } +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java index 168ea477e..00f08da26 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java @@ -264,7 +264,7 @@ public class WebSocketResourceProvider implements WebSocket ByteArrayOutputStream responseBody) throws IOException { if (requestMessage.hasRequestId()) { byte[] body = responseBody.toByteArray(); - + response.getHeaders().putIfAbsent(HttpHeaders.CONTENT_LENGTH, List.of(body.length)); if (body.length <= 0) { body = null; } diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java index b7378952a..d8a60d381 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java @@ -9,6 +9,7 @@ import static java.util.Optional.ofNullable; import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider; import java.io.IOException; import java.security.Principal; +import java.util.Map; import java.util.Optional; import javax.ws.rs.InternalServerErrorException; import org.apache.commons.lang3.StringUtils; @@ -17,6 +18,7 @@ import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse; import org.eclipse.jetty.websocket.server.JettyWebSocketCreator; import org.eclipse.jetty.websocket.server.JettyWebSocketServlet; import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory; +import org.glassfish.jersey.CommonProperties; import org.glassfish.jersey.server.ApplicationHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,6 +48,11 @@ public class WebSocketResourceProviderFactory extends Jetty environment.jersey().register(new WebsocketAuthValueFactoryProvider.Binder(principalClass)); environment.jersey().register(new JacksonMessageBodyProvider(environment.getObjectMapper())); + // Jersey buffers responses (by default up to 8192 bytes) just so it can add a content length to responses. We + // already buffer our responses to serialize them as protos, so we can compute the content length ourselves. Setting + // the buffer to zero disables buffering. + environment.jersey().addProperties(Map.of(CommonProperties.OUTBOUND_CONTENT_LENGTH_BUFFER, 0)); + this.jerseyApplicationHandler = new ApplicationHandler(environment.jersey()); this.configuration = configuration; diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java index f2955fdf1..42250d29b 100644 --- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java +++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java @@ -137,6 +137,7 @@ class WebSocketResourceProviderTest { return "OK"; } }); + when(response.getHeaders()).thenReturn(new MultivaluedHashMap<>()); ArgumentCaptor responseOutputStream = ArgumentCaptor.forClass(OutputStream.class);