From 552079d3c24bac7f08ba5231bbfacb8770d80eee Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Fri, 28 Feb 2025 09:33:24 -0500 Subject: [PATCH] Add an interceptor interface for WebSocket authentication --- .../WebSocketAccountAuthenticator.java | 2 +- .../WebSocketResourceProviderFactory.java | 3 +++ .../AuthenticatedWebSocketUpgradeFilter.java | 18 ++++++++++++++ .../websocket/setup/WebSocketEnvironment.java | 13 ++++++++++ .../WebSocketResourceProviderFactoryTest.java | 24 +++++++++++++++++++ 5 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 websocket-resources/src/main/java/org/whispersystems/websocket/auth/AuthenticatedWebSocketUpgradeFilter.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java index 10afbae40..d9640c177 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java @@ -49,7 +49,7 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator accountAuthenticator.authenticate(credentials)) + .flatMap(accountAuthenticator::authenticate) .map(authenticatedAccount -> ReusableAuth.authenticated(authenticatedAccount, this.principalSupplier)) .orElse(INVALID_CREDENTIALS_PRESENTED); } diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java index 3858565fd..0c0aa0ec0 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java @@ -76,6 +76,9 @@ public class WebSocketResourceProviderFactory extends Jetty authenticated = ReusableAuth.anonymous(); } + Optional.ofNullable(environment.getAuthenticatedWebSocketUpgradeFilter()) + .ifPresent(filter -> filter.handleAuthentication(authenticated, request, response)); + return new WebSocketResourceProvider<>(getRemoteAddress(request), remoteAddressPropertyName, this.jerseyApplicationHandler, diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/AuthenticatedWebSocketUpgradeFilter.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/AuthenticatedWebSocketUpgradeFilter.java new file mode 100644 index 000000000..daf25b080 --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/AuthenticatedWebSocketUpgradeFilter.java @@ -0,0 +1,18 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.websocket.auth; + +import java.security.Principal; +import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest; +import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse; +import org.whispersystems.websocket.ReusableAuth; + +public interface AuthenticatedWebSocketUpgradeFilter { + + void handleAuthentication(ReusableAuth authenticated, + JettyServerUpgradeRequest request, + JettyServerUpgradeResponse response); +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/setup/WebSocketEnvironment.java b/websocket-resources/src/main/java/org/whispersystems/websocket/setup/WebSocketEnvironment.java index 7d5d79f88..e22ea3eb0 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/setup/WebSocketEnvironment.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/setup/WebSocketEnvironment.java @@ -11,11 +11,13 @@ import jakarta.validation.Validator; import java.security.Principal; import java.time.Duration; import org.glassfish.jersey.server.ResourceConfig; +import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter; import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.logging.WebsocketRequestLog; import org.whispersystems.websocket.messages.WebSocketMessageFactory; import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory; +import javax.annotation.Nullable; public class WebSocketEnvironment { @@ -26,6 +28,7 @@ public class WebSocketEnvironment { private final Duration idleTimeout; private WebSocketAuthenticator authenticator; + private AuthenticatedWebSocketUpgradeFilter authenticatedWebSocketUpgradeFilter; private WebSocketMessageFactory messageFactory; private WebSocketConnectListener connectListener; @@ -50,6 +53,7 @@ public class WebSocketEnvironment { return jerseyConfig; } + @Nullable public WebSocketAuthenticator getAuthenticator() { return authenticator; } @@ -58,6 +62,15 @@ public class WebSocketEnvironment { this.authenticator = authenticator; } + @Nullable + public AuthenticatedWebSocketUpgradeFilter getAuthenticatedWebSocketUpgradeFilter() { + return authenticatedWebSocketUpgradeFilter; + } + + public void setAuthenticatedWebSocketUpgradeFilter(final AuthenticatedWebSocketUpgradeFilter authenticatedWebSocketUpgradeFilter) { + this.authenticatedWebSocketUpgradeFilter = authenticatedWebSocketUpgradeFilter; + } + public Duration getIdleTimeout() { return idleTimeout; } diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java index 0a2140984..ee1b019b1 100644 --- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java +++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java @@ -27,6 +27,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.whispersystems.websocket.auth.AuthenticationException; import org.whispersystems.websocket.auth.PrincipalSupplier; +import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter; import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.setup.WebSocketEnvironment; @@ -125,6 +126,29 @@ public class WebSocketResourceProviderFactoryTest { verify(servletFactory).setCreator(eq(factory)); } + @Test + void testAuthenticatedWebSocketUpgradeFilter() throws AuthenticationException { + final Account account = new Account(); + final ReusableAuth reusableAuth = + ReusableAuth.authenticated(account, PrincipalSupplier.forImmutablePrincipal()); + + when(environment.getAuthenticator()).thenReturn(authenticator); + when(authenticator.authenticate(eq(request))).thenReturn(reusableAuth); + when(environment.jersey()).thenReturn(jerseyEnvironment); + final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class); + when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1"); + when(request.getHttpServletRequest()).thenReturn(httpServletRequest); + + final AuthenticatedWebSocketUpgradeFilter filter = mock(AuthenticatedWebSocketUpgradeFilter.class); + when(environment.getAuthenticatedWebSocketUpgradeFilter()).thenReturn(filter); + + final WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, Account.class, + mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME); + assertNotNull(factory.createWebSocket(request, response)); + + verify(filter).handleAuthentication(reusableAuth, request, response); + } + private static class Account implements Principal { @Override