Add an interceptor interface for WebSocket authentication

This commit is contained in:
Jon Chambers 2025-02-28 09:33:24 -05:00 committed by Jon Chambers
parent 59d984e25d
commit 552079d3c2
5 changed files with 59 additions and 1 deletions

View File

@ -49,7 +49,7 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Aut
return CREDENTIALS_NOT_PRESENTED;
}
return basicCredentialsFromAuthHeader(authHeader)
.flatMap(credentials -> accountAuthenticator.authenticate(credentials))
.flatMap(accountAuthenticator::authenticate)
.map(authenticatedAccount -> ReusableAuth.authenticated(authenticatedAccount, this.principalSupplier))
.orElse(INVALID_CREDENTIALS_PRESENTED);
}

View File

@ -76,6 +76,9 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
authenticated = ReusableAuth.anonymous();
}
Optional.ofNullable(environment.getAuthenticatedWebSocketUpgradeFilter())
.ifPresent(filter -> filter.handleAuthentication(authenticated, request, response));
return new WebSocketResourceProvider<>(getRemoteAddress(request),
remoteAddressPropertyName,
this.jerseyApplicationHandler,

View File

@ -0,0 +1,18 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket.auth;
import java.security.Principal;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.whispersystems.websocket.ReusableAuth;
public interface AuthenticatedWebSocketUpgradeFilter<T extends Principal> {
void handleAuthentication(ReusableAuth<T> authenticated,
JettyServerUpgradeRequest request,
JettyServerUpgradeResponse response);
}

View File

@ -11,11 +11,13 @@ import jakarta.validation.Validator;
import java.security.Principal;
import java.time.Duration;
import org.glassfish.jersey.server.ResourceConfig;
import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.logging.WebsocketRequestLog;
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
import javax.annotation.Nullable;
public class WebSocketEnvironment<T extends Principal> {
@ -26,6 +28,7 @@ public class WebSocketEnvironment<T extends Principal> {
private final Duration idleTimeout;
private WebSocketAuthenticator<T> authenticator;
private AuthenticatedWebSocketUpgradeFilter<T> authenticatedWebSocketUpgradeFilter;
private WebSocketMessageFactory messageFactory;
private WebSocketConnectListener connectListener;
@ -50,6 +53,7 @@ public class WebSocketEnvironment<T extends Principal> {
return jerseyConfig;
}
@Nullable
public WebSocketAuthenticator<T> getAuthenticator() {
return authenticator;
}
@ -58,6 +62,15 @@ public class WebSocketEnvironment<T extends Principal> {
this.authenticator = authenticator;
}
@Nullable
public AuthenticatedWebSocketUpgradeFilter<T> getAuthenticatedWebSocketUpgradeFilter() {
return authenticatedWebSocketUpgradeFilter;
}
public void setAuthenticatedWebSocketUpgradeFilter(final AuthenticatedWebSocketUpgradeFilter<T> authenticatedWebSocketUpgradeFilter) {
this.authenticatedWebSocketUpgradeFilter = authenticatedWebSocketUpgradeFilter;
}
public Duration getIdleTimeout() {
return idleTimeout;
}

View File

@ -27,6 +27,7 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ -125,6 +126,29 @@ public class WebSocketResourceProviderFactoryTest {
verify(servletFactory).setCreator(eq(factory));
}
@Test
void testAuthenticatedWebSocketUpgradeFilter() throws AuthenticationException {
final Account account = new Account();
final ReusableAuth<Account> reusableAuth =
ReusableAuth.authenticated(account, PrincipalSupplier.forImmutablePrincipal());
when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenReturn(reusableAuth);
when(environment.jersey()).thenReturn(jerseyEnvironment);
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1");
when(request.getHttpServletRequest()).thenReturn(httpServletRequest);
final AuthenticatedWebSocketUpgradeFilter<Account> filter = mock(AuthenticatedWebSocketUpgradeFilter.class);
when(environment.getAuthenticatedWebSocketUpgradeFilter()).thenReturn(filter);
final WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME);
assertNotNull(factory.createWebSocket(request, response));
verify(filter).handleAuthentication(reusableAuth, request, response);
}
private static class Account implements Principal {
@Override