disable response buffering on the websocket

Jersey buffers responses (by default up to 8192 bytes) just so it can
add a content length to responses. We already buffer our responses to
serialize them as protos, so we can compute the content length
ourselves. Setting the buffer to zero disables buffering.
This commit is contained in:
Ravi Khadiwala 2024-03-07 17:16:29 -06:00 committed by ravi-signal
parent 2dc0ea2b89
commit 9e510a678c
4 changed files with 147 additions and 1 deletions

View File

@ -0,0 +1,138 @@
package org.whispersystems.textsecuregcm;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import java.io.IOException;
import java.net.URI;
import java.util.EnumSet;
import javax.servlet.DispatcherType;
import javax.servlet.ServletRegistration;
import javax.ws.rs.GET;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.MediaType;
import org.apache.commons.lang3.RandomStringUtils;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync;
import org.glassfish.jersey.server.ServerProperties;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
public class WebsocketResourceProviderIntegrationTest {
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION =
new DropwizardAppExtension<>(TestApplication.class);
private WebSocketClient client;
@BeforeEach
void setUp() throws Exception {
client = new WebSocketClient();
client.start();
}
@AfterEach
void tearDown() throws Exception {
client.stop();
}
public static class TestApplication extends Application<Configuration> {
@Override
public void run(final Configuration configuration, final Environment environment) throws Exception {
final TestController testController = new TestController();
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
final WebSocketEnvironment<AuthenticatedAccount> webSocketEnvironment =
new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter(true))
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(testController);
webSocketEnvironment.jersey().register(new RemoteAddressFilter(true));
webSocketEnvironment.setAuthenticator(upgradeRequest ->
ReusableAuth.authenticated(mock(AuthenticatedAccount.class), PrincipalSupplier.forImmutablePrincipal()));
webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE);
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {
});
final WebSocketResourceProviderFactory<AuthenticatedAccount> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedAccount.class,
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
final ServletRegistration.Dynamic websocketServlet =
environment.servlets().addServlet("WebSocket", webSocketServlet);
websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
}
}
@ParameterizedTest
// Jersey's content-length buffering by default does not buffer responses with a content-length of > 8192. We disable
// that buffering and do our own though, so the 9000 byte case should work.
@ValueSource(ints = {0, 1, 100, 1025, 9000})
public void contentLength(int length) throws IOException {
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/%d".formatted(length)).join();
assertThat(readResponse.getHeaders().get(HttpHeaders.CONTENT_LENGTH.toLowerCase()))
.isEqualTo(Integer.toString(length));
}
@Path("/test")
public static class TestController {
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/{size}")
@ManagedAsync
public String get(@PathParam("size") int size) {
return RandomStringUtils.randomAscii(size);
}
@PUT
@ManagedAsync
public String put() {
return "put";
}
}
}

View File

@ -264,7 +264,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
ByteArrayOutputStream responseBody) throws IOException {
if (requestMessage.hasRequestId()) {
byte[] body = responseBody.toByteArray();
response.getHeaders().putIfAbsent(HttpHeaders.CONTENT_LENGTH, List.of(body.length));
if (body.length <= 0) {
body = null;
}

View File

@ -9,6 +9,7 @@ import static java.util.Optional.ofNullable;
import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider;
import java.io.IOException;
import java.security.Principal;
import java.util.Map;
import java.util.Optional;
import javax.ws.rs.InternalServerErrorException;
import org.apache.commons.lang3.StringUtils;
@ -17,6 +18,7 @@ import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.eclipse.jetty.websocket.server.JettyWebSocketCreator;
import org.eclipse.jetty.websocket.server.JettyWebSocketServlet;
import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory;
import org.glassfish.jersey.CommonProperties;
import org.glassfish.jersey.server.ApplicationHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -46,6 +48,11 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
environment.jersey().register(new WebsocketAuthValueFactoryProvider.Binder<T>(principalClass));
environment.jersey().register(new JacksonMessageBodyProvider(environment.getObjectMapper()));
// Jersey buffers responses (by default up to 8192 bytes) just so it can add a content length to responses. We
// already buffer our responses to serialize them as protos, so we can compute the content length ourselves. Setting
// the buffer to zero disables buffering.
environment.jersey().addProperties(Map.of(CommonProperties.OUTBOUND_CONTENT_LENGTH_BUFFER, 0));
this.jerseyApplicationHandler = new ApplicationHandler(environment.jersey());
this.configuration = configuration;

View File

@ -137,6 +137,7 @@ class WebSocketResourceProviderTest {
return "OK";
}
});
when(response.getHeaders()).thenReturn(new MultivaluedHashMap<>());
ArgumentCaptor<OutputStream> responseOutputStream = ArgumentCaptor.forClass(OutputStream.class);