Simplify WebSocket authentication failure handling

This commit is contained in:
Jon Chambers 2025-06-10 22:59:53 -04:00 committed by Jon Chambers
parent 626a7fdad7
commit 4f1cab407f
9 changed files with 49 additions and 94 deletions

View File

@ -13,7 +13,7 @@ import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
@ -22,7 +22,6 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Aut
private static final ReusableAuth<AuthenticatedDevice> CREDENTIALS_NOT_PRESENTED = ReusableAuth.anonymous();
private static final ReusableAuth<AuthenticatedDevice> INVALID_CREDENTIALS_PRESENTED = ReusableAuth.invalid();
private final AccountAuthenticator accountAuthenticator;
private final PrincipalSupplier<AuthenticatedDevice> principalSupplier;
@ -34,23 +33,17 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Aut
@Override
public ReusableAuth<AuthenticatedDevice> authenticate(final UpgradeRequest request)
throws AuthenticationException {
try {
return authenticatedAccountFromHeaderAuth(request.getHeader(HttpHeaders.AUTHORIZATION));
} catch (final Exception e) {
// this will be handled and logged upstream
// the most likely exception is a transient error connecting to account storage
throw new AuthenticationException(e);
}
}
throws InvalidCredentialsException {
@Nullable final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);
private ReusableAuth<AuthenticatedDevice> authenticatedAccountFromHeaderAuth(@Nullable final String authHeader) {
if (authHeader == null) {
return CREDENTIALS_NOT_PRESENTED;
}
return basicCredentialsFromAuthHeader(authHeader)
.flatMap(accountAuthenticator::authenticate)
.map(authenticatedAccount -> ReusableAuth.authenticated(authenticatedAccount, this.principalSupplier))
.orElse(INVALID_CREDENTIALS_PRESENTED);
.orElseThrow(InvalidCredentialsException::new);
}
}

View File

@ -6,6 +6,7 @@
package org.whispersystems.textsecuregcm.websocket;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -26,7 +27,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.PrincipalSupplier;
class WebSocketAccountAuthenticatorTest {
@ -73,10 +74,11 @@ class WebSocketAccountAuthenticatorTest {
accountAuthenticator,
mock(PrincipalSupplier.class));
final ReusableAuth<AuthenticatedDevice> result = webSocketAuthenticator.authenticate(upgradeRequest);
assertEquals(expectAccount, result.ref().isPresent());
assertEquals(expectInvalid, result.invalidCredentialsProvided());
if (expectInvalid) {
assertThrows(InvalidCredentialsException.class, () -> webSocketAuthenticator.authenticate(upgradeRequest));
} else {
assertEquals(expectAccount, webSocketAuthenticator.authenticate(upgradeRequest).ref().isPresent());
}
}
private static Stream<Arguments> testAuthenticate() {

View File

@ -58,10 +58,10 @@ import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
@ -151,7 +151,6 @@ class WebSocketConnectionTest {
when(upgradeRequest.getParameterMap()).thenReturn(Map.of());
account = webSocketAuthenticator.authenticate(upgradeRequest);
assertFalse(account.ref().isPresent());
assertFalse(account.invalidCredentialsProvided());
connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext, times(2)).addWebsocketClosedListener(

View File

@ -59,14 +59,6 @@ public abstract sealed class ReusableAuth<T extends Principal> {
*/
public abstract Optional<MutableRef<T>> mutableRef();
public boolean invalidCredentialsProvided() {
return switch (this) {
case Invalid<T> ignored -> true;
case ReusableAuth.Anonymous<T> ignored -> false;
case ReusableAuth.Authenticated<T> ignored-> false;
};
}
/**
* @return A {@link ReusableAuth} indicating no credential were provided
*/
@ -75,14 +67,6 @@ public abstract sealed class ReusableAuth<T extends Principal> {
return (ReusableAuth<T>) Anonymous.ANON_RESULT;
}
/**
* @return A {@link ReusableAuth} indicating that invalid credentials were provided
*/
public static <T extends Principal> ReusableAuth<T> invalid() {
//noinspection unchecked
return (ReusableAuth<T>) Invalid.INVALID_RESULT;
}
/**
* Create a successfully authenticated {@link ReusableAuth}
*
@ -96,23 +80,6 @@ public abstract sealed class ReusableAuth<T extends Principal> {
return new Authenticated<>(principal, principalSupplier);
}
private static final class Invalid<T extends Principal> extends ReusableAuth<T> {
@SuppressWarnings({"rawtypes"})
private static final ReusableAuth INVALID_RESULT = new Invalid();
@Override
public Optional<T> ref() {
return Optional.empty();
}
@Override
public Optional<MutableRef<T>> mutableRef() {
return Optional.empty();
}
}
private static final class Anonymous<T extends Principal> extends ReusableAuth<T> {
@SuppressWarnings({"rawtypes"})

View File

@ -22,7 +22,7 @@ import org.glassfish.jersey.CommonProperties;
import org.glassfish.jersey.server.ApplicationHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
@ -45,7 +45,7 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
this.environment = environment;
environment.jersey().register(new WebSocketSessionContextValueFactoryProvider.Binder());
environment.jersey().register(new WebsocketAuthValueFactoryProvider.Binder<T>(principalClass));
environment.jersey().register(new WebsocketAuthValueFactoryProvider.Binder<>(principalClass));
environment.jersey().register(new JacksonMessageBodyProvider(environment.getObjectMapper()));
// Jersey buffers responses (by default up to 8192 bytes) just so it can add a content length to responses. We
@ -64,17 +64,9 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
try {
Optional<WebSocketAuthenticator<T>> authenticator = Optional.ofNullable(environment.getAuthenticator());
final ReusableAuth<T> authenticated;
if (authenticator.isPresent()) {
authenticated = authenticator.get().authenticate(request);
if (authenticated.invalidCredentialsProvided()) {
response.sendForbidden("Unauthorized");
return null;
}
} else {
authenticated = ReusableAuth.anonymous();
}
final ReusableAuth<T> authenticated = authenticator.isPresent()
? authenticator.get().authenticate(request)
: ReusableAuth.anonymous();
Optional.ofNullable(environment.getAuthenticatedWebSocketUpgradeFilter())
.ifPresent(filter -> filter.handleAuthentication(authenticated, request, response));
@ -87,11 +79,19 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
this.environment.getMessageFactory(),
ofNullable(this.environment.getConnectListener()),
this.environment.getIdleTimeout());
} catch (AuthenticationException | IOException e) {
} catch (final InvalidCredentialsException e) {
try {
response.sendForbidden("Unauthorized");
} catch (final IOException ignored) {
}
return null;
} catch (final Exception e) {
// Authentication may fail for non-incorrect-credential reasons (e.g. we couldn't read from the account database).
// If that happens, we don't want to incorrectly tell clients that they provided bad credentials.
logger.warn("Authentication failure", e);
try {
response.sendError(500, "Failure");
} catch (IOException ignored) {
} catch (final IOException ignored) {
}
return null;
}

View File

@ -1,17 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket.auth;
public class AuthenticationException extends Exception {
public AuthenticationException(String s) {
super(s);
}
public AuthenticationException(Exception e) {
super(e);
}
}

View File

@ -0,0 +1,12 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket.auth;
public class InvalidCredentialsException extends Exception {
public InvalidCredentialsException() {
super(null, null, true, false);
}
}

View File

@ -5,10 +5,9 @@
package org.whispersystems.websocket.auth;
import java.security.Principal;
import java.util.Optional;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.websocket.ReusableAuth;
public interface WebSocketAuthenticator<T extends Principal> {
ReusableAuth<T> authenticate(UpgradeRequest request) throws AuthenticationException;
ReusableAuth<T> authenticate(UpgradeRequest request) throws InvalidCredentialsException;
}

View File

@ -25,7 +25,7 @@ import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory;
import org.glassfish.jersey.server.ResourceConfig;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
@ -55,9 +55,9 @@ public class WebSocketResourceProviderFactoryTest {
}
@Test
void testUnauthorized() throws AuthenticationException, IOException {
void testUnauthorized() throws InvalidCredentialsException, IOException {
when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenReturn(ReusableAuth.invalid());
when(authenticator.authenticate(eq(request))).thenThrow(new InvalidCredentialsException());
when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
@ -70,7 +70,7 @@ public class WebSocketResourceProviderFactoryTest {
}
@Test
void testValidAuthorization() throws AuthenticationException {
void testValidAuthorization() throws InvalidCredentialsException {
Account account = new Account();
when(environment.getAuthenticator()).thenReturn(authenticator);
@ -96,9 +96,9 @@ public class WebSocketResourceProviderFactoryTest {
}
@Test
void testErrorAuthorization() throws AuthenticationException, IOException {
void testErrorAuthorization() throws InvalidCredentialsException, IOException {
when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenThrow(new AuthenticationException("database failure"));
when(authenticator.authenticate(eq(request))).thenThrow(new RuntimeException("database failure"));
when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory<Account> factory = new WebSocketResourceProviderFactory<>(environment,
@ -127,7 +127,7 @@ public class WebSocketResourceProviderFactoryTest {
}
@Test
void testAuthenticatedWebSocketUpgradeFilter() throws AuthenticationException {
void testAuthenticatedWebSocketUpgradeFilter() throws InvalidCredentialsException {
final Account account = new Account();
final ReusableAuth<Account> reusableAuth =
ReusableAuth.authenticated(account, PrincipalSupplier.forImmutablePrincipal());