diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 6d81055ae..9ed0db953 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -642,9 +642,9 @@ public class WhisperServerService extends Application webSocketServlet = new WebSocketResourceProviderFactory<>( - webSocketEnvironment, AuthenticatedAccount.class); + webSocketEnvironment, AuthenticatedAccount.class, config.getWebSocketConfiguration()); WebSocketResourceProviderFactory provisioningServlet = new WebSocketResourceProviderFactory<>( - provisioningEnvironment, AuthenticatedAccount.class); + provisioningEnvironment, AuthenticatedAccount.class, config.getWebSocketConfiguration()); ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet); ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet); diff --git a/websocket-resources/pom.xml b/websocket-resources/pom.xml index 00e4b1b69..0857c9aa4 100644 --- a/websocket-resources/pom.xml +++ b/websocket-resources/pom.xml @@ -102,6 +102,11 @@ mockito-inline test + + org.junit.jupiter + junit-jupiter + test + diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java index 50fed50ec..ab0876243 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java @@ -23,6 +23,7 @@ import org.whispersystems.websocket.auth.AuthenticationException; import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult; import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; +import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider; import org.whispersystems.websocket.setup.WebSocketEnvironment; @@ -31,9 +32,11 @@ public class WebSocketResourceProviderFactory extends WebSo private static final Logger logger = LoggerFactory.getLogger(WebSocketResourceProviderFactory.class); private final WebSocketEnvironment environment; - private final ApplicationHandler jerseyApplicationHandler; + private final ApplicationHandler jerseyApplicationHandler; + private final WebSocketConfiguration configuration; - public WebSocketResourceProviderFactory(WebSocketEnvironment environment, Class principalClass) { + public WebSocketResourceProviderFactory(WebSocketEnvironment environment, Class principalClass, + WebSocketConfiguration configuration) { this.environment = environment; environment.jersey().register(new WebSocketSessionContextValueFactoryProvider.Binder()); @@ -41,6 +44,8 @@ public class WebSocketResourceProviderFactory extends WebSo environment.jersey().register(new JacksonMessageBodyProvider(environment.getObjectMapper())); this.jerseyApplicationHandler = new ApplicationHandler(environment.jersey()); + + this.configuration = configuration; } @Override @@ -79,9 +84,8 @@ public class WebSocketResourceProviderFactory extends WebSo @Override public void configure(WebSocketServletFactory factory) { factory.setCreator(this); - // TODO extract to configuration - factory.getPolicy().setMaxBinaryMessageSize(512 * 1024); - factory.getPolicy().setMaxTextMessageSize(512 * 1024); + factory.getPolicy().setMaxBinaryMessageSize(configuration.getMaxBinaryMessageSize()); + factory.getPolicy().setMaxTextMessageSize(configuration.getMaxTextMessageSize()); } private String getRemoteAddress(ServletUpgradeRequest request) { diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/configuration/WebSocketConfiguration.java b/websocket-resources/src/main/java/org/whispersystems/websocket/configuration/WebSocketConfiguration.java index ffa67b603..fb8c8a957 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/configuration/WebSocketConfiguration.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/configuration/WebSocketConfiguration.java @@ -5,10 +5,11 @@ package org.whispersystems.websocket.configuration; import com.fasterxml.jackson.annotation.JsonProperty; -import org.whispersystems.websocket.logging.WebsocketRequestLoggerFactory; - import javax.validation.Valid; +import javax.validation.constraints.Max; +import javax.validation.constraints.Min; import javax.validation.constraints.NotNull; +import org.whispersystems.websocket.logging.WebsocketRequestLoggerFactory; public class WebSocketConfiguration { @@ -17,7 +18,25 @@ public class WebSocketConfiguration { @JsonProperty private WebsocketRequestLoggerFactory requestLog = new WebsocketRequestLoggerFactory(); + @Min(512 * 1024) // 512 KB + @Max(10 * 1024 * 1024) // 10 MB + @JsonProperty + private int maxBinaryMessageSize = 512 * 1024; + + @Min(512 * 1024) // 512 KB + @Max(10 * 1024 * 1024) // 10 MB + @JsonProperty + private int maxTextMessageSize = 512 * 1024; + public WebsocketRequestLoggerFactory getRequestLog() { return requestLog; } + + public int getMaxBinaryMessageSize() { + return maxBinaryMessageSize; + } + + public int getMaxTextMessageSize() { + return maxTextMessageSize; + } } diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java index 28fa3d8c2..33baf6e29 100644 --- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java +++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java @@ -1,12 +1,12 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.websocket; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -25,27 +25,43 @@ import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest; import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse; import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory; import org.glassfish.jersey.server.ResourceConfig; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.whispersystems.websocket.auth.AuthenticationException; import org.whispersystems.websocket.auth.WebSocketAuthenticator; +import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.setup.WebSocketEnvironment; public class WebSocketResourceProviderFactoryTest { - @Test - public void testUnauthorized() throws AuthenticationException, IOException { - ResourceConfig jerseyEnvironment = new DropwizardResourceConfig(); - WebSocketEnvironment environment = mock(WebSocketEnvironment.class ); - WebSocketAuthenticator authenticator = mock(WebSocketAuthenticator.class); - ServletUpgradeRequest request = mock(ServletUpgradeRequest.class ); - ServletUpgradeResponse response = mock(ServletUpgradeResponse.class); + private ResourceConfig jerseyEnvironment; + private WebSocketEnvironment environment; + private WebSocketAuthenticator authenticator; + private ServletUpgradeRequest request; + private ServletUpgradeResponse response; + @BeforeEach + void setup() { + jerseyEnvironment = new DropwizardResourceConfig(); + //noinspection unchecked + environment = mock(WebSocketEnvironment.class); + //noinspection unchecked + authenticator = mock(WebSocketAuthenticator.class); + request = mock(ServletUpgradeRequest.class); + response = mock(ServletUpgradeResponse.class); + + } + + @Test + void testUnauthorized() throws AuthenticationException, IOException { when(environment.getAuthenticator()).thenReturn(authenticator); - when(authenticator.authenticate(eq(request))).thenReturn(new WebSocketAuthenticator.AuthenticationResult<>(Optional.empty(), true)); + when(authenticator.authenticate(eq(request))).thenReturn( + new WebSocketAuthenticator.AuthenticationResult<>(Optional.empty(), true)); when(environment.jersey()).thenReturn(jerseyEnvironment); - WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class); - Object connection = factory.createWebSocket(request, response); + WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, Account.class, + mock(WebSocketConfiguration.class)); + Object connection = factory.createWebSocket(request, response); assertNull(connection); verify(response).sendForbidden(eq("Unauthorized")); @@ -53,47 +69,40 @@ public class WebSocketResourceProviderFactoryTest { } @Test - public void testValidAuthorization() throws AuthenticationException { - ResourceConfig jerseyEnvironment = new DropwizardResourceConfig(); - WebSocketEnvironment environment = mock(WebSocketEnvironment.class ); - WebSocketAuthenticator authenticator = mock(WebSocketAuthenticator.class ); - ServletUpgradeRequest request = mock(ServletUpgradeRequest.class ); - ServletUpgradeResponse response = mock(ServletUpgradeResponse.class ); - Session session = mock(Session.class ); - Account account = new Account(); + void testValidAuthorization() throws AuthenticationException { + Session session = mock(Session.class); + Account account = new Account(); when(environment.getAuthenticator()).thenReturn(authenticator); - when(authenticator.authenticate(eq(request))).thenReturn(new WebSocketAuthenticator.AuthenticationResult<>(Optional.of(account), true)); + when(authenticator.authenticate(eq(request))).thenReturn( + new WebSocketAuthenticator.AuthenticationResult<>(Optional.of(account), true)); when(environment.jersey()).thenReturn(jerseyEnvironment); when(session.getUpgradeRequest()).thenReturn(mock(UpgradeRequest.class)); - WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class); - Object connection = factory.createWebSocket(request, response); + WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, Account.class, + mock(WebSocketConfiguration.class)); + Object connection = factory.createWebSocket(request, response); assertNotNull(connection); verifyNoMoreInteractions(response); verify(authenticator).authenticate(eq(request)); - ((WebSocketResourceProvider)connection).onWebSocketConnect(session); + ((WebSocketResourceProvider) connection).onWebSocketConnect(session); - assertNotNull(((WebSocketResourceProvider) connection).getContext().getAuthenticated()); - assertEquals(((WebSocketResourceProvider)connection).getContext().getAuthenticated(), account); + assertNotNull(((WebSocketResourceProvider) connection).getContext().getAuthenticated()); + assertEquals(((WebSocketResourceProvider) connection).getContext().getAuthenticated(), account); } @Test - public void testErrorAuthorization() throws AuthenticationException, IOException { - ResourceConfig jerseyEnvironment = new DropwizardResourceConfig(); - WebSocketEnvironment environment = mock(WebSocketEnvironment.class ); - WebSocketAuthenticator authenticator = mock(WebSocketAuthenticator.class ); - ServletUpgradeRequest request = mock(ServletUpgradeRequest.class ); - ServletUpgradeResponse response = mock(ServletUpgradeResponse.class ); - + void testErrorAuthorization() throws AuthenticationException, IOException { when(environment.getAuthenticator()).thenReturn(authenticator); when(authenticator.authenticate(eq(request))).thenThrow(new AuthenticationException("database failure")); when(environment.jersey()).thenReturn(jerseyEnvironment); - WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class); - Object connection = factory.createWebSocket(request, response); + WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, + Account.class, + mock(WebSocketConfiguration.class)); + Object connection = factory.createWebSocket(request, response); assertNull(connection); verify(response).sendError(eq(500), eq("Failure")); @@ -101,14 +110,14 @@ public class WebSocketResourceProviderFactoryTest { } @Test - public void testConfigure() { - ResourceConfig jerseyEnvironment = new DropwizardResourceConfig(); - WebSocketEnvironment environment = mock(WebSocketEnvironment.class ); - WebSocketServletFactory servletFactory = mock(WebSocketServletFactory.class ); + void testConfigure() { + WebSocketServletFactory servletFactory = mock(WebSocketServletFactory.class); when(environment.jersey()).thenReturn(jerseyEnvironment); when(servletFactory.getPolicy()).thenReturn(mock(WebSocketPolicy.class)); - WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class); + WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, + Account.class, + mock(WebSocketConfiguration.class)); factory.configure(servletFactory); verify(servletFactory).setCreator(eq(factory));