Extract configuration for WebSocket max message sizes

This commit is contained in:
Chris Eager 2021-08-13 14:29:13 -05:00 committed by Chris Eager
parent a398e2269c
commit 19f7b207b7
5 changed files with 89 additions and 52 deletions

View File

@ -642,9 +642,9 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
provisioningEnvironment.jersey().register(rateLimitChallengeExceptionMapper);
WebSocketResourceProviderFactory<AuthenticatedAccount> webSocketServlet = new WebSocketResourceProviderFactory<>(
webSocketEnvironment, AuthenticatedAccount.class);
webSocketEnvironment, AuthenticatedAccount.class, config.getWebSocketConfiguration());
WebSocketResourceProviderFactory<AuthenticatedAccount> provisioningServlet = new WebSocketResourceProviderFactory<>(
provisioningEnvironment, AuthenticatedAccount.class);
provisioningEnvironment, AuthenticatedAccount.class, config.getWebSocketConfiguration());
ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet);
ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet);

View File

@ -102,6 +102,11 @@
<artifactId>mockito-inline</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@ -23,6 +23,7 @@ import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ -31,9 +32,11 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends WebSo
private static final Logger logger = LoggerFactory.getLogger(WebSocketResourceProviderFactory.class);
private final WebSocketEnvironment<T> environment;
private final ApplicationHandler jerseyApplicationHandler;
private final ApplicationHandler jerseyApplicationHandler;
private final WebSocketConfiguration configuration;
public WebSocketResourceProviderFactory(WebSocketEnvironment<T> environment, Class<T> principalClass) {
public WebSocketResourceProviderFactory(WebSocketEnvironment<T> environment, Class<T> principalClass,
WebSocketConfiguration configuration) {
this.environment = environment;
environment.jersey().register(new WebSocketSessionContextValueFactoryProvider.Binder());
@ -41,6 +44,8 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends WebSo
environment.jersey().register(new JacksonMessageBodyProvider(environment.getObjectMapper()));
this.jerseyApplicationHandler = new ApplicationHandler(environment.jersey());
this.configuration = configuration;
}
@Override
@ -79,9 +84,8 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends WebSo
@Override
public void configure(WebSocketServletFactory factory) {
factory.setCreator(this);
// TODO extract to configuration
factory.getPolicy().setMaxBinaryMessageSize(512 * 1024);
factory.getPolicy().setMaxTextMessageSize(512 * 1024);
factory.getPolicy().setMaxBinaryMessageSize(configuration.getMaxBinaryMessageSize());
factory.getPolicy().setMaxTextMessageSize(configuration.getMaxTextMessageSize());
}
private String getRemoteAddress(ServletUpgradeRequest request) {

View File

@ -5,10 +5,11 @@
package org.whispersystems.websocket.configuration;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.whispersystems.websocket.logging.WebsocketRequestLoggerFactory;
import javax.validation.Valid;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotNull;
import org.whispersystems.websocket.logging.WebsocketRequestLoggerFactory;
public class WebSocketConfiguration {
@ -17,7 +18,25 @@ public class WebSocketConfiguration {
@JsonProperty
private WebsocketRequestLoggerFactory requestLog = new WebsocketRequestLoggerFactory();
@Min(512 * 1024) // 512 KB
@Max(10 * 1024 * 1024) // 10 MB
@JsonProperty
private int maxBinaryMessageSize = 512 * 1024;
@Min(512 * 1024) // 512 KB
@Max(10 * 1024 * 1024) // 10 MB
@JsonProperty
private int maxTextMessageSize = 512 * 1024;
public WebsocketRequestLoggerFactory getRequestLog() {
return requestLog;
}
public int getMaxBinaryMessageSize() {
return maxBinaryMessageSize;
}
public int getMaxTextMessageSize() {
return maxTextMessageSize;
}
}

View File

@ -1,12 +1,12 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@ -25,27 +25,43 @@ import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory;
import org.glassfish.jersey.server.ResourceConfig;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
public class WebSocketResourceProviderFactoryTest {
@Test
public void testUnauthorized() throws AuthenticationException, IOException {
ResourceConfig jerseyEnvironment = new DropwizardResourceConfig();
WebSocketEnvironment environment = mock(WebSocketEnvironment.class );
WebSocketAuthenticator authenticator = mock(WebSocketAuthenticator.class);
ServletUpgradeRequest request = mock(ServletUpgradeRequest.class );
ServletUpgradeResponse response = mock(ServletUpgradeResponse.class);
private ResourceConfig jerseyEnvironment;
private WebSocketEnvironment<Account> environment;
private WebSocketAuthenticator<Account> authenticator;
private ServletUpgradeRequest request;
private ServletUpgradeResponse response;
@BeforeEach
void setup() {
jerseyEnvironment = new DropwizardResourceConfig();
//noinspection unchecked
environment = mock(WebSocketEnvironment.class);
//noinspection unchecked
authenticator = mock(WebSocketAuthenticator.class);
request = mock(ServletUpgradeRequest.class);
response = mock(ServletUpgradeResponse.class);
}
@Test
void testUnauthorized() throws AuthenticationException, IOException {
when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenReturn(new WebSocketAuthenticator.AuthenticationResult<>(Optional.empty(), true));
when(authenticator.authenticate(eq(request))).thenReturn(
new WebSocketAuthenticator.AuthenticationResult<>(Optional.empty(), true));
when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class);
Object connection = factory.createWebSocket(request, response);
WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
mock(WebSocketConfiguration.class));
Object connection = factory.createWebSocket(request, response);
assertNull(connection);
verify(response).sendForbidden(eq("Unauthorized"));
@ -53,47 +69,40 @@ public class WebSocketResourceProviderFactoryTest {
}
@Test
public void testValidAuthorization() throws AuthenticationException {
ResourceConfig jerseyEnvironment = new DropwizardResourceConfig();
WebSocketEnvironment environment = mock(WebSocketEnvironment.class );
WebSocketAuthenticator authenticator = mock(WebSocketAuthenticator.class );
ServletUpgradeRequest request = mock(ServletUpgradeRequest.class );
ServletUpgradeResponse response = mock(ServletUpgradeResponse.class );
Session session = mock(Session.class );
Account account = new Account();
void testValidAuthorization() throws AuthenticationException {
Session session = mock(Session.class);
Account account = new Account();
when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenReturn(new WebSocketAuthenticator.AuthenticationResult<>(Optional.of(account), true));
when(authenticator.authenticate(eq(request))).thenReturn(
new WebSocketAuthenticator.AuthenticationResult<>(Optional.of(account), true));
when(environment.jersey()).thenReturn(jerseyEnvironment);
when(session.getUpgradeRequest()).thenReturn(mock(UpgradeRequest.class));
WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class);
Object connection = factory.createWebSocket(request, response);
WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
mock(WebSocketConfiguration.class));
Object connection = factory.createWebSocket(request, response);
assertNotNull(connection);
verifyNoMoreInteractions(response);
verify(authenticator).authenticate(eq(request));
((WebSocketResourceProvider)connection).onWebSocketConnect(session);
((WebSocketResourceProvider<?>) connection).onWebSocketConnect(session);
assertNotNull(((WebSocketResourceProvider) connection).getContext().getAuthenticated());
assertEquals(((WebSocketResourceProvider)connection).getContext().getAuthenticated(), account);
assertNotNull(((WebSocketResourceProvider<?>) connection).getContext().getAuthenticated());
assertEquals(((WebSocketResourceProvider<?>) connection).getContext().getAuthenticated(), account);
}
@Test
public void testErrorAuthorization() throws AuthenticationException, IOException {
ResourceConfig jerseyEnvironment = new DropwizardResourceConfig();
WebSocketEnvironment environment = mock(WebSocketEnvironment.class );
WebSocketAuthenticator authenticator = mock(WebSocketAuthenticator.class );
ServletUpgradeRequest request = mock(ServletUpgradeRequest.class );
ServletUpgradeResponse response = mock(ServletUpgradeResponse.class );
void testErrorAuthorization() throws AuthenticationException, IOException {
when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenThrow(new AuthenticationException("database failure"));
when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class);
Object connection = factory.createWebSocket(request, response);
WebSocketResourceProviderFactory<Account> factory = new WebSocketResourceProviderFactory<>(environment,
Account.class,
mock(WebSocketConfiguration.class));
Object connection = factory.createWebSocket(request, response);
assertNull(connection);
verify(response).sendError(eq(500), eq("Failure"));
@ -101,14 +110,14 @@ public class WebSocketResourceProviderFactoryTest {
}
@Test
public void testConfigure() {
ResourceConfig jerseyEnvironment = new DropwizardResourceConfig();
WebSocketEnvironment environment = mock(WebSocketEnvironment.class );
WebSocketServletFactory servletFactory = mock(WebSocketServletFactory.class );
void testConfigure() {
WebSocketServletFactory servletFactory = mock(WebSocketServletFactory.class);
when(environment.jersey()).thenReturn(jerseyEnvironment);
when(servletFactory.getPolicy()).thenReturn(mock(WebSocketPolicy.class));
WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class);
WebSocketResourceProviderFactory<Account> factory = new WebSocketResourceProviderFactory<>(environment,
Account.class,
mock(WebSocketConfiguration.class));
factory.configure(servletFactory);
verify(servletFactory).setCreator(eq(factory));