Simplify WebSocket authentication failure handling

This commit is contained in:
Jon Chambers 2025-06-10 22:59:53 -04:00 committed by Jon Chambers
parent 626a7fdad7
commit 4f1cab407f
9 changed files with 49 additions and 94 deletions

View File

@ -13,7 +13,7 @@ import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.websocket.ReusableAuth; import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.AuthenticationException; import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.PrincipalSupplier; import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.auth.WebSocketAuthenticator;
@ -22,7 +22,6 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Aut
private static final ReusableAuth<AuthenticatedDevice> CREDENTIALS_NOT_PRESENTED = ReusableAuth.anonymous(); private static final ReusableAuth<AuthenticatedDevice> CREDENTIALS_NOT_PRESENTED = ReusableAuth.anonymous();
private static final ReusableAuth<AuthenticatedDevice> INVALID_CREDENTIALS_PRESENTED = ReusableAuth.invalid();
private final AccountAuthenticator accountAuthenticator; private final AccountAuthenticator accountAuthenticator;
private final PrincipalSupplier<AuthenticatedDevice> principalSupplier; private final PrincipalSupplier<AuthenticatedDevice> principalSupplier;
@ -34,23 +33,17 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Aut
@Override @Override
public ReusableAuth<AuthenticatedDevice> authenticate(final UpgradeRequest request) public ReusableAuth<AuthenticatedDevice> authenticate(final UpgradeRequest request)
throws AuthenticationException { throws InvalidCredentialsException {
try {
return authenticatedAccountFromHeaderAuth(request.getHeader(HttpHeaders.AUTHORIZATION)); @Nullable final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);
} catch (final Exception e) {
// this will be handled and logged upstream
// the most likely exception is a transient error connecting to account storage
throw new AuthenticationException(e);
}
}
private ReusableAuth<AuthenticatedDevice> authenticatedAccountFromHeaderAuth(@Nullable final String authHeader) {
if (authHeader == null) { if (authHeader == null) {
return CREDENTIALS_NOT_PRESENTED; return CREDENTIALS_NOT_PRESENTED;
} }
return basicCredentialsFromAuthHeader(authHeader) return basicCredentialsFromAuthHeader(authHeader)
.flatMap(accountAuthenticator::authenticate) .flatMap(accountAuthenticator::authenticate)
.map(authenticatedAccount -> ReusableAuth.authenticated(authenticatedAccount, this.principalSupplier)) .map(authenticatedAccount -> ReusableAuth.authenticated(authenticatedAccount, this.principalSupplier))
.orElse(INVALID_CREDENTIALS_PRESENTED); .orElseThrow(InvalidCredentialsException::new);
} }
} }

View File

@ -6,6 +6,7 @@
package org.whispersystems.textsecuregcm.websocket; package org.whispersystems.textsecuregcm.websocket;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -26,7 +27,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.websocket.ReusableAuth; import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.PrincipalSupplier; import org.whispersystems.websocket.auth.PrincipalSupplier;
class WebSocketAccountAuthenticatorTest { class WebSocketAccountAuthenticatorTest {
@ -73,10 +74,11 @@ class WebSocketAccountAuthenticatorTest {
accountAuthenticator, accountAuthenticator,
mock(PrincipalSupplier.class)); mock(PrincipalSupplier.class));
final ReusableAuth<AuthenticatedDevice> result = webSocketAuthenticator.authenticate(upgradeRequest); if (expectInvalid) {
assertThrows(InvalidCredentialsException.class, () -> webSocketAuthenticator.authenticate(upgradeRequest));
assertEquals(expectAccount, result.ref().isPresent()); } else {
assertEquals(expectInvalid, result.invalidCredentialsProvided()); assertEquals(expectAccount, webSocketAuthenticator.authenticate(upgradeRequest).ref().isPresent());
}
} }
private static Stream<Arguments> testAuthenticate() { private static Stream<Arguments> testAuthenticate() {

View File

@ -58,10 +58,10 @@ import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
@ -151,7 +151,6 @@ class WebSocketConnectionTest {
when(upgradeRequest.getParameterMap()).thenReturn(Map.of()); when(upgradeRequest.getParameterMap()).thenReturn(Map.of());
account = webSocketAuthenticator.authenticate(upgradeRequest); account = webSocketAuthenticator.authenticate(upgradeRequest);
assertFalse(account.ref().isPresent()); assertFalse(account.ref().isPresent());
assertFalse(account.invalidCredentialsProvided());
connectListener.onWebSocketConnect(sessionContext); connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext, times(2)).addWebsocketClosedListener( verify(sessionContext, times(2)).addWebsocketClosedListener(

View File

@ -59,14 +59,6 @@ public abstract sealed class ReusableAuth<T extends Principal> {
*/ */
public abstract Optional<MutableRef<T>> mutableRef(); public abstract Optional<MutableRef<T>> mutableRef();
public boolean invalidCredentialsProvided() {
return switch (this) {
case Invalid<T> ignored -> true;
case ReusableAuth.Anonymous<T> ignored -> false;
case ReusableAuth.Authenticated<T> ignored-> false;
};
}
/** /**
* @return A {@link ReusableAuth} indicating no credential were provided * @return A {@link ReusableAuth} indicating no credential were provided
*/ */
@ -75,14 +67,6 @@ public abstract sealed class ReusableAuth<T extends Principal> {
return (ReusableAuth<T>) Anonymous.ANON_RESULT; return (ReusableAuth<T>) Anonymous.ANON_RESULT;
} }
/**
* @return A {@link ReusableAuth} indicating that invalid credentials were provided
*/
public static <T extends Principal> ReusableAuth<T> invalid() {
//noinspection unchecked
return (ReusableAuth<T>) Invalid.INVALID_RESULT;
}
/** /**
* Create a successfully authenticated {@link ReusableAuth} * Create a successfully authenticated {@link ReusableAuth}
* *
@ -96,23 +80,6 @@ public abstract sealed class ReusableAuth<T extends Principal> {
return new Authenticated<>(principal, principalSupplier); return new Authenticated<>(principal, principalSupplier);
} }
private static final class Invalid<T extends Principal> extends ReusableAuth<T> {
@SuppressWarnings({"rawtypes"})
private static final ReusableAuth INVALID_RESULT = new Invalid();
@Override
public Optional<T> ref() {
return Optional.empty();
}
@Override
public Optional<MutableRef<T>> mutableRef() {
return Optional.empty();
}
}
private static final class Anonymous<T extends Principal> extends ReusableAuth<T> { private static final class Anonymous<T extends Principal> extends ReusableAuth<T> {
@SuppressWarnings({"rawtypes"}) @SuppressWarnings({"rawtypes"})

View File

@ -22,7 +22,7 @@ import org.glassfish.jersey.CommonProperties;
import org.glassfish.jersey.server.ApplicationHandler; import org.glassfish.jersey.server.ApplicationHandler;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.websocket.auth.AuthenticationException; import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.configuration.WebSocketConfiguration;
@ -45,7 +45,7 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
this.environment = environment; this.environment = environment;
environment.jersey().register(new WebSocketSessionContextValueFactoryProvider.Binder()); environment.jersey().register(new WebSocketSessionContextValueFactoryProvider.Binder());
environment.jersey().register(new WebsocketAuthValueFactoryProvider.Binder<T>(principalClass)); environment.jersey().register(new WebsocketAuthValueFactoryProvider.Binder<>(principalClass));
environment.jersey().register(new JacksonMessageBodyProvider(environment.getObjectMapper())); environment.jersey().register(new JacksonMessageBodyProvider(environment.getObjectMapper()));
// Jersey buffers responses (by default up to 8192 bytes) just so it can add a content length to responses. We // Jersey buffers responses (by default up to 8192 bytes) just so it can add a content length to responses. We
@ -64,17 +64,9 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
try { try {
Optional<WebSocketAuthenticator<T>> authenticator = Optional.ofNullable(environment.getAuthenticator()); Optional<WebSocketAuthenticator<T>> authenticator = Optional.ofNullable(environment.getAuthenticator());
final ReusableAuth<T> authenticated; final ReusableAuth<T> authenticated = authenticator.isPresent()
if (authenticator.isPresent()) { ? authenticator.get().authenticate(request)
authenticated = authenticator.get().authenticate(request); : ReusableAuth.anonymous();
if (authenticated.invalidCredentialsProvided()) {
response.sendForbidden("Unauthorized");
return null;
}
} else {
authenticated = ReusableAuth.anonymous();
}
Optional.ofNullable(environment.getAuthenticatedWebSocketUpgradeFilter()) Optional.ofNullable(environment.getAuthenticatedWebSocketUpgradeFilter())
.ifPresent(filter -> filter.handleAuthentication(authenticated, request, response)); .ifPresent(filter -> filter.handleAuthentication(authenticated, request, response));
@ -87,11 +79,19 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
this.environment.getMessageFactory(), this.environment.getMessageFactory(),
ofNullable(this.environment.getConnectListener()), ofNullable(this.environment.getConnectListener()),
this.environment.getIdleTimeout()); this.environment.getIdleTimeout());
} catch (AuthenticationException | IOException e) { } catch (final InvalidCredentialsException e) {
try {
response.sendForbidden("Unauthorized");
} catch (final IOException ignored) {
}
return null;
} catch (final Exception e) {
// Authentication may fail for non-incorrect-credential reasons (e.g. we couldn't read from the account database).
// If that happens, we don't want to incorrectly tell clients that they provided bad credentials.
logger.warn("Authentication failure", e); logger.warn("Authentication failure", e);
try { try {
response.sendError(500, "Failure"); response.sendError(500, "Failure");
} catch (IOException ignored) { } catch (final IOException ignored) {
} }
return null; return null;
} }

View File

@ -1,17 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket.auth;
public class AuthenticationException extends Exception {
public AuthenticationException(String s) {
super(s);
}
public AuthenticationException(Exception e) {
super(e);
}
}

View File

@ -0,0 +1,12 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket.auth;
public class InvalidCredentialsException extends Exception {
public InvalidCredentialsException() {
super(null, null, true, false);
}
}

View File

@ -5,10 +5,9 @@
package org.whispersystems.websocket.auth; package org.whispersystems.websocket.auth;
import java.security.Principal; import java.security.Principal;
import java.util.Optional;
import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.websocket.ReusableAuth; import org.whispersystems.websocket.ReusableAuth;
public interface WebSocketAuthenticator<T extends Principal> { public interface WebSocketAuthenticator<T extends Principal> {
ReusableAuth<T> authenticate(UpgradeRequest request) throws AuthenticationException; ReusableAuth<T> authenticate(UpgradeRequest request) throws InvalidCredentialsException;
} }

View File

@ -25,7 +25,7 @@ import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory;
import org.glassfish.jersey.server.ResourceConfig; import org.glassfish.jersey.server.ResourceConfig;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.whispersystems.websocket.auth.AuthenticationException; import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.PrincipalSupplier; import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter; import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter;
import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.auth.WebSocketAuthenticator;
@ -55,9 +55,9 @@ public class WebSocketResourceProviderFactoryTest {
} }
@Test @Test
void testUnauthorized() throws AuthenticationException, IOException { void testUnauthorized() throws InvalidCredentialsException, IOException {
when(environment.getAuthenticator()).thenReturn(authenticator); when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenReturn(ReusableAuth.invalid()); when(authenticator.authenticate(eq(request))).thenThrow(new InvalidCredentialsException());
when(environment.jersey()).thenReturn(jerseyEnvironment); when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class, WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
@ -70,7 +70,7 @@ public class WebSocketResourceProviderFactoryTest {
} }
@Test @Test
void testValidAuthorization() throws AuthenticationException { void testValidAuthorization() throws InvalidCredentialsException {
Account account = new Account(); Account account = new Account();
when(environment.getAuthenticator()).thenReturn(authenticator); when(environment.getAuthenticator()).thenReturn(authenticator);
@ -96,9 +96,9 @@ public class WebSocketResourceProviderFactoryTest {
} }
@Test @Test
void testErrorAuthorization() throws AuthenticationException, IOException { void testErrorAuthorization() throws InvalidCredentialsException, IOException {
when(environment.getAuthenticator()).thenReturn(authenticator); when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenThrow(new AuthenticationException("database failure")); when(authenticator.authenticate(eq(request))).thenThrow(new RuntimeException("database failure"));
when(environment.jersey()).thenReturn(jerseyEnvironment); when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory<Account> factory = new WebSocketResourceProviderFactory<>(environment, WebSocketResourceProviderFactory<Account> factory = new WebSocketResourceProviderFactory<>(environment,
@ -127,7 +127,7 @@ public class WebSocketResourceProviderFactoryTest {
} }
@Test @Test
void testAuthenticatedWebSocketUpgradeFilter() throws AuthenticationException { void testAuthenticatedWebSocketUpgradeFilter() throws InvalidCredentialsException {
final Account account = new Account(); final Account account = new Account();
final ReusableAuth<Account> reusableAuth = final ReusableAuth<Account> reusableAuth =
ReusableAuth.authenticated(account, PrincipalSupplier.forImmutablePrincipal()); ReusableAuth.authenticated(account, PrincipalSupplier.forImmutablePrincipal());