Extract configuration for WebSocket max message sizes

This commit is contained in:
Chris Eager 2021-08-13 14:29:13 -05:00 committed by Chris Eager
parent a398e2269c
commit 19f7b207b7
5 changed files with 89 additions and 52 deletions

View File

@ -642,9 +642,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
provisioningEnvironment.jersey().register(rateLimitChallengeExceptionMapper); provisioningEnvironment.jersey().register(rateLimitChallengeExceptionMapper);
WebSocketResourceProviderFactory<AuthenticatedAccount> webSocketServlet = new WebSocketResourceProviderFactory<>( WebSocketResourceProviderFactory<AuthenticatedAccount> webSocketServlet = new WebSocketResourceProviderFactory<>(
webSocketEnvironment, AuthenticatedAccount.class); webSocketEnvironment, AuthenticatedAccount.class, config.getWebSocketConfiguration());
WebSocketResourceProviderFactory<AuthenticatedAccount> provisioningServlet = new WebSocketResourceProviderFactory<>( WebSocketResourceProviderFactory<AuthenticatedAccount> provisioningServlet = new WebSocketResourceProviderFactory<>(
provisioningEnvironment, AuthenticatedAccount.class); provisioningEnvironment, AuthenticatedAccount.class, config.getWebSocketConfiguration());
ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet); ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet);
ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet); ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet);

View File

@ -102,6 +102,11 @@
<artifactId>mockito-inline</artifactId> <artifactId>mockito-inline</artifactId>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
</project> </project>

View File

@ -23,6 +23,7 @@ import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult; import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider; import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider;
import org.whispersystems.websocket.setup.WebSocketEnvironment; import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ -31,9 +32,11 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends WebSo
private static final Logger logger = LoggerFactory.getLogger(WebSocketResourceProviderFactory.class); private static final Logger logger = LoggerFactory.getLogger(WebSocketResourceProviderFactory.class);
private final WebSocketEnvironment<T> environment; private final WebSocketEnvironment<T> environment;
private final ApplicationHandler jerseyApplicationHandler; private final ApplicationHandler jerseyApplicationHandler;
private final WebSocketConfiguration configuration;
public WebSocketResourceProviderFactory(WebSocketEnvironment<T> environment, Class<T> principalClass) { public WebSocketResourceProviderFactory(WebSocketEnvironment<T> environment, Class<T> principalClass,
WebSocketConfiguration configuration) {
this.environment = environment; this.environment = environment;
environment.jersey().register(new WebSocketSessionContextValueFactoryProvider.Binder()); environment.jersey().register(new WebSocketSessionContextValueFactoryProvider.Binder());
@ -41,6 +44,8 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends WebSo
environment.jersey().register(new JacksonMessageBodyProvider(environment.getObjectMapper())); environment.jersey().register(new JacksonMessageBodyProvider(environment.getObjectMapper()));
this.jerseyApplicationHandler = new ApplicationHandler(environment.jersey()); this.jerseyApplicationHandler = new ApplicationHandler(environment.jersey());
this.configuration = configuration;
} }
@Override @Override
@ -79,9 +84,8 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends WebSo
@Override @Override
public void configure(WebSocketServletFactory factory) { public void configure(WebSocketServletFactory factory) {
factory.setCreator(this); factory.setCreator(this);
// TODO extract to configuration factory.getPolicy().setMaxBinaryMessageSize(configuration.getMaxBinaryMessageSize());
factory.getPolicy().setMaxBinaryMessageSize(512 * 1024); factory.getPolicy().setMaxTextMessageSize(configuration.getMaxTextMessageSize());
factory.getPolicy().setMaxTextMessageSize(512 * 1024);
} }
private String getRemoteAddress(ServletUpgradeRequest request) { private String getRemoteAddress(ServletUpgradeRequest request) {

View File

@ -5,10 +5,11 @@
package org.whispersystems.websocket.configuration; package org.whispersystems.websocket.configuration;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import org.whispersystems.websocket.logging.WebsocketRequestLoggerFactory;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import org.whispersystems.websocket.logging.WebsocketRequestLoggerFactory;
public class WebSocketConfiguration { public class WebSocketConfiguration {
@ -17,7 +18,25 @@ public class WebSocketConfiguration {
@JsonProperty @JsonProperty
private WebsocketRequestLoggerFactory requestLog = new WebsocketRequestLoggerFactory(); private WebsocketRequestLoggerFactory requestLog = new WebsocketRequestLoggerFactory();
@Min(512 * 1024) // 512 KB
@Max(10 * 1024 * 1024) // 10 MB
@JsonProperty
private int maxBinaryMessageSize = 512 * 1024;
@Min(512 * 1024) // 512 KB
@Max(10 * 1024 * 1024) // 10 MB
@JsonProperty
private int maxTextMessageSize = 512 * 1024;
public WebsocketRequestLoggerFactory getRequestLog() { public WebsocketRequestLoggerFactory getRequestLog() {
return requestLog; return requestLog;
} }
public int getMaxBinaryMessageSize() {
return maxBinaryMessageSize;
}
public int getMaxTextMessageSize() {
return maxTextMessageSize;
}
} }

View File

@ -1,12 +1,12 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.websocket; package org.whispersystems.websocket;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.Assert.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -25,27 +25,43 @@ import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse; import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory; import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory;
import org.glassfish.jersey.server.ResourceConfig; import org.glassfish.jersey.server.ResourceConfig;
import org.junit.Test; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.websocket.auth.AuthenticationException; import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.setup.WebSocketEnvironment; import org.whispersystems.websocket.setup.WebSocketEnvironment;
public class WebSocketResourceProviderFactoryTest { public class WebSocketResourceProviderFactoryTest {
@Test private ResourceConfig jerseyEnvironment;
public void testUnauthorized() throws AuthenticationException, IOException { private WebSocketEnvironment<Account> environment;
ResourceConfig jerseyEnvironment = new DropwizardResourceConfig(); private WebSocketAuthenticator<Account> authenticator;
WebSocketEnvironment environment = mock(WebSocketEnvironment.class ); private ServletUpgradeRequest request;
WebSocketAuthenticator authenticator = mock(WebSocketAuthenticator.class); private ServletUpgradeResponse response;
ServletUpgradeRequest request = mock(ServletUpgradeRequest.class );
ServletUpgradeResponse response = mock(ServletUpgradeResponse.class);
@BeforeEach
void setup() {
jerseyEnvironment = new DropwizardResourceConfig();
//noinspection unchecked
environment = mock(WebSocketEnvironment.class);
//noinspection unchecked
authenticator = mock(WebSocketAuthenticator.class);
request = mock(ServletUpgradeRequest.class);
response = mock(ServletUpgradeResponse.class);
}
@Test
void testUnauthorized() throws AuthenticationException, IOException {
when(environment.getAuthenticator()).thenReturn(authenticator); when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenReturn(new WebSocketAuthenticator.AuthenticationResult<>(Optional.empty(), true)); when(authenticator.authenticate(eq(request))).thenReturn(
new WebSocketAuthenticator.AuthenticationResult<>(Optional.empty(), true));
when(environment.jersey()).thenReturn(jerseyEnvironment); when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class); WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
Object connection = factory.createWebSocket(request, response); mock(WebSocketConfiguration.class));
Object connection = factory.createWebSocket(request, response);
assertNull(connection); assertNull(connection);
verify(response).sendForbidden(eq("Unauthorized")); verify(response).sendForbidden(eq("Unauthorized"));
@ -53,47 +69,40 @@ public class WebSocketResourceProviderFactoryTest {
} }
@Test @Test
public void testValidAuthorization() throws AuthenticationException { void testValidAuthorization() throws AuthenticationException {
ResourceConfig jerseyEnvironment = new DropwizardResourceConfig(); Session session = mock(Session.class);
WebSocketEnvironment environment = mock(WebSocketEnvironment.class ); Account account = new Account();
WebSocketAuthenticator authenticator = mock(WebSocketAuthenticator.class );
ServletUpgradeRequest request = mock(ServletUpgradeRequest.class );
ServletUpgradeResponse response = mock(ServletUpgradeResponse.class );
Session session = mock(Session.class );
Account account = new Account();
when(environment.getAuthenticator()).thenReturn(authenticator); when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenReturn(new WebSocketAuthenticator.AuthenticationResult<>(Optional.of(account), true)); when(authenticator.authenticate(eq(request))).thenReturn(
new WebSocketAuthenticator.AuthenticationResult<>(Optional.of(account), true));
when(environment.jersey()).thenReturn(jerseyEnvironment); when(environment.jersey()).thenReturn(jerseyEnvironment);
when(session.getUpgradeRequest()).thenReturn(mock(UpgradeRequest.class)); when(session.getUpgradeRequest()).thenReturn(mock(UpgradeRequest.class));
WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class); WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
Object connection = factory.createWebSocket(request, response); mock(WebSocketConfiguration.class));
Object connection = factory.createWebSocket(request, response);
assertNotNull(connection); assertNotNull(connection);
verifyNoMoreInteractions(response); verifyNoMoreInteractions(response);
verify(authenticator).authenticate(eq(request)); verify(authenticator).authenticate(eq(request));
((WebSocketResourceProvider)connection).onWebSocketConnect(session); ((WebSocketResourceProvider<?>) connection).onWebSocketConnect(session);
assertNotNull(((WebSocketResourceProvider) connection).getContext().getAuthenticated()); assertNotNull(((WebSocketResourceProvider<?>) connection).getContext().getAuthenticated());
assertEquals(((WebSocketResourceProvider)connection).getContext().getAuthenticated(), account); assertEquals(((WebSocketResourceProvider<?>) connection).getContext().getAuthenticated(), account);
} }
@Test @Test
public void testErrorAuthorization() throws AuthenticationException, IOException { void testErrorAuthorization() throws AuthenticationException, IOException {
ResourceConfig jerseyEnvironment = new DropwizardResourceConfig();
WebSocketEnvironment environment = mock(WebSocketEnvironment.class );
WebSocketAuthenticator authenticator = mock(WebSocketAuthenticator.class );
ServletUpgradeRequest request = mock(ServletUpgradeRequest.class );
ServletUpgradeResponse response = mock(ServletUpgradeResponse.class );
when(environment.getAuthenticator()).thenReturn(authenticator); when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenThrow(new AuthenticationException("database failure")); when(authenticator.authenticate(eq(request))).thenThrow(new AuthenticationException("database failure"));
when(environment.jersey()).thenReturn(jerseyEnvironment); when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class); WebSocketResourceProviderFactory<Account> factory = new WebSocketResourceProviderFactory<>(environment,
Object connection = factory.createWebSocket(request, response); Account.class,
mock(WebSocketConfiguration.class));
Object connection = factory.createWebSocket(request, response);
assertNull(connection); assertNull(connection);
verify(response).sendError(eq(500), eq("Failure")); verify(response).sendError(eq(500), eq("Failure"));
@ -101,14 +110,14 @@ public class WebSocketResourceProviderFactoryTest {
} }
@Test @Test
public void testConfigure() { void testConfigure() {
ResourceConfig jerseyEnvironment = new DropwizardResourceConfig(); WebSocketServletFactory servletFactory = mock(WebSocketServletFactory.class);
WebSocketEnvironment environment = mock(WebSocketEnvironment.class );
WebSocketServletFactory servletFactory = mock(WebSocketServletFactory.class );
when(environment.jersey()).thenReturn(jerseyEnvironment); when(environment.jersey()).thenReturn(jerseyEnvironment);
when(servletFactory.getPolicy()).thenReturn(mock(WebSocketPolicy.class)); when(servletFactory.getPolicy()).thenReturn(mock(WebSocketPolicy.class));
WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class); WebSocketResourceProviderFactory<Account> factory = new WebSocketResourceProviderFactory<>(environment,
Account.class,
mock(WebSocketConfiguration.class));
factory.configure(servletFactory); factory.configure(servletFactory);
verify(servletFactory).setCreator(eq(factory)); verify(servletFactory).setCreator(eq(factory));