From d0fdae3df7227da7774ff5088e8e0ac426a3b65f Mon Sep 17 00:00:00 2001 From: Sergey Skrobotov Date: Mon, 25 Sep 2023 11:28:23 -0700 Subject: [PATCH] Enable header-based auth for WebSocket connections --- ...icCredentialAuthenticationInterceptor.java | 31 ++++------ .../textsecuregcm/util/HeaderUtils.java | 35 +++++++++++ .../WebSocketAccountAuthenticator.java | 61 ++++++++++++++----- ...edentialAuthenticationInterceptorTest.java | 27 ++------ .../WebSocketAccountAuthenticatorTest.java | 51 ++++++++++++---- .../websocket/WebSocketConnectionTest.java | 4 +- .../WebSocketResourceProviderFactory.java | 4 +- .../auth/WebSocketAuthenticator.java | 19 +++--- 8 files changed, 147 insertions(+), 85 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptor.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptor.java index 73a1df2b0..95b66b18a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptor.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptor.java @@ -18,6 +18,7 @@ import java.util.Optional; import org.apache.commons.lang3.StringUtils; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.BaseAccountAuthenticator; +import org.whispersystems.textsecuregcm.util.HeaderUtils; /** * A basic credential authentication interceptor enforces the presence of a valid username and password on every call. @@ -39,7 +40,7 @@ public class BasicCredentialAuthenticationInterceptor implements ServerIntercept @VisibleForTesting static final Metadata.Key BASIC_CREDENTIALS = - Metadata.Key.of("x-signal-basic-auth-credentials", Metadata.ASCII_STRING_MARSHALLER); + Metadata.Key.of("x-signal-auth", Metadata.ASCII_STRING_MARSHALLER); private static final Metadata EMPTY_TRAILERS = new Metadata(); @@ -48,17 +49,20 @@ public class BasicCredentialAuthenticationInterceptor implements ServerIntercept } @Override - public ServerCall.Listener interceptCall(final ServerCall call, + public ServerCall.Listener interceptCall( + final ServerCall call, final Metadata headers, final ServerCallHandler next) { - final String credentialString = headers.get(BASIC_CREDENTIALS); + final String authHeader = headers.get(BASIC_CREDENTIALS); - if (StringUtils.isNotBlank(credentialString)) { - try { - final BasicCredentials credentials = extractBasicCredentials(credentialString); + if (StringUtils.isNotBlank(authHeader)) { + final Optional maybeCredentials = HeaderUtils.basicCredentialsFromAuthHeader(authHeader); + if (maybeCredentials.isEmpty()) { + call.close(Status.UNAUTHENTICATED.withDescription("Could not parse credentials"), EMPTY_TRAILERS); + } else { final Optional maybeAuthenticatedAccount = - baseAccountAuthenticator.authenticate(credentials, false); + baseAccountAuthenticator.authenticate(maybeCredentials.get(), false); if (maybeAuthenticatedAccount.isPresent()) { final AuthenticatedAccount authenticatedAccount = maybeAuthenticatedAccount.get(); @@ -71,8 +75,6 @@ public class BasicCredentialAuthenticationInterceptor implements ServerIntercept } else { call.close(Status.UNAUTHENTICATED.withDescription("Credentials not accepted"), EMPTY_TRAILERS); } - } catch (final IllegalArgumentException e) { - call.close(Status.UNAUTHENTICATED.withDescription("Could not parse credentials"), EMPTY_TRAILERS); } } else { call.close(Status.UNAUTHENTICATED.withDescription("No credentials provided"), EMPTY_TRAILERS); @@ -80,15 +82,4 @@ public class BasicCredentialAuthenticationInterceptor implements ServerIntercept return new ServerCall.Listener<>() {}; } - - @VisibleForTesting - static BasicCredentials extractBasicCredentials(final String credentials) { - if (credentials.indexOf(':') < 0) { - throw new IllegalArgumentException("Credentials do not include a username and password part"); - } - - final String[] pieces = credentials.split(":", 2); - - return new BasicCredentials(pieces[0], pieces[1]); - } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/HeaderUtils.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/HeaderUtils.java index 85c3efc47..89b9e7b3b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/HeaderUtils.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/HeaderUtils.java @@ -7,6 +7,7 @@ package org.whispersystems.textsecuregcm.util; import static java.util.Objects.requireNonNull; +import io.dropwizard.auth.basic.BasicCredentials; import java.nio.charset.StandardCharsets; import java.util.Base64; import java.util.Optional; @@ -63,4 +64,38 @@ public final class HeaderUtils { }) .filter(StringUtils::isNotBlank); } + + /** + * Parses a Base64-encoded value of the `Authorization` header + * in the form of `Basic dXNlcm5hbWU6cGFzc3dvcmQ=`. + * Note: parsing logic is copied from {@link io.dropwizard.auth.basic.BasicCredentialAuthFilter#getCredentials(String)}. + */ + public static Optional basicCredentialsFromAuthHeader(final String authHeader) { + final int space = authHeader.indexOf(' '); + if (space <= 0) { + return Optional.empty(); + } + + final String method = authHeader.substring(0, space); + if (!"Basic".equalsIgnoreCase(method)) { + return Optional.empty(); + } + + final String decoded; + try { + decoded = new String(Base64.getDecoder().decode(authHeader.substring(space + 1)), StandardCharsets.UTF_8); + } catch (IllegalArgumentException e) { + return Optional.empty(); + } + + // Decoded credentials is 'username:password' + final int i = decoded.indexOf(':'); + if (i <= 0) { + return Optional.empty(); + } + + final String username = decoded.substring(0, i); + final String password = decoded.substring(i + 1); + return Optional.of(new BasicCredentials(username, password)); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java index 460c5f498..28f3a00cb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java @@ -1,14 +1,18 @@ /* - * Copyright 2013-2021 Signal Messenger, LLC + * Copyright 2013 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.websocket; +import static org.whispersystems.textsecuregcm.util.HeaderUtils.basicCredentialsFromAuthHeader; + +import com.google.common.net.HttpHeaders; import io.dropwizard.auth.basic.BasicCredentials; import java.util.List; import java.util.Map; import java.util.Optional; +import javax.annotation.Nullable; import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; @@ -18,29 +22,32 @@ import org.whispersystems.websocket.auth.WebSocketAuthenticator; public class WebSocketAccountAuthenticator implements WebSocketAuthenticator { + private static final AuthenticationResult CREDENTIALS_NOT_PRESENTED = + new AuthenticationResult<>(Optional.empty(), false); + + private static final AuthenticationResult INVALID_CREDENTIALS_PRESENTED = + new AuthenticationResult<>(Optional.empty(), true); + private final AccountAuthenticator accountAuthenticator; - public WebSocketAccountAuthenticator(AccountAuthenticator accountAuthenticator) { + + public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator) { this.accountAuthenticator = accountAuthenticator; } @Override - public AuthenticationResult authenticate(UpgradeRequest request) + public AuthenticationResult authenticate(final UpgradeRequest request) throws AuthenticationException { - Map> parameters = request.getParameterMap(); - List usernames = parameters.get("login"); - List passwords = parameters.get("password"); - - if (usernames == null || usernames.size() == 0 || - passwords == null || passwords.size() == 0) { - return new AuthenticationResult<>(Optional.empty(), false); - } - - BasicCredentials credentials = new BasicCredentials(usernames.get(0).replace(" ", "+"), - passwords.get(0).replace(" ", "+")); - try { - return new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true); + final AuthenticationResult authResultFromHeader = + authenticatedAccountFromHeaderAuth(request.getHeader(HttpHeaders.AUTHORIZATION)); + // the logic here is that if the `Authorization` header was set for the request, + // it takes the priority and we use the result of the header-based auth + // ignoring the result of the query-based auth. + if (authResultFromHeader.credentialsPresented()) { + return authResultFromHeader; + } + return authenticatedAccountFromQueryParams(request); } catch (final Exception e) { // this will be handled and logged upstream // the most likely exception is a transient error connecting to account storage @@ -48,4 +55,26 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator authenticatedAccountFromQueryParams(final UpgradeRequest request) { + final Map> parameters = request.getParameterMap(); + final List usernames = parameters.get("login"); + final List passwords = parameters.get("password"); + if (usernames == null || usernames.size() == 0 || + passwords == null || passwords.size() == 0) { + return CREDENTIALS_NOT_PRESENTED; + } + final BasicCredentials credentials = new BasicCredentials(usernames.get(0).replace(" ", "+"), + passwords.get(0).replace(" ", "+")); + return new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true); + } + + private AuthenticationResult authenticatedAccountFromHeaderAuth(@Nullable final String authHeader) + throws AuthenticationException { + if (authHeader == null) { + return CREDENTIALS_NOT_PRESENTED; + } + return basicCredentialsFromAuthHeader(authHeader) + .map(credentials -> new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true)) + .orElse(INVALID_CREDENTIALS_PRESENTED); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptorTest.java index 1ac0d8565..16eadcd35 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptorTest.java @@ -13,7 +13,6 @@ import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import io.dropwizard.auth.basic.BasicCredentials; import io.grpc.CallCredentials; import io.grpc.ManagedChannel; import io.grpc.Metadata; @@ -30,7 +29,6 @@ import java.util.stream.Stream; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -41,6 +39,7 @@ import org.whispersystems.textsecuregcm.auth.BaseAccountAuthenticator; import org.whispersystems.textsecuregcm.grpc.EchoServiceImpl; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.Pair; class BasicCredentialAuthenticationInterceptorTest { @@ -122,8 +121,10 @@ class BasicCredentialAuthenticationInterceptorTest { malformedCredentialHeaders.put(BasicCredentialAuthenticationInterceptor.BASIC_CREDENTIALS, "Incorrect"); final Metadata structurallyValidCredentialHeaders = new Metadata(); - structurallyValidCredentialHeaders.put(BasicCredentialAuthenticationInterceptor.BASIC_CREDENTIALS, - UUID.randomUUID() + ":" + RandomStringUtils.randomAlphanumeric(16)); + structurallyValidCredentialHeaders.put( + BasicCredentialAuthenticationInterceptor.BASIC_CREDENTIALS, + HeaderUtils.basicAuthHeader(UUID.randomUUID().toString(), RandomStringUtils.randomAlphanumeric(16)) + ); return Stream.of( Arguments.of(new Metadata(), true, false), @@ -132,22 +133,4 @@ class BasicCredentialAuthenticationInterceptorTest { Arguments.of(structurallyValidCredentialHeaders, true, true) ); } - - @Test - void extractBasicCredentials() { - final String username = UUID.randomUUID().toString(); - final String password = RandomStringUtils.random(16); - - final BasicCredentials basicCredentials = - BasicCredentialAuthenticationInterceptor.extractBasicCredentials(username + ":" + password); - - assertEquals(username, basicCredentials.getUsername()); - assertEquals(password, basicCredentials.getPassword()); - } - - @Test - void extractBasicCredentialsIllegalArgument() { - assertThrows(IllegalArgumentException.class, - () -> BasicCredentialAuthenticationInterceptor.extractBasicCredentials("This does not include a password")); - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java index 8148e41b3..8601c8ad6 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java @@ -6,17 +6,18 @@ package org.whispersystems.textsecuregcm.websocket; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import com.google.common.net.HttpHeaders; import com.google.i18n.phonenumbers.PhoneNumberUtil; import io.dropwizard.auth.basic.BasicCredentials; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.stream.Stream; +import javax.annotation.Nullable; import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; @@ -26,6 +27,7 @@ import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.websocket.auth.WebSocketAuthenticator; @@ -33,9 +35,12 @@ class WebSocketAccountAuthenticatorTest { private static final String VALID_USER = PhoneNumberUtil.getInstance().format( PhoneNumberUtil.getInstance().getExampleNumber("NZ"), PhoneNumberUtil.PhoneNumberFormat.E164); + private static final String VALID_PASSWORD = "valid"; + private static final String INVALID_USER = PhoneNumberUtil.getInstance().format( PhoneNumberUtil.getInstance().getExampleNumber("AU"), PhoneNumberUtil.PhoneNumberFormat.E164); + private static final String INVALID_PASSWORD = "invalid"; private AccountAuthenticator accountAuthenticator; @@ -57,10 +62,16 @@ class WebSocketAccountAuthenticatorTest { @ParameterizedTest @MethodSource - void testAuthenticate(final Map> upgradeRequestParameters, final boolean expectAccount, - final boolean expectRequired) throws Exception { + void testAuthenticate( + @Nullable final String authorizationHeaderValue, + final Map> upgradeRequestParameters, + final boolean expectAccount, + final boolean expectCredentialsPresented) throws Exception { when(upgradeRequest.getParameterMap()).thenReturn(upgradeRequestParameters); + if (authorizationHeaderValue != null) { + when(upgradeRequest.getHeader(eq(HttpHeaders.AUTHORIZATION))).thenReturn(authorizationHeaderValue); + } final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator( accountAuthenticator); @@ -68,20 +79,34 @@ class WebSocketAccountAuthenticatorTest { final WebSocketAuthenticator.AuthenticationResult result = webSocketAuthenticator.authenticate( upgradeRequest); - if (expectAccount) { - assertTrue(result.getUser().isPresent()); - } else { - assertTrue(result.getUser().isEmpty()); - } - - assertEquals(expectRequired, result.isRequired()); + assertEquals(expectAccount, result.getUser().isPresent()); + assertEquals(expectCredentialsPresented, result.credentialsPresented()); } private static Stream testAuthenticate() { + final Map> paramsMapWithValidAuth = + Map.of("login", List.of(VALID_USER), "password", List.of(VALID_PASSWORD)); + final Map> paramsMapWithInvalidAuth = + Map.of("login", List.of(INVALID_USER), "password", List.of(INVALID_PASSWORD)); + final String headerWithValidAuth = + HeaderUtils.basicAuthHeader(VALID_USER, VALID_PASSWORD); + final String headerWithInvalidAuth = + HeaderUtils.basicAuthHeader(INVALID_USER, INVALID_PASSWORD); return Stream.of( - Arguments.of(Map.of("login", List.of(VALID_USER), "password", List.of(VALID_PASSWORD)), true, true), - Arguments.of(Map.of("login", List.of(INVALID_USER), "password", List.of(INVALID_PASSWORD)), false, true), - Arguments.of(Map.of(), false, false) + // if `Authorization` header is present, outcome should not depend on the value of query parameters + Arguments.of(headerWithValidAuth, Map.of(), true, true), + Arguments.of(headerWithInvalidAuth, Map.of(), false, true), + Arguments.of("invalid header value", Map.of(), false, true), + Arguments.of(headerWithValidAuth, paramsMapWithValidAuth, true, true), + Arguments.of(headerWithInvalidAuth, paramsMapWithValidAuth, false, true), + Arguments.of("invalid header value", paramsMapWithValidAuth, false, true), + Arguments.of(headerWithValidAuth, paramsMapWithInvalidAuth, true, true), + Arguments.of(headerWithInvalidAuth, paramsMapWithInvalidAuth, false, true), + Arguments.of("invalid header value", paramsMapWithInvalidAuth, false, true), + // if `Authorization` header is not set, outcome should match the query params based auth + Arguments.of(null, paramsMapWithValidAuth, true, true), + Arguments.of(null, paramsMapWithInvalidAuth, false, true), + Arguments.of(null, Map.of(), false, false) ); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index a95ae6dff..5076b27af 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2022 Signal Messenger, LLC + * Copyright 2013 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -142,7 +142,7 @@ class WebSocketConnectionTest { when(upgradeRequest.getParameterMap()).thenReturn(Map.of()); account = webSocketAuthenticator.authenticate(upgradeRequest); assertFalse(account.getUser().isPresent()); - assertFalse(account.isRequired()); + assertFalse(account.credentialsPresented()); connectListener.onWebSocketConnect(sessionContext); verify(sessionContext, times(2)).addWebsocketClosedListener( diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java index 8233e2346..e9b6e88a4 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.websocket; @@ -57,7 +57,7 @@ public class WebSocketResourceProviderFactory extends WebSo if (authenticator.isPresent()) { AuthenticationResult authenticationResult = authenticator.get().authenticate(request); - if (authenticationResult.getUser().isEmpty() && authenticationResult.isRequired()) { + if (authenticationResult.getUser().isEmpty() && authenticationResult.credentialsPresented()) { response.sendForbidden("Unauthorized"); return null; } else { diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java index f8d7969fb..4c913fe86 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java @@ -1,33 +1,32 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.websocket.auth; -import org.eclipse.jetty.websocket.api.UpgradeRequest; - import java.security.Principal; import java.util.Optional; +import org.eclipse.jetty.websocket.api.UpgradeRequest; public interface WebSocketAuthenticator { AuthenticationResult authenticate(UpgradeRequest request) throws AuthenticationException; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") - public class AuthenticationResult { + class AuthenticationResult { private final Optional user; - private final boolean required; + private final boolean credentialsPresented; - public AuthenticationResult(Optional user, boolean required) { - this.user = user; - this.required = required; + public AuthenticationResult(final Optional user, final boolean credentialsPresented) { + this.user = user; + this.credentialsPresented = credentialsPresented; } public Optional getUser() { return user; } - public boolean isRequired() { - return required; + public boolean credentialsPresented() { + return credentialsPresented; } } }