From 26ffa19f369c8107b5ca9e6fe9d31c821e9b5d01 Mon Sep 17 00:00:00 2001 From: Ravi Khadiwala Date: Tue, 6 Feb 2024 16:59:42 -0600 Subject: [PATCH] Lifecycle management for Account objects reused accross websocket requests --- .../textsecuregcm/WhisperServerService.java | 3 +- .../auth/AccountAuthenticator.java | 4 +- ...hEnablementRefreshRequirementProvider.java | 51 +-- .../auth/AuthenticatedAccount.java | 14 +- .../auth/ChangesPhoneNumber.java | 20 ++ .../auth/ContainerRequestUtil.java | 33 +- ...umberChangeRefreshRequirementProvider.java | 38 ++- ...socketRefreshApplicationEventListener.java | 2 +- .../controllers/AccountControllerV2.java | 2 + .../storage/AccountPrincipalSupplier.java | 36 ++ .../textsecuregcm/storage/AccountUtil.java | 4 +- .../RefreshingAccountAndDeviceSupplier.java | 35 -- ...> RefreshingAccountNotFoundException.java} | 4 +- .../WebSocketAccountAuthenticator.java | 40 +-- .../WebsocketReuseAuthIntegrationTest.java | 279 ++++++++++++++++ ...blementRefreshRequirementProviderTest.java | 46 +-- .../auth/ContainerRequestUtilTest.java | 55 ++++ ...rChangeRefreshRequirementProviderTest.java | 308 ++++++++++++++---- ...edentialAuthenticationInterceptorTest.java | 2 +- .../DirectoryControllerV2Test.java | 2 +- .../MetricsRequestEventListenerTest.java | 20 +- ...efreshingAccountAndDeviceSupplierTest.java | 71 ---- .../tests/util/TestPrincipal.java | 27 ++ .../tests/util/TestWebsocketListener.java | 79 +++++ .../LoggingUnhandledExceptionMapperTest.java | 18 +- .../WebSocketAccountAuthenticatorTest.java | 26 +- .../WebSocketConnectionIntegrationTest.java | 6 +- .../websocket/WebSocketConnectionTest.java | 21 +- .../websocket/ReusableAuth.java | 182 +++++++++++ .../websocket/WebSocketResourceProvider.java | 70 ++-- .../WebSocketResourceProviderFactory.java | 11 +- .../websocket/auth/Mutable.java | 26 ++ .../websocket/auth/PrincipalSupplier.java | 58 ++++ .../websocket/auth/ReadOnly.java | 25 ++ .../auth/WebSocketAuthenticator.java | 23 +- .../WebsocketAuthValueFactoryProvider.java | 88 +++-- .../WebSocketResourceProviderFactoryTest.java | 10 +- .../WebSocketResourceProviderTest.java | 35 +- 38 files changed, 1317 insertions(+), 457 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/auth/ChangesPhoneNumber.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountPrincipalSupplier.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplier.java rename service/src/main/java/org/whispersystems/textsecuregcm/storage/{RefreshingAccountAndDeviceNotFoundException.java => RefreshingAccountNotFoundException.java} (51%) create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/WebsocketReuseAuthIntegrationTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/auth/ContainerRequestUtilTest.java delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplierTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestPrincipal.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestWebsocketListener.java create mode 100644 websocket-resources/src/main/java/org/whispersystems/websocket/ReusableAuth.java create mode 100644 websocket-resources/src/main/java/org/whispersystems/websocket/auth/Mutable.java create mode 100644 websocket-resources/src/main/java/org/whispersystems/websocket/auth/PrincipalSupplier.java create mode 100644 websocket-resources/src/main/java/org/whispersystems/websocket/auth/ReadOnly.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 413f5d549..eb8b680d5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -188,6 +188,7 @@ import org.whispersystems.textsecuregcm.spam.SenderOverrideProvider; import org.whispersystems.textsecuregcm.spam.SpamChecker; import org.whispersystems.textsecuregcm.spam.SpamFilter; import org.whispersystems.textsecuregcm.storage.AccountLockManager; +import org.whispersystems.textsecuregcm.storage.AccountPrincipalSupplier; import org.whispersystems.textsecuregcm.storage.Accounts; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.ChangeNumberManager; @@ -812,7 +813,7 @@ public class WhisperServerService extends Application webSocketEnvironment = new WebSocketEnvironment<>(environment, config.getWebSocketConfiguration(), Duration.ofMillis(90000)); webSocketEnvironment.jersey().register(new VirtualExecutorServiceProvider("managed-async-websocket-virtual-thread-")); - webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator)); + webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator, new AccountPrincipalSupplier(accountsManager))); webSocketEnvironment.setConnectListener( new AuthenticatedConnectListener(receiptSender, messagesManager, pushNotificationManager, clientPresenceManager, websocketScheduledExecutor, messageDeliveryScheduler, clientReleaseManager)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AccountAuthenticator.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AccountAuthenticator.java index 0bcdf074e..c6ef61f39 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AccountAuthenticator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AccountAuthenticator.java @@ -21,7 +21,6 @@ import org.apache.commons.lang3.StringUtils; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.storage.RefreshingAccountAndDeviceSupplier; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Util; @@ -108,8 +107,7 @@ public class AccountAuthenticator implements Authenticator buildDevicesEnabledMap(final Account account) { - return account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::isEnabled)); - } @Override public void handleRequestFiltered(final RequestEvent requestEvent) { @@ -60,10 +55,13 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres setAccount(requestEvent.getContainerRequest(), account)); } } - public static void setAccount(final ContainerRequest containerRequest, final Account account) { - containerRequest.setProperty(ACCOUNT_UUID, account.getUuid()); - containerRequest.setProperty(DEVICES_ENABLED, buildDevicesEnabledMap(account)); + setAccount(containerRequest, ContainerRequestUtil.AccountInfo.fromAccount(account)); + } + + private static void setAccount(final ContainerRequest containerRequest, final ContainerRequestUtil.AccountInfo info) { + containerRequest.setProperty(ACCOUNT_UUID, info.accountId()); + containerRequest.setProperty(DEVICES_ENABLED, info.devicesEnabled()); } @Override @@ -75,25 +73,28 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres @SuppressWarnings("unchecked") final Map initialDevicesEnabled = (Map) requestEvent.getContainerRequest().getProperty(DEVICES_ENABLED); - return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID)).map(account -> { - final Set deviceIdsToDisplace; - final Map currentDevicesEnabled = buildDevicesEnabledMap(account); + return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID)) + .map(ContainerRequestUtil.AccountInfo::fromAccount) + .map(account -> { + final Set deviceIdsToDisplace; + final Map currentDevicesEnabled = account.devicesEnabled(); - if (!initialDevicesEnabled.equals(currentDevicesEnabled)) { - deviceIdsToDisplace = new HashSet<>(initialDevicesEnabled.keySet()); - deviceIdsToDisplace.addAll(currentDevicesEnabled.keySet()); - } else { - deviceIdsToDisplace = Collections.emptySet(); - } + if (!initialDevicesEnabled.equals(currentDevicesEnabled)) { + deviceIdsToDisplace = new HashSet<>(initialDevicesEnabled.keySet()); + deviceIdsToDisplace.addAll(currentDevicesEnabled.keySet()); + } else { + deviceIdsToDisplace = Collections.emptySet(); + } - return deviceIdsToDisplace.stream() - .map(deviceId -> new Pair<>(account.getUuid(), deviceId)) - .collect(Collectors.toList()); - }).orElseGet(() -> { - logger.error("Request had account, but it is no longer present"); - return Collections.emptyList(); - }); - } else + return deviceIdsToDisplace.stream() + .map(deviceId -> new Pair<>(account.accountId(), deviceId)) + .collect(Collectors.toList()); + }).orElseGet(() -> { + logger.error("Request had account, but it is no longer present"); + return Collections.emptyList(); + }); + } else { return Collections.emptyList(); + } } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedAccount.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedAccount.java index 13c9e504c..4fdca7be4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedAccount.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedAccount.java @@ -10,24 +10,24 @@ import java.util.function.Supplier; import javax.security.auth.Subject; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.util.Pair; public class AuthenticatedAccount implements Principal, AccountAndAuthenticatedDeviceHolder { + private final Account account; + private final Device device; - private final Supplier> accountAndDevice; - - public AuthenticatedAccount(final Supplier> accountAndDevice) { - this.accountAndDevice = accountAndDevice; + public AuthenticatedAccount(final Account account, final Device device) { + this.account = account; + this.device = device; } @Override public Account getAccount() { - return accountAndDevice.get().first(); + return account; } @Override public Device getAuthenticatedDevice() { - return accountAndDevice.get().second(); + return device; } // Principal implementation diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/ChangesPhoneNumber.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/ChangesPhoneNumber.java new file mode 100644 index 000000000..b1defdf9d --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/ChangesPhoneNumber.java @@ -0,0 +1,20 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.auth; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Indicates that an endpoint changes the phone number and PNI keys associated with an account, and that + * any websockets associated with the account may need to be refreshed after a call to that endpoint. + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +public @interface ChangesPhoneNumber { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/ContainerRequestUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/ContainerRequestUtil.java index f551b9761..eb2ecf029 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/ContainerRequestUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/ContainerRequestUtil.java @@ -7,15 +7,42 @@ package org.whispersystems.textsecuregcm.auth; import org.glassfish.jersey.server.ContainerRequest; import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Device; import javax.ws.rs.core.SecurityContext; +import java.util.Map; import java.util.Optional; +import java.util.UUID; +import java.util.stream.Collectors; class ContainerRequestUtil { - static Optional getAuthenticatedAccount(final ContainerRequest request) { + private static Map buildDevicesEnabledMap(final Account account) { + return account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::isEnabled)); + } + + /** + * A read-only subset of the authenticated Account object, to enforce that filter-based consumers do not perform + * account modifying operations. + */ + record AccountInfo(UUID accountId, String e164, Map devicesEnabled) { + + static AccountInfo fromAccount(final Account account) { + return new AccountInfo( + account.getUuid(), + account.getNumber(), + buildDevicesEnabledMap(account)); + } + } + + static Optional getAuthenticatedAccount(final ContainerRequest request) { return Optional.ofNullable(request.getSecurityContext()) .map(SecurityContext::getUserPrincipal) - .map(principal -> principal instanceof AccountAndAuthenticatedDeviceHolder - ? ((AccountAndAuthenticatedDeviceHolder) principal).getAccount() : null); + .map(principal -> { + if (principal instanceof AccountAndAuthenticatedDeviceHolder aaadh) { + return aaadh.getAccount(); + } + return null; + }) + .map(AccountInfo::fromAccount); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProvider.java index 522a49964..db1c35ae9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProvider.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProvider.java @@ -7,40 +7,50 @@ package org.whispersystems.textsecuregcm.auth; import java.util.Collections; import java.util.List; -import java.util.Optional; import java.util.UUID; import java.util.stream.Collectors; import org.glassfish.jersey.server.monitoring.RequestEvent; -import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.util.Pair; public class PhoneNumberChangeRefreshRequirementProvider implements WebsocketRefreshRequirementProvider { + private static final String ACCOUNT_UUID = + PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".accountUuid"; + private static final String INITIAL_NUMBER_KEY = PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".initialNumber"; + private final AccountsManager accountsManager; + + public PhoneNumberChangeRefreshRequirementProvider(final AccountsManager accountsManager) { + this.accountsManager = accountsManager; + } @Override public void handleRequestFiltered(final RequestEvent requestEvent) { + if (requestEvent.getUriInfo().getMatchedResourceMethod().getInvocable().getHandlingMethod() + .getAnnotation(ChangesPhoneNumber.class) == null) { + return; + } ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest()) - .ifPresent(account -> requestEvent.getContainerRequest().setProperty(INITIAL_NUMBER_KEY, account.getNumber())); + .ifPresent(account -> { + requestEvent.getContainerRequest().setProperty(INITIAL_NUMBER_KEY, account.e164()); + requestEvent.getContainerRequest().setProperty(ACCOUNT_UUID, account.accountId()); + }); } @Override public List> handleRequestFinished(final RequestEvent requestEvent) { final String initialNumber = (String) requestEvent.getContainerRequest().getProperty(INITIAL_NUMBER_KEY); - if (initialNumber != null) { - final Optional maybeAuthenticatedAccount = - ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest()); - - return maybeAuthenticatedAccount - .filter(account -> !initialNumber.equals(account.getNumber())) - .map(account -> account.getDevices().stream() - .map(device -> new Pair<>(account.getUuid(), device.getId())) - .collect(Collectors.toList())) - .orElse(Collections.emptyList()); - } else { + if (initialNumber == null) { return Collections.emptyList(); } + return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID)) + .filter(account -> !initialNumber.equals(account.getNumber())) + .map(account -> account.getDevices().stream() + .map(device -> new Pair<>(account.getUuid(), device.getId())) + .collect(Collectors.toList())) + .orElse(Collections.emptyList()); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java index ad7ffeb9c..4ad8c67ca 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java @@ -24,7 +24,7 @@ public class WebsocketRefreshApplicationEventListener implements ApplicationEven this.websocketRefreshRequestEventListener = new WebsocketRefreshRequestEventListener(clientPresenceManager, new AuthEnablementRefreshRequirementProvider(accountsManager), - new PhoneNumberChangeRefreshRequirementProvider()); + new PhoneNumberChangeRefreshRequirementProvider(accountsManager)); } @Override diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java index 7f0c0a760..7025557f0 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java @@ -36,6 +36,7 @@ import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.ChangesPhoneNumber; import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager; import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; import org.whispersystems.textsecuregcm.entities.AccountDataReportResponse; @@ -90,6 +91,7 @@ public class AccountControllerV2 { @Path("/number") @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) + @ChangesPhoneNumber @Operation(summary = "Change number", description = "Changes a phone number for an existing account.") @ApiResponse(responseCode = "200", description = "The phone number associated with the authenticated account was changed successfully", useReturnTypeSchema = true) @ApiResponse(responseCode = "401", description = "Account authentication check failed.") diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountPrincipalSupplier.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountPrincipalSupplier.java new file mode 100644 index 000000000..12d20a3ee --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountPrincipalSupplier.java @@ -0,0 +1,36 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.storage; + +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.websocket.auth.PrincipalSupplier; + +public class AccountPrincipalSupplier implements PrincipalSupplier { + + private final AccountsManager accountsManager; + + public AccountPrincipalSupplier(final AccountsManager accountsManager) { + this.accountsManager = accountsManager; + } + + @Override + public AuthenticatedAccount refresh(final AuthenticatedAccount oldAccount) { + final Account account = accountsManager.getByAccountIdentifier(oldAccount.getAccount().getUuid()) + .orElseThrow(() -> new RefreshingAccountNotFoundException("Could not find account")); + final Device device = account.getDevice(oldAccount.getAuthenticatedDevice().getId()) + .orElseThrow(() -> new RefreshingAccountNotFoundException("Could not find device")); + return new AuthenticatedAccount(account, device); + } + + @Override + public AuthenticatedAccount deepCopy(final AuthenticatedAccount authenticatedAccount) { + final Account cloned = AccountUtil.cloneAccountAsNotStale(authenticatedAccount.getAccount()); + return new AuthenticatedAccount( + cloned, + cloned.getDevice(authenticatedAccount.getAuthenticatedDevice().getId()) + .orElseThrow(() -> new IllegalStateException( + "Could not find device from a clone of an account where the device was present"))); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountUtil.java index cffdeea13..53d826b94 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountUtil.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountUtil.java @@ -5,10 +5,12 @@ package org.whispersystems.textsecuregcm.storage; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.util.SystemMapper; +import org.whispersystems.websocket.auth.PrincipalSupplier; import java.io.IOException; -class AccountUtil { +public class AccountUtil { static Account cloneAccountAsNotStale(final Account account) { try { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplier.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplier.java deleted file mode 100644 index fa0556e42..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplier.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright 2021 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.storage; - -import java.util.function.Supplier; -import org.whispersystems.textsecuregcm.util.Pair; - -public class RefreshingAccountAndDeviceSupplier implements Supplier> { - - private Account account; - private Device device; - private final AccountsManager accountsManager; - - public RefreshingAccountAndDeviceSupplier(Account account, byte deviceId, AccountsManager accountsManager) { - this.account = account; - this.device = account.getDevice(deviceId) - .orElseThrow(() -> new RefreshingAccountAndDeviceNotFoundException("Could not find device")); - this.accountsManager = accountsManager; - } - - @Override - public Pair get() { - if (account.isStale()) { - account = accountsManager.getByAccountIdentifier(account.getUuid()) - .orElseThrow(() -> new RuntimeException("Could not find account")); - device = account.getDevice(device.getId()) - .orElseThrow(() -> new RefreshingAccountAndDeviceNotFoundException("Could not find device")); - } - - return new Pair<>(account, device); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceNotFoundException.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountNotFoundException.java similarity index 51% rename from service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceNotFoundException.java rename to service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountNotFoundException.java index 0ea57a0ad..768435b5e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceNotFoundException.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountNotFoundException.java @@ -5,9 +5,9 @@ package org.whispersystems.textsecuregcm.storage; -public class RefreshingAccountAndDeviceNotFoundException extends RuntimeException { +public class RefreshingAccountNotFoundException extends RuntimeException { - public RefreshingAccountAndDeviceNotFoundException(final String message) { + public RefreshingAccountNotFoundException(final String message) { super(message); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java index 28f3a00cb..c03826dcb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java @@ -11,41 +11,40 @@ import com.google.common.net.HttpHeaders; import io.dropwizard.auth.basic.BasicCredentials; import java.util.List; import java.util.Map; -import java.util.Optional; import javax.annotation.Nullable; import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.websocket.ReusableAuth; import org.whispersystems.websocket.auth.AuthenticationException; +import org.whispersystems.websocket.auth.PrincipalSupplier; import org.whispersystems.websocket.auth.WebSocketAuthenticator; public class WebSocketAccountAuthenticator implements WebSocketAuthenticator { - private static final AuthenticationResult CREDENTIALS_NOT_PRESENTED = - new AuthenticationResult<>(Optional.empty(), false); + private static final ReusableAuth CREDENTIALS_NOT_PRESENTED = ReusableAuth.anonymous(); - private static final AuthenticationResult INVALID_CREDENTIALS_PRESENTED = - new AuthenticationResult<>(Optional.empty(), true); + private static final ReusableAuth INVALID_CREDENTIALS_PRESENTED = ReusableAuth.invalid(); private final AccountAuthenticator accountAuthenticator; + private final PrincipalSupplier principalSupplier; - - public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator) { + public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator, + final PrincipalSupplier principalSupplier) { this.accountAuthenticator = accountAuthenticator; + this.principalSupplier = principalSupplier; } @Override - public AuthenticationResult authenticate(final UpgradeRequest request) + public ReusableAuth authenticate(final UpgradeRequest request) throws AuthenticationException { try { - final AuthenticationResult authResultFromHeader = - authenticatedAccountFromHeaderAuth(request.getHeader(HttpHeaders.AUTHORIZATION)); - // the logic here is that if the `Authorization` header was set for the request, - // it takes the priority and we use the result of the header-based auth - // ignoring the result of the query-based auth. - if (authResultFromHeader.credentialsPresented()) { - return authResultFromHeader; + // If the `Authorization` header was set for the request it takes priority, and we use the result of the + // header-based auth ignoring the result of the query-based auth. + final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION); + if (authHeader != null) { + return authenticatedAccountFromHeaderAuth(authHeader); } return authenticatedAccountFromQueryParams(request); } catch (final Exception e) { @@ -55,7 +54,7 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator authenticatedAccountFromQueryParams(final UpgradeRequest request) { + private ReusableAuth authenticatedAccountFromQueryParams(final UpgradeRequest request) { final Map> parameters = request.getParameterMap(); final List usernames = parameters.get("login"); final List passwords = parameters.get("password"); @@ -65,16 +64,19 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator(accountAuthenticator.authenticate(credentials), true); + return accountAuthenticator.authenticate(credentials) + .map(authenticatedAccount -> ReusableAuth.authenticated(authenticatedAccount, this.principalSupplier)) + .orElse(INVALID_CREDENTIALS_PRESENTED); } - private AuthenticationResult authenticatedAccountFromHeaderAuth(@Nullable final String authHeader) + private ReusableAuth authenticatedAccountFromHeaderAuth(@Nullable final String authHeader) throws AuthenticationException { if (authHeader == null) { return CREDENTIALS_NOT_PRESENTED; } return basicCredentialsFromAuthHeader(authHeader) - .map(credentials -> new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true)) + .flatMap(credentials -> accountAuthenticator.authenticate(credentials)) + .map(authenticatedAccount -> ReusableAuth.authenticated(authenticatedAccount, this.principalSupplier)) .orElse(INVALID_CREDENTIALS_PRESENTED); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketReuseAuthIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketReuseAuthIntegrationTest.java new file mode 100644 index 000000000..18ee672df --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketReuseAuthIntegrationTest.java @@ -0,0 +1,279 @@ +package org.whispersystems.textsecuregcm; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; +import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME; + +import io.dropwizard.auth.Auth; +import io.dropwizard.core.Application; +import io.dropwizard.core.Configuration; +import io.dropwizard.core.setup.Environment; +import io.dropwizard.testing.junit5.DropwizardAppExtension; +import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; +import java.io.IOException; +import java.net.URI; +import java.util.EnumSet; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.stream.IntStream; +import javax.servlet.DispatcherType; +import javax.servlet.ServletRegistration; +import javax.ws.rs.GET; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; +import org.glassfish.jersey.server.ManagedAsync; +import org.glassfish.jersey.server.ServerProperties; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; +import org.whispersystems.textsecuregcm.storage.RefreshingAccountNotFoundException; +import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener; +import org.whispersystems.websocket.ReusableAuth; +import org.whispersystems.websocket.WebSocketResourceProviderFactory; +import org.whispersystems.websocket.auth.PrincipalSupplier; +import org.whispersystems.websocket.auth.ReadOnly; +import org.whispersystems.websocket.configuration.WebSocketConfiguration; +import org.whispersystems.websocket.messages.WebSocketResponseMessage; +import org.whispersystems.websocket.setup.WebSocketEnvironment; + +@ExtendWith(DropwizardExtensionsSupport.class) +public class WebsocketReuseAuthIntegrationTest { + + private static final AuthenticatedAccount ACCOUNT = mock(AuthenticatedAccount.class); + @SuppressWarnings("unchecked") + private static final PrincipalSupplier PRINCIPAL_SUPPLIER = mock(PrincipalSupplier.class); + private static final DropwizardAppExtension DROPWIZARD_APP_EXTENSION = + new DropwizardAppExtension<>(TestApplication.class); + + + private WebSocketClient client; + + @BeforeEach + void setUp() throws Exception { + reset(PRINCIPAL_SUPPLIER); + reset(ACCOUNT); + when(ACCOUNT.getName()).thenReturn("original"); + client = new WebSocketClient(); + client.start(); + } + + @AfterEach + void tearDown() throws Exception { + client.stop(); + } + + + public static class TestApplication extends Application { + + @Override + public void run(final Configuration configuration, final Environment environment) throws Exception { + final TestController testController = new TestController(); + + final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration(); + + final WebSocketEnvironment webSocketEnvironment = + new WebSocketEnvironment<>(environment, webSocketConfiguration); + + environment.jersey().register(testController); + environment.servlets() + .addFilter("RemoteAddressFilter", new RemoteAddressFilter(true)) + .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); + webSocketEnvironment.jersey().register(testController); + webSocketEnvironment.jersey().register(new RemoteAddressFilter(true)); + webSocketEnvironment.setAuthenticator(upgradeRequest -> ReusableAuth.authenticated(ACCOUNT, PRINCIPAL_SUPPLIER)); + + webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE); + webSocketEnvironment.setConnectListener(webSocketSessionContext -> { + }); + + final WebSocketResourceProviderFactory webSocketServlet = + new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedAccount.class, + webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME); + + JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); + + final ServletRegistration.Dynamic websocketServlet = + environment.servlets().addServlet("WebSocket", webSocketServlet); + + websocketServlet.addMapping("/websocket"); + websocketServlet.setAsyncSupported(true); + } + } + + private WebSocketResponseMessage make1WebsocketRequest(final String requestPath) throws IOException { + + final TestWebsocketListener testWebsocketListener = new TestWebsocketListener(); + + client.connect(testWebsocketListener, + URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort()))); + return testWebsocketListener.doGet(requestPath).join(); + } + + @ParameterizedTest + @ValueSource(strings = {"/test/read-auth", "/test/optional-read-auth"}) + public void readAuth(final String path) throws IOException { + final WebSocketResponseMessage response = make1WebsocketRequest(path); + assertThat(response.getStatus()).isEqualTo(200); + verifyNoMoreInteractions(PRINCIPAL_SUPPLIER); + } + + @ParameterizedTest + @ValueSource(strings = {"/test/write-auth", "/test/optional-write-auth"}) + public void writeAuth(final String path) throws IOException { + final AuthenticatedAccount copiedAccount = mock(AuthenticatedAccount.class); + when(copiedAccount.getName()).thenReturn("copy"); + when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(copiedAccount); + + final WebSocketResponseMessage response = make1WebsocketRequest(path); + assertThat(response.getStatus()).isEqualTo(200); + assertThat(response.getBody().map(String::new)).get().isEqualTo("copy"); + verify(PRINCIPAL_SUPPLIER, times(1)).deepCopy(any()); + verifyNoMoreInteractions(PRINCIPAL_SUPPLIER); + } + + @Test + public void readAfterWrite() throws IOException { + when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(ACCOUNT); + final AuthenticatedAccount account2 = mock(AuthenticatedAccount.class); + when(account2.getName()).thenReturn("refresh"); + when(PRINCIPAL_SUPPLIER.refresh(any())).thenReturn(account2); + + final TestWebsocketListener testWebsocketListener = new TestWebsocketListener(); + client.connect(testWebsocketListener, + URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort()))); + + final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/read-auth").join(); + assertThat(readResponse.getBody().map(String::new)).get().isEqualTo("original"); + + final WebSocketResponseMessage writeResponse = testWebsocketListener.doGet("/test/write-auth").join(); + assertThat(writeResponse.getBody().map(String::new)).get().isEqualTo("original"); + + final WebSocketResponseMessage readResponse2 = testWebsocketListener.doGet("/test/read-auth").join(); + assertThat(readResponse2.getBody().map(String::new)).get().isEqualTo("refresh"); + } + + @Test + public void readAfterWriteRefreshFails() throws IOException { + when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(ACCOUNT); + when(PRINCIPAL_SUPPLIER.refresh(any())).thenThrow(RefreshingAccountNotFoundException.class); + + final TestWebsocketListener testWebsocketListener = new TestWebsocketListener(); + client.connect(testWebsocketListener, + URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort()))); + + final WebSocketResponseMessage writeResponse = testWebsocketListener.doGet("/test/write-auth").join(); + assertThat(writeResponse.getBody().map(String::new)).get().isEqualTo("original"); + + final WebSocketResponseMessage readResponse2 = testWebsocketListener.doGet("/test/read-auth").join(); + assertThat(readResponse2.getStatus()).isEqualTo(500); + } + + @Test + public void readConcurrentWithWrite() throws IOException, ExecutionException, InterruptedException, TimeoutException { + final AuthenticatedAccount deepCopy = mock(AuthenticatedAccount.class); + when(deepCopy.getName()).thenReturn("deepCopy"); + when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(deepCopy); + + final AuthenticatedAccount refresh = mock(AuthenticatedAccount.class); + when(refresh.getName()).thenReturn("refresh"); + when(PRINCIPAL_SUPPLIER.refresh(any())).thenReturn(refresh); + + final TestWebsocketListener testWebsocketListener = new TestWebsocketListener(); + client.connect(testWebsocketListener, + URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort()))); + + // start a write request that takes a while to finish + final CompletableFuture writeResponse = + testWebsocketListener.doGet("/test/start-delayed-write/foo"); + + // send a bunch of reads, they should reflect the original auth + final List> futures = IntStream.range(0, 10) + .boxed().map(i -> testWebsocketListener.doGet("/test/read-auth")) + .toList(); + CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join(); + for (CompletableFuture future : futures) { + assertThat(future.join().getBody().map(String::new)).get().isEqualTo("original"); + } + + assertThat(writeResponse.isDone()).isFalse(); + + // finish the delayed write request + testWebsocketListener.doGet("/test/finish-delayed-write/foo").get(1, TimeUnit.SECONDS); + assertThat(writeResponse.join().getBody().map(String::new)).get().isEqualTo("deepCopy"); + + // subsequent reads should have the refreshed auth + final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/read-auth").join(); + assertThat(readResponse.getBody().map(String::new)).get().isEqualTo("refresh"); + } + + + @Path("/test") + public static class TestController { + + private final ConcurrentHashMap delayedWriteLatches = new ConcurrentHashMap<>(); + + @GET + @Path("/read-auth") + @ManagedAsync + public String readAuth(@ReadOnly @Auth final AuthenticatedAccount account) { + return account.getName(); + } + + @GET + @Path("/optional-read-auth") + @ManagedAsync + public String optionalReadAuth(@ReadOnly @Auth final Optional account) { + return account.map(AuthenticatedAccount::getName).orElse("empty"); + } + + @GET + @Path("/write-auth") + @ManagedAsync + public String writeAuth(@Auth final AuthenticatedAccount account) { + return account.getName(); + } + + @GET + @Path("/optional-write-auth") + @ManagedAsync + public String optionalWriteAuth(@Auth final Optional account) { + return account.map(AuthenticatedAccount::getName).orElse("empty"); + } + + @GET + @Path("/start-delayed-write/{id}") + @ManagedAsync + public String startDelayedWrite(@Auth final AuthenticatedAccount account, @PathParam("id") String id) + throws InterruptedException { + delayedWriteLatches.computeIfAbsent(id, i -> new CountDownLatch(1)).await(); + return account.getName(); + } + + @GET + @Path("/finish-delayed-write/{id}") + @ManagedAsync + public String finishDelayedWrite(@PathParam("id") String id) { + delayedWriteLatches.computeIfAbsent(id, i -> new CountDownLatch(1)).countDown(); + return "ok"; + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java index 253a54b7c..bc3a913e2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java @@ -7,8 +7,6 @@ package org.whispersystems.textsecuregcm.auth; import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -30,7 +28,6 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.Principal; import java.time.Duration; -import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; import java.util.LinkedList; @@ -76,7 +73,9 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; import org.whispersystems.textsecuregcm.util.SystemMapper; +import org.whispersystems.websocket.ReusableAuth; import org.whispersystems.websocket.WebSocketResourceProvider; +import org.whispersystems.websocket.auth.PrincipalSupplier; import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; import org.whispersystems.websocket.logging.WebsocketRequestLog; import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory; @@ -132,38 +131,6 @@ class AuthEnablementRefreshRequirementProviderTest { .forEach(device -> when(clientPresenceManager.isPresent(uuid, device.getId())).thenReturn(true)); } - @Test - void testBuildDevicesEnabled() { - - final byte disabledDeviceId = 3; - - final Account account = mock(Account.class); - - final List devices = new ArrayList<>(); - when(account.getDevices()).thenReturn(devices); - - IntStream.range(1, 5) - .forEach(id -> { - final Device device = mock(Device.class); - when(device.getId()).thenReturn((byte) id); - when(device.isEnabled()).thenReturn(id != disabledDeviceId); - devices.add(device); - }); - - final Map devicesEnabled = AuthEnablementRefreshRequirementProvider.buildDevicesEnabledMap(account); - - assertEquals(4, devicesEnabled.size()); - - assertAll(devicesEnabled.entrySet().stream() - .map(deviceAndEnabled -> () -> { - if (deviceAndEnabled.getKey().equals(disabledDeviceId)) { - assertFalse(deviceAndEnabled.getValue()); - } else { - assertTrue(deviceAndEnabled.getValue()); - } - })); - } - @ParameterizedTest @MethodSource void testDeviceEnabledChanged(final Map initialEnabled, final Map finalEnabled) { @@ -308,7 +275,7 @@ class AuthEnablementRefreshRequirementProviderTest { WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); provider = new WebSocketResourceProvider<>("127.0.0.1", RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, - applicationHandler, requestLog, new TestPrincipal("test", account, authenticatedDevice), + applicationHandler, requestLog, TestPrincipal.reusableAuth("test", account, authenticatedDevice), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); remoteEndpoint = mock(RemoteEndpoint.class); @@ -349,7 +316,7 @@ class AuthEnablementRefreshRequirementProviderTest { private final Account account; private final Device device; - private TestPrincipal(String name, final Account account, final Device device) { + private TestPrincipal(final String name, final Account account, final Device device) { this.name = name; this.account = account; this.device = device; @@ -369,6 +336,11 @@ class AuthEnablementRefreshRequirementProviderTest { public Device getAuthenticatedDevice() { return device; } + + public static ReusableAuth reusableAuth(final String name, final Account account, final Device device) { + return ReusableAuth.authenticated(new TestPrincipal(name, account, device), PrincipalSupplier.forImmutablePrincipal()); + } + } @Path("/v1/test") diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/ContainerRequestUtilTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/ContainerRequestUtilTest.java new file mode 100644 index 000000000..e0a87d3f0 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/ContainerRequestUtilTest.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.auth; + +import org.junit.jupiter.api.Test; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Device; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; + +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ContainerRequestUtilTest { + + @Test + void testBuildDevicesEnabled() { + + final byte disabledDeviceId = 3; + + final Account account = mock(Account.class); + + final List devices = new ArrayList<>(); + when(account.getDevices()).thenReturn(devices); + + IntStream.range(1, 5) + .forEach(id -> { + final Device device = mock(Device.class); + when(device.getId()).thenReturn((byte) id); + when(device.isEnabled()).thenReturn(id != disabledDeviceId); + devices.add(device); + }); + + final Map devicesEnabled = ContainerRequestUtil.AccountInfo.fromAccount(account).devicesEnabled(); + + assertEquals(4, devicesEnabled.size()); + + assertAll(devicesEnabled.entrySet().stream() + .map(deviceAndEnabled -> () -> { + if (deviceAndEnabled.getKey().equals(disabledDeviceId)) { + assertFalse(deviceAndEnabled.getValue()); + } else { + assertTrue(deviceAndEnabled.getValue()); + } + })); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java index 2341c5ee6..148bf788f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java @@ -5,108 +5,292 @@ package org.whispersystems.textsecuregcm.auth; -import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.doAnswer; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME; +import com.google.common.net.HttpHeaders; +import io.dropwizard.auth.Auth; +import io.dropwizard.auth.AuthDynamicFeature; +import io.dropwizard.auth.basic.BasicCredentialAuthFilter; +import io.dropwizard.core.Application; +import io.dropwizard.core.Configuration; +import io.dropwizard.core.setup.Environment; +import io.dropwizard.testing.junit5.DropwizardAppExtension; +import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; +import java.io.IOException; +import java.net.URI; import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.EnumSet; +import java.util.Optional; import java.util.UUID; -import javax.annotation.Nullable; -import javax.ws.rs.core.SecurityContext; -import org.glassfish.jersey.server.ContainerRequest; -import org.glassfish.jersey.server.monitoring.RequestEvent; +import javax.servlet.DispatcherType; +import javax.servlet.ServletRegistration; +import javax.ws.rs.GET; +import javax.ws.rs.Path; +import javax.ws.rs.client.Invocation; +import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; +import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; +import org.glassfish.jersey.server.ManagedAsync; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; +import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.util.Pair; +import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; +import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener; +import org.whispersystems.textsecuregcm.util.HeaderUtils; +import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator; +import org.whispersystems.websocket.WebSocketResourceProviderFactory; +import org.whispersystems.websocket.auth.PrincipalSupplier; +import org.whispersystems.websocket.auth.ReadOnly; +import org.whispersystems.websocket.configuration.WebSocketConfiguration; +import org.whispersystems.websocket.setup.WebSocketEnvironment; +@ExtendWith(DropwizardExtensionsSupport.class) class PhoneNumberChangeRefreshRequirementProviderTest { - private PhoneNumberChangeRefreshRequirementProvider provider; - - private Account account; - private RequestEvent requestEvent; - private ContainerRequest request; - - private static final UUID ACCOUNT_UUID = UUID.randomUUID(); private static final String NUMBER = "+18005551234"; private static final String CHANGED_NUMBER = "+18005554321"; + private static final String TEST_CRED_HEADER = HeaderUtils.basicAuthHeader("test", "password"); + + + private static final DropwizardAppExtension DROPWIZARD_APP_EXTENSION = new DropwizardAppExtension<>( + TestApplication.class); + + private static final AccountAuthenticator AUTHENTICATOR = mock(AccountAuthenticator.class); + private static final AccountsManager ACCOUNTS_MANAGER = mock(AccountsManager.class); + private static final ClientPresenceManager CLIENT_PRESENCE = mock(ClientPresenceManager.class); + + private WebSocketClient client; + private final Account account1 = new Account(); + private final Account account2 = new Account(); + private final Device authenticatedDevice = DevicesHelper.createDevice(Device.PRIMARY_ID); + @BeforeEach - void setUp() { - provider = new PhoneNumberChangeRefreshRequirementProvider(); + void setUp() throws Exception { + reset(AUTHENTICATOR, CLIENT_PRESENCE, ACCOUNTS_MANAGER); + client = new WebSocketClient(); + client.start(); - account = mock(Account.class); - final Device device = mock(Device.class); + final UUID uuid = UUID.randomUUID(); + account1.setUuid(uuid); + account1.addDevice(authenticatedDevice); + account1.setNumber(NUMBER, UUID.randomUUID()); - when(account.getUuid()).thenReturn(ACCOUNT_UUID); - when(account.getNumber()).thenReturn(NUMBER); - when(account.getDevices()).thenReturn(List.of(device)); - when(device.getId()).thenReturn(Device.PRIMARY_ID); + account2.setUuid(uuid); + account2.addDevice(authenticatedDevice); + account2.setNumber(CHANGED_NUMBER, UUID.randomUUID()); - request = mock(ContainerRequest.class); + } - final Map requestProperties = new HashMap<>(); + @AfterEach + void tearDown() throws Exception { + client.stop(); + } - doAnswer(invocation -> { - requestProperties.put(invocation.getArgument(0, String.class), invocation.getArgument(1)); - return null; - }).when(request).setProperty(anyString(), any()); - when(request.getProperty(anyString())).thenAnswer( - invocation -> requestProperties.get(invocation.getArgument(0, String.class))); + public static class TestApplication extends Application { - requestEvent = mock(RequestEvent.class); - when(requestEvent.getContainerRequest()).thenReturn(request); + @Override + public void run(final Configuration configuration, final Environment environment) throws Exception { + final TestController testController = new TestController(); + + final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration(); + + final WebSocketEnvironment webSocketEnvironment = + new WebSocketEnvironment<>(environment, webSocketConfiguration); + + environment.jersey().register(testController); + webSocketEnvironment.jersey().register(testController); + environment.servlets() + .addFilter("RemoteAddressFilter", new RemoteAddressFilter(true)) + .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); + webSocketEnvironment.jersey().register(new RemoteAddressFilter(true)); + webSocketEnvironment.jersey() + .register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE)); + environment.jersey() + .register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE)); + webSocketEnvironment.setConnectListener(webSocketSessionContext -> { + }); + + + environment.jersey().register(new AuthDynamicFeature(new BasicCredentialAuthFilter.Builder() + .setAuthenticator(AUTHENTICATOR) + .buildAuthFilter())); + webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(AUTHENTICATOR, mock(PrincipalSupplier.class))); + + final WebSocketResourceProviderFactory webSocketServlet = + new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedAccount.class, + webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME); + + JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); + + final ServletRegistration.Dynamic websocketServlet = + environment.servlets().addServlet("WebSocket", webSocketServlet); + + websocketServlet.addMapping("/websocket"); + websocketServlet.setAsyncSupported(true); + } + } + + enum Protocol { HTTP, WEBSOCKET } + + private void makeAnonymousRequest(final Protocol protocol, final String requestPath) throws IOException { + makeRequest(protocol, requestPath, true); + } + + /* + * Make an authenticated request that will return account1 as the principal + */ + private void makeAuthenticatedRequest( + final Protocol protocol, + final String requestPath) throws IOException { + when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedAccount(account1, authenticatedDevice))); + makeRequest(protocol,requestPath, false); + } + + private void makeRequest(final Protocol protocol, final String requestPath, final boolean anonymous) throws IOException { + switch (protocol) { + case WEBSOCKET -> { + final TestWebsocketListener testWebsocketListener = new TestWebsocketListener(); + final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest(); + if (!anonymous) { + upgradeRequest.setHeader(HttpHeaders.AUTHORIZATION, TEST_CRED_HEADER); + } + client.connect( + testWebsocketListener, + URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())), + upgradeRequest); + testWebsocketListener.sendRequest(requestPath, "GET", Collections.emptyList(), Optional.empty()).join(); + } + case HTTP -> { + final Invocation.Builder request = DROPWIZARD_APP_EXTENSION.client() + .target("http://127.0.0.1:%s%s".formatted(DROPWIZARD_APP_EXTENSION.getLocalPort(), requestPath)) + .request(); + if (!anonymous) { + request.header(HttpHeaders.AUTHORIZATION, TEST_CRED_HEADER); + } + request.get(); + } + } + } + + @ParameterizedTest + @EnumSource(Protocol.class) + void handleRequestNoChange(final Protocol protocol) throws IOException { + when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account1)); + makeAuthenticatedRequest(protocol, "/test/annotated"); + + // Event listeners can fire after responses are sent + verify(ACCOUNTS_MANAGER, timeout(5000).times(1)).getByAccountIdentifier(eq(account1.getUuid())); + verifyNoMoreInteractions(CLIENT_PRESENCE); + verifyNoMoreInteractions(ACCOUNTS_MANAGER); + } + + @ParameterizedTest + @EnumSource(Protocol.class) + void handleRequestChange(final Protocol protocol) throws IOException { + when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account2)); + when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedAccount(account1, authenticatedDevice))); + + makeAuthenticatedRequest(protocol, "/test/annotated"); + + // Make sure we disconnect the account if the account has changed numbers. Event listeners can fire after responses + // are sent, so use a timeout. + verify(CLIENT_PRESENCE, timeout(5000)) + .disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId())); + verifyNoMoreInteractions(CLIENT_PRESENCE); } @Test - void handleRequestNoChange() { - setAuthenticatedAccount(request, account); + void handleRequestChangeAsyncEndpoint() throws IOException { + when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account2)); + when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedAccount(account1, authenticatedDevice))); - provider.handleRequestFiltered(requestEvent); - assertEquals(Collections.emptyList(), provider.handleRequestFinished(requestEvent)); + // Event listeners with asynchronous HTTP endpoints don't currently correctly maintain state between request and + // response + makeAuthenticatedRequest(Protocol.WEBSOCKET, "/test/async-annotated"); + + // Make sure we disconnect the account if the account has changed numbers. Event listeners can fire after responses + // are sent, so use a timeout. + verify(CLIENT_PRESENCE, timeout(5000)) + .disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId())); + verifyNoMoreInteractions(CLIENT_PRESENCE); } - @Test - void handleRequestNumberChange() { - setAuthenticatedAccount(request, account); + @ParameterizedTest + @EnumSource(Protocol.class) + void handleRequestNotAnnotated(final Protocol protocol) throws IOException, InterruptedException { + makeAuthenticatedRequest(protocol,"/test/not-annotated"); - provider.handleRequestFiltered(requestEvent); - when(account.getNumber()).thenReturn(CHANGED_NUMBER); - assertEquals(List.of(new Pair<>(ACCOUNT_UUID, Device.PRIMARY_ID)), provider.handleRequestFinished(requestEvent)); + // Give a tick for event listeners to run. Racy, but should occasionally catch an errant running listener if one is + // introduced. + Thread.sleep(100); + + // Shouldn't even read the account if the method has not been annotated + verifyNoMoreInteractions(ACCOUNTS_MANAGER); + verifyNoMoreInteractions(CLIENT_PRESENCE); } - @Test - void handleRequestNoAuthenticatedAccount() { - final ContainerRequest request = mock(ContainerRequest.class); - setAuthenticatedAccount(request, null); + @ParameterizedTest + @EnumSource(Protocol.class) + void handleRequestNotAuthenticated(final Protocol protocol) throws IOException, InterruptedException { + makeAnonymousRequest(protocol, "/test/not-authenticated"); - when(requestEvent.getContainerRequest()).thenReturn(request); + // Give a tick for event listeners to run. Racy, but should occasionally catch an errant running listener if one is + // introduced. + Thread.sleep(100); - provider.handleRequestFiltered(requestEvent); - assertEquals(Collections.emptyList(), provider.handleRequestFinished(requestEvent)); + // Shouldn't even read the account if the method has not been annotated + verifyNoMoreInteractions(ACCOUNTS_MANAGER); + verifyNoMoreInteractions(CLIENT_PRESENCE); } - private static void setAuthenticatedAccount(final ContainerRequest mockRequest, @Nullable final Account account) { - final SecurityContext securityContext = mock(SecurityContext.class); - when(mockRequest.getSecurityContext()).thenReturn(securityContext); + @Path("/test") + public static class TestController { - if (account != null) { - final AuthenticatedAccount authenticatedAccount = mock(AuthenticatedAccount.class); + @GET + @Path("/annotated") + @ChangesPhoneNumber + public String annotated(@ReadOnly @Auth final AuthenticatedAccount account) { + return "ok"; + } - when(securityContext.getUserPrincipal()).thenReturn(authenticatedAccount); - when(authenticatedAccount.getAccount()).thenReturn(account); - } else { - when(securityContext.getUserPrincipal()).thenReturn(null); + @GET + @Path("/async-annotated") + @ChangesPhoneNumber + @ManagedAsync + public String asyncAnnotated(@ReadOnly @Auth final AuthenticatedAccount account) { + return "ok"; + } + + @GET + @Path("/not-authenticated") + @ChangesPhoneNumber + public String notAuthenticated() { + return "ok"; + } + + @GET + @Path("/not-annotated") + public String notAnnotated(@ReadOnly @Auth final AuthenticatedAccount account) { + return "ok"; } } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptorTest.java index 7bda74261..8bf1a184a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/grpc/BasicCredentialAuthenticationInterceptorTest.java @@ -87,7 +87,7 @@ class BasicCredentialAuthenticationInterceptorTest { when(device.getId()).thenReturn(Device.PRIMARY_ID); when(accountAuthenticator.authenticate(any())) - .thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(account, device)))); + .thenReturn(Optional.of(new AuthenticatedAccount(account, device))); } else { when(accountAuthenticator.authenticate(any())) .thenReturn(Optional.empty()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DirectoryControllerV2Test.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DirectoryControllerV2Test.java index 296bd0920..50f1efc9b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DirectoryControllerV2Test.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DirectoryControllerV2Test.java @@ -39,7 +39,7 @@ class DirectoryControllerV2Test { when(account.getUuid()).thenReturn(uuid); final ExternalServiceCredentials credentials = (ExternalServiceCredentials) controller.getAuthToken( - new AuthenticatedAccount(() -> new Pair<>(account, mock(Device.class)))).getEntity(); + new AuthenticatedAccount(account, mock(Device.class))).getEntity(); assertEquals(credentials.username(), "d369bc712e2e0dd36258"); assertEquals(credentials.password(), "1633738643:4433b0fab41f25f79dd4"); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java index a712a07ef..f93bd08b9 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java @@ -51,7 +51,10 @@ import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; +import org.whispersystems.textsecuregcm.tests.util.TestPrincipal; +import org.whispersystems.websocket.ReusableAuth; import org.whispersystems.websocket.WebSocketResourceProvider; +import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; import org.whispersystems.websocket.logging.WebsocketRequestLog; import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory; @@ -139,7 +142,7 @@ class MetricsRequestEventListenerTest { final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); final WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), + RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.reusableAuth("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); final Session session = mock(Session.class); @@ -201,7 +204,7 @@ class MetricsRequestEventListenerTest { final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); final WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), + RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.reusableAuth("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); final Session session = mock(Session.class); @@ -252,19 +255,6 @@ class MetricsRequestEventListenerTest { return SubProtocol.WebSocketMessage.parseFrom(responseCaptor.getValue().array()).getResponse(); } - public static class TestPrincipal implements Principal { - - private final String name; - - private TestPrincipal(String name) { - this.name = name; - } - - @Override - public String getName() { - return name; - } - } @Path("/v1/test") public static class TestResource { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplierTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplierTest.java deleted file mode 100644 index 4c30456c6..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplierTest.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright 2021 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.storage; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotSame; -import static org.junit.jupiter.api.Assertions.assertSame; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import java.util.Optional; -import java.util.UUID; -import org.junit.jupiter.api.Test; -import org.whispersystems.textsecuregcm.util.Pair; - -class RefreshingAccountAndDeviceSupplierTest { - - @Test - void test() { - - final AccountsManager accountsManager = mock(AccountsManager.class); - - final UUID uuid = UUID.randomUUID(); - final byte deviceId = 2; - - final Account initialAccount = mock(Account.class); - final Device initialDevice = mock(Device.class); - - when(initialAccount.getUuid()).thenReturn(uuid); - when(initialDevice.getId()).thenReturn(deviceId); - when(initialAccount.getDevice(deviceId)).thenReturn(Optional.of(initialDevice)); - - when(accountsManager.getByAccountIdentifier(any(UUID.class))).thenAnswer(answer -> { - final Account account = mock(Account.class); - final Device device = mock(Device.class); - - when(account.getUuid()).thenReturn(answer.getArgument(0, UUID.class)); - when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); - when(device.getId()).thenReturn(deviceId); - - return Optional.of(account); - }); - - final RefreshingAccountAndDeviceSupplier refreshingAccountAndDeviceSupplier = new RefreshingAccountAndDeviceSupplier( - initialAccount, deviceId, accountsManager); - - Pair accountAndDevice = refreshingAccountAndDeviceSupplier.get(); - - assertSame(initialAccount, accountAndDevice.first()); - assertSame(initialDevice, accountAndDevice.second()); - - accountAndDevice = refreshingAccountAndDeviceSupplier.get(); - - assertSame(initialAccount, accountAndDevice.first()); - assertSame(initialDevice, accountAndDevice.second()); - - when(initialAccount.isStale()).thenReturn(true); - - accountAndDevice = refreshingAccountAndDeviceSupplier.get(); - - assertNotSame(initialAccount, accountAndDevice.first()); - assertNotSame(initialDevice, accountAndDevice.second()); - - assertEquals(uuid, accountAndDevice.first().getUuid()); - } - -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestPrincipal.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestPrincipal.java new file mode 100644 index 000000000..201e10a50 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestPrincipal.java @@ -0,0 +1,27 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.tests.util; + +import java.security.Principal; +import org.whispersystems.websocket.ReusableAuth; +import org.whispersystems.websocket.auth.PrincipalSupplier; + +public class TestPrincipal implements Principal { + + private final String name; + + private TestPrincipal(String name) { + this.name = name; + } + + @Override + public String getName() { + return name; + } + + public static ReusableAuth reusableAuth(final String name) { + return ReusableAuth.authenticated(new TestPrincipal(name), PrincipalSupplier.forImmutablePrincipal()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestWebsocketListener.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestWebsocketListener.java new file mode 100644 index 000000000..a010776b6 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestWebsocketListener.java @@ -0,0 +1,79 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.tests.util; + +import org.eclipse.jetty.websocket.api.Session; +import org.eclipse.jetty.websocket.api.WebSocketListener; +import org.whispersystems.websocket.messages.WebSocketMessage; +import org.whispersystems.websocket.messages.WebSocketMessageFactory; +import org.whispersystems.websocket.messages.WebSocketResponseMessage; +import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +public class TestWebsocketListener implements WebSocketListener { + + private final AtomicLong requestId = new AtomicLong(); + private final CompletableFuture started = new CompletableFuture<>(); + private final ConcurrentHashMap> responseFutures = new ConcurrentHashMap<>(); + private final WebSocketMessageFactory messageFactory; + + public TestWebsocketListener() { + this.messageFactory = new ProtobufWebSocketMessageFactory(); + } + + + @Override + public void onWebSocketConnect(final Session session) { + started.complete(session); + + } + + public CompletableFuture doGet(final String requestPath) { + return sendRequest(requestPath, "GET", List.of("Accept: application/json"), Optional.empty()); + } + + public CompletableFuture sendRequest( + final String requestPath, + final String verb, + final List headers, + final Optional body) { + return started.thenCompose(session -> { + final long id = requestId.incrementAndGet(); + final CompletableFuture future = new CompletableFuture<>(); + responseFutures.put(id, future); + final byte[] requestBytes = messageFactory.createRequest( + Optional.of(id), verb, requestPath, headers, body).toByteArray(); + try { + session.getRemote().sendBytes(ByteBuffer.wrap(requestBytes)); + } catch (IOException e) { + throw new RuntimeException(e); + } + return future; + }); + } + + @Override + public void onWebSocketBinary(final byte[] payload, final int offset, final int length) { + try { + WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length); + if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) { + responseFutures.get(webSocketMessage.getResponseMessage().getRequestId()) + .complete(webSocketMessage.getResponseMessage()); + } else { + throw new RuntimeException("Unexpected message type: " + webSocketMessage.getType()); + } + } catch (final Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java index 207d4dfbe..58e52ca34 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java @@ -57,6 +57,7 @@ import org.junit.jupiter.params.provider.MethodSource; import org.slf4j.Logger; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper; +import org.whispersystems.textsecuregcm.tests.util.TestPrincipal; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.websocket.WebSocketResourceProvider; import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; @@ -175,7 +176,8 @@ class LoggingUnhandledExceptionMapperTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), + RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, + TestPrincipal.reusableAuth("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); @@ -238,18 +240,4 @@ class LoggingUnhandledExceptionMapperTest { throw new RuntimeException(); } } - - public static class TestPrincipal implements Principal { - - private final String name; - - private TestPrincipal(String name) { - this.name = name; - } - - @Override - public String getName() { - return name; - } - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java index 8601c8ad6..b1305091f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java @@ -28,8 +28,8 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.HeaderUtils; -import org.whispersystems.textsecuregcm.util.Pair; -import org.whispersystems.websocket.auth.WebSocketAuthenticator; +import org.whispersystems.websocket.ReusableAuth; +import org.whispersystems.websocket.auth.PrincipalSupplier; class WebSocketAccountAuthenticatorTest { @@ -52,7 +52,7 @@ class WebSocketAccountAuthenticatorTest { accountAuthenticator = mock(AccountAuthenticator.class); when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) - .thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(mock(Account.class), mock(Device.class))))); + .thenReturn(Optional.of(new AuthenticatedAccount(mock(Account.class), mock(Device.class)))); when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD)))) .thenReturn(Optional.empty()); @@ -66,7 +66,7 @@ class WebSocketAccountAuthenticatorTest { @Nullable final String authorizationHeaderValue, final Map> upgradeRequestParameters, final boolean expectAccount, - final boolean expectCredentialsPresented) throws Exception { + final boolean expectInvalid) throws Exception { when(upgradeRequest.getParameterMap()).thenReturn(upgradeRequestParameters); if (authorizationHeaderValue != null) { @@ -74,13 +74,13 @@ class WebSocketAccountAuthenticatorTest { } final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator( - accountAuthenticator); + accountAuthenticator, + mock(PrincipalSupplier.class)); - final WebSocketAuthenticator.AuthenticationResult result = webSocketAuthenticator.authenticate( - upgradeRequest); + final ReusableAuth result = webSocketAuthenticator.authenticate(upgradeRequest); - assertEquals(expectAccount, result.getUser().isPresent()); - assertEquals(expectCredentialsPresented, result.credentialsPresented()); + assertEquals(expectAccount, result.ref().isPresent()); + assertEquals(expectInvalid, result.invalidCredentialsProvided()); } private static Stream testAuthenticate() { @@ -94,17 +94,17 @@ class WebSocketAccountAuthenticatorTest { HeaderUtils.basicAuthHeader(INVALID_USER, INVALID_PASSWORD); return Stream.of( // if `Authorization` header is present, outcome should not depend on the value of query parameters - Arguments.of(headerWithValidAuth, Map.of(), true, true), + Arguments.of(headerWithValidAuth, Map.of(), true, false), Arguments.of(headerWithInvalidAuth, Map.of(), false, true), Arguments.of("invalid header value", Map.of(), false, true), - Arguments.of(headerWithValidAuth, paramsMapWithValidAuth, true, true), + Arguments.of(headerWithValidAuth, paramsMapWithValidAuth, true, false), Arguments.of(headerWithInvalidAuth, paramsMapWithValidAuth, false, true), Arguments.of("invalid header value", paramsMapWithValidAuth, false, true), - Arguments.of(headerWithValidAuth, paramsMapWithInvalidAuth, true, true), + Arguments.of(headerWithValidAuth, paramsMapWithInvalidAuth, true, false), Arguments.of(headerWithInvalidAuth, paramsMapWithInvalidAuth, false, true), Arguments.of("invalid header value", paramsMapWithInvalidAuth, false, true), // if `Authorization` header is not set, outcome should match the query params based auth - Arguments.of(null, paramsMapWithValidAuth, true, true), + Arguments.of(null, paramsMapWithValidAuth, true, false), Arguments.of(null, paramsMapWithInvalidAuth, false, true), Arguments.of(null, Map.of(), false, false) ); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java index d13ddc5ee..56560996a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java @@ -125,7 +125,7 @@ class WebSocketConnectionIntegrationTest { final WebSocketConnection webSocketConnection = new WebSocketConnection( mock(ReceiptSender.class), new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), - new AuthenticatedAccount(() -> new Pair<>(account, device)), + new AuthenticatedAccount(account, device), device, webSocketClient, scheduledExecutorService, @@ -210,7 +210,7 @@ class WebSocketConnectionIntegrationTest { final WebSocketConnection webSocketConnection = new WebSocketConnection( mock(ReceiptSender.class), new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), - new AuthenticatedAccount(() -> new Pair<>(account, device)), + new AuthenticatedAccount(account, device), device, webSocketClient, scheduledExecutorService, @@ -276,7 +276,7 @@ class WebSocketConnectionIntegrationTest { final WebSocketConnection webSocketConnection = new WebSocketConnection( mock(ReceiptSender.class), new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), - new AuthenticatedAccount(() -> new Pair<>(account, device)), + new AuthenticatedAccount(account, device), device, webSocketClient, 100, // use a very short timeout, so that this test completes quickly diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index fc03a4bbd..479dde23e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -64,9 +64,9 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; -import org.whispersystems.textsecuregcm.util.Pair; +import org.whispersystems.websocket.ReusableAuth; import org.whispersystems.websocket.WebSocketClient; -import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult; +import org.whispersystems.websocket.auth.PrincipalSupplier; import org.whispersystems.websocket.messages.WebSocketResponseMessage; import org.whispersystems.websocket.session.WebSocketSessionContext; import reactor.core.publisher.Flux; @@ -101,7 +101,7 @@ class WebSocketConnectionTest { accountsManager = mock(AccountsManager.class); account = mock(Account.class); device = mock(Device.class); - auth = new AuthenticatedAccount(() -> new Pair<>(account, device)); + auth = new AuthenticatedAccount(account, device); upgradeRequest = mock(UpgradeRequest.class); messagesManager = mock(MessagesManager.class); receiptSender = mock(ReceiptSender.class); @@ -118,18 +118,19 @@ class WebSocketConnectionTest { @Test void testCredentials() throws Exception { - WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator); + WebSocketAccountAuthenticator webSocketAuthenticator = + new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class)); AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager, mock(PushNotificationManager.class), mock(ClientPresenceManager.class), retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager); WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) - .thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(account, device)))); + .thenReturn(Optional.of(new AuthenticatedAccount(account, device))); - AuthenticationResult account = webSocketAuthenticator.authenticate(upgradeRequest); - when(sessionContext.getAuthenticated()).thenReturn(account.getUser().orElse(null)); - when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.getUser().orElse(null)); + ReusableAuth account = webSocketAuthenticator.authenticate(upgradeRequest); + when(sessionContext.getAuthenticated()).thenReturn(account.ref().orElse(null)); + when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.ref().orElse(null)); final WebSocketClient webSocketClient = mock(WebSocketClient.class); when(webSocketClient.getUserAgent()).thenReturn("Signal-Android/6.22.8"); @@ -144,8 +145,8 @@ class WebSocketConnectionTest { // unauthenticated when(upgradeRequest.getParameterMap()).thenReturn(Map.of()); account = webSocketAuthenticator.authenticate(upgradeRequest); - assertFalse(account.getUser().isPresent()); - assertFalse(account.credentialsPresented()); + assertFalse(account.ref().isPresent()); + assertFalse(account.invalidCredentialsProvided()); connectListener.onWebSocketConnect(sessionContext); verify(sessionContext, times(2)).addWebsocketClosedListener( diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/ReusableAuth.java b/websocket-resources/src/main/java/org/whispersystems/websocket/ReusableAuth.java new file mode 100644 index 000000000..f6e1e79c1 --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/ReusableAuth.java @@ -0,0 +1,182 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.websocket; + +import java.security.Principal; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import org.whispersystems.websocket.auth.PrincipalSupplier; + +/** + * This class holds a principal that can be reused across requests on a websocket. Since two requests may operate + * concurrently on the same principal, and some principals contain non thread-safe mutable state, appropriate use of + * this class ensures that no data races occur. It also ensures that after a principal is modified, a subsequent request + * gets the up-to-date principal + * + * @param The underlying principal type + * @see PrincipalSupplier + */ +public abstract sealed class ReusableAuth { + + /** + * Get a reference to the underlying principal that callers pledge not to modify. + *

+ * The reference returned will potentially be provided to many threads concurrently accessing the principal. Callers + * should use this method only if they can ensure that they will not modify the in-memory principal object AND they do + * not intend to modify the underlying canonical representation of the principal. + *

+ * For example, if a caller retrieves a reference to a principal, does not modify the in memory state, but updates a + * field on a database that should be reflected in subsequent retrievals of the principal, they will have met the + * first criteria, but not the second. In that case they should instead use {@link #mutableRef()}. + *

+ * If other callers have modified the underlying principal by using {@link #mutableRef()}, this method may need to + * refresh the principal via {@link PrincipalSupplier#refresh} which could be a blocking operation. + * + * @return If authenticated, a reference to the underlying principal that should not be modified + */ + public abstract Optional ref(); + + + public interface MutableRef { + + T ref(); + + void close(); + } + + /** + * Get a reference to the underlying principal that may be modified. + *

+ * The underlying principal can be safely modified. Multiple threads may operate on the same {@link ReusableAuth} so + * long as they each have their own {@link MutableRef}. After any modifications, the caller must call + * {@link MutableRef#close} to notify the principal has become dirty. Close should be called after modifications but + * before sending a response on the websocket. This ensures that a request that comes in after a modification response + * is received is guaranteed to see the modification. + * + * @return If authenticated, a reference to the underlying principal that may be modified + */ + public abstract Optional> mutableRef(); + + public boolean invalidCredentialsProvided() { + return switch (this) { + case Invalid ignored -> true; + case ReusableAuth.Anonymous ignored -> false; + case ReusableAuth.Authenticated ignored-> false; + }; + } + + /** + * @return A {@link ReusableAuth} indicating no credential were provided + */ + public static ReusableAuth anonymous() { + //noinspection unchecked + return (ReusableAuth) Anonymous.ANON_RESULT; + } + + /** + * @return A {@link ReusableAuth} indicating that invalid credentials were provided + */ + public static ReusableAuth invalid() { + //noinspection unchecked + return (ReusableAuth) Invalid.INVALID_RESULT; + } + + /** + * Create a successfully authenticated {@link ReusableAuth} + * + * @param principal The authenticated principal + * @param principalSupplier Instructions for how to refresh or copy this principal + * @param The principal type + * @return A {@link ReusableAuth} for a successfully authenticated principal + */ + public static ReusableAuth authenticated(T principal, + PrincipalSupplier principalSupplier) { + return new Authenticated<>(principal, principalSupplier); + } + + + private static final class Invalid extends ReusableAuth { + + @SuppressWarnings({"rawtypes"}) + private static final ReusableAuth INVALID_RESULT = new Invalid(); + + @Override + public Optional ref() { + return Optional.empty(); + } + + @Override + public Optional> mutableRef() { + return Optional.empty(); + } + } + + private static final class Anonymous extends ReusableAuth { + + @SuppressWarnings({"rawtypes"}) + private static final ReusableAuth ANON_RESULT = new Anonymous(); + + @Override + public Optional ref() { + return Optional.empty(); + } + + @Override + public Optional> mutableRef() { + return Optional.empty(); + } + } + + private static final class Authenticated extends ReusableAuth { + + private T basePrincipal; + private final AtomicBoolean needRefresh = new AtomicBoolean(false); + private final PrincipalSupplier principalSupplier; + + Authenticated(final T basePrincipal, PrincipalSupplier principalSupplier) { + this.basePrincipal = basePrincipal; + this.principalSupplier = principalSupplier; + + } + + @Override + public Optional ref() { + maybeRefresh(); + return Optional.of(basePrincipal); + } + + @Override + public Optional> mutableRef() { + maybeRefresh(); + return Optional.of(new AuthenticatedMutableRef(principalSupplier.deepCopy(basePrincipal))); + } + + private void maybeRefresh() { + if (needRefresh.compareAndSet(true, false)) { + basePrincipal = principalSupplier.refresh(basePrincipal); + } + } + + private class AuthenticatedMutableRef implements MutableRef { + + final T ref; + + private AuthenticatedMutableRef(T ref) { + this.ref = ref; + } + + public T ref() { + return ref; + } + + public void close() { + needRefresh.set(true); + } + } + } + + private ReusableAuth() { + } +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java index 006fa5333..3677365d9 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java @@ -58,7 +58,7 @@ public class WebSocketResourceProvider implements WebSocket private final Map> requestMap = new ConcurrentHashMap<>(); - private final T authenticated; + private final ReusableAuth reusableAuth; private final WebSocketMessageFactory messageFactory; private final Optional connectListener; private final ApplicationHandler jerseyHandler; @@ -77,7 +77,7 @@ public class WebSocketResourceProvider implements WebSocket String remoteAddressPropertyName, ApplicationHandler jerseyHandler, WebsocketRequestLog requestLog, - T authenticated, + ReusableAuth authenticated, WebSocketMessageFactory messageFactory, Optional connectListener, Duration idleTimeout) { @@ -85,7 +85,7 @@ public class WebSocketResourceProvider implements WebSocket this.remoteAddressPropertyName = remoteAddressPropertyName; this.jerseyHandler = jerseyHandler; this.requestLog = requestLog; - this.authenticated = authenticated; + this.reusableAuth = authenticated; this.messageFactory = messageFactory; this.connectListener = connectListener; this.idleTimeout = idleTimeout; @@ -97,7 +97,7 @@ public class WebSocketResourceProvider implements WebSocket this.remoteEndpoint = session.getRemote(); this.context = new WebSocketSessionContext( new WebSocketClient(session, remoteEndpoint, messageFactory, requestMap)); - this.context.setAuthenticated(authenticated); + this.context.setAuthenticated(reusableAuth.ref().orElse(null)); this.session.setIdleTimeout(idleTimeout); connectListener.ifPresent(listener -> listener.onWebSocketConnect(this.context)); @@ -162,6 +162,17 @@ public class WebSocketResourceProvider implements WebSocket logger.debug("onWebSocketText!"); } + /** + * The property name where {@link org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider} can find an + * {@link ReusableAuth} object that lives for the lifetime of the websocket + */ + public static final String REUSABLE_AUTH_PROPERTY = WebSocketResourceProvider.class.getName() + ".reusableAuth"; + + /** + * The property name where {@link org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider} can install a + * {@link org.whispersystems.websocket.ReusableAuth.MutableRef} for us to close when the request is finished + */ + public static final String RESOLVED_PRINCIPAL_PROPERTY = WebSocketResourceProvider.class.getName() + ".resolvedPrincipal"; private void handleRequest(WebSocketRequestMessage requestMessage) { ContainerRequest containerRequest = new ContainerRequest(null, URI.create(requestMessage.getPath()), requestMessage.getVerb(), new WebSocketSecurityContext(new ContextPrincipal(context)), @@ -173,30 +184,43 @@ public class WebSocketResourceProvider implements WebSocket } containerRequest.setProperty(remoteAddressPropertyName, remoteAddress); + containerRequest.setProperty(REUSABLE_AUTH_PROPERTY, reusableAuth); ByteArrayOutputStream responseBody = new ByteArrayOutputStream(); CompletableFuture responseFuture = (CompletableFuture) jerseyHandler.apply( containerRequest, responseBody); - responseFuture.thenAccept(response -> { - try { - sendResponse(requestMessage, response, responseBody); - } catch (IOException e) { - throw new RuntimeException(e); - } - requestLog.log(remoteAddress, containerRequest, response); - }).exceptionally(exception -> { - logger.warn("Websocket Error: " + requestMessage.getVerb() + " " + requestMessage.getPath() + "\n" - + requestMessage.getBody(), exception); - try { - sendErrorResponse(requestMessage, Response.status(500).build()); - } catch (IOException e) { - logger.warn("Failed to send error response", e); - } - requestLog.log(remoteAddress, containerRequest, - new ContainerResponse(containerRequest, Response.status(500).build())); - return null; - }); + responseFuture + .whenComplete((ignoredResponse, ignoredError) -> { + // If the request ended up being one that mutates our principal, we have to close it to indicate we're done + // with the mutation operation + final Object resolvedPrincipal = containerRequest.getProperty(RESOLVED_PRINCIPAL_PROPERTY); + if (resolvedPrincipal instanceof ReusableAuth.MutableRef ref) { + ref.close(); + } else if (resolvedPrincipal != null) { + logger.warn("unexpected resolved principal type {} : {}", resolvedPrincipal.getClass(), resolvedPrincipal); + } + }) + .thenAccept(response -> { + try { + sendResponse(requestMessage, response, responseBody); + } catch (IOException e) { + throw new RuntimeException(e); + } + requestLog.log(remoteAddress, containerRequest, response); + }) + .exceptionally(exception -> { + logger.warn("Websocket Error: " + requestMessage.getVerb() + " " + requestMessage.getPath() + "\n" + + requestMessage.getBody(), exception); + try { + sendErrorResponse(requestMessage, Response.status(500).build()); + } catch (IOException e) { + logger.warn("Failed to send error response", e); + } + requestLog.log(remoteAddress, containerRequest, + new ContainerResponse(containerRequest, Response.status(500).build())); + return null; + }); } @VisibleForTesting diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java index 1b3314430..b7378952a 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java @@ -22,7 +22,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.websocket.auth.AuthenticationException; import org.whispersystems.websocket.auth.WebSocketAuthenticator; -import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult; import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider; @@ -57,17 +56,17 @@ public class WebSocketResourceProviderFactory extends Jetty public Object createWebSocket(final JettyServerUpgradeRequest request, final JettyServerUpgradeResponse response) { try { Optional> authenticator = Optional.ofNullable(environment.getAuthenticator()); - T authenticated = null; + final ReusableAuth authenticated; if (authenticator.isPresent()) { - AuthenticationResult authenticationResult = authenticator.get().authenticate(request); + authenticated = authenticator.get().authenticate(request); - if (authenticationResult.getUser().isEmpty() && authenticationResult.credentialsPresented()) { + if (authenticated.invalidCredentialsProvided()) { response.sendForbidden("Unauthorized"); return null; - } else { - authenticated = authenticationResult.getUser().orElse(null); } + } else { + authenticated = ReusableAuth.anonymous(); } return new WebSocketResourceProvider<>(getRemoteAddress(request), diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/Mutable.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/Mutable.java new file mode 100644 index 000000000..b5fadf7fd --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/Mutable.java @@ -0,0 +1,26 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.websocket.auth; + +import io.dropwizard.auth.Auth; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * An @{@link Auth} object annotated with {@link Mutable} indicates that the consumer of the object + * will modify the object or its underlying canonical source. + * + * Note: An {@link Auth} object that does not specify @{@link ReadOnly} will be assumed to be @Mutable + * + * @see org.whispersystems.websocket.ReusableAuth + */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.FIELD, ElementType.PARAMETER}) +public @interface Mutable { +} + diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/PrincipalSupplier.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/PrincipalSupplier.java new file mode 100644 index 000000000..414d11d66 --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/PrincipalSupplier.java @@ -0,0 +1,58 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.websocket.auth; + +/** + * Teach {@link org.whispersystems.websocket.ReusableAuth} how to make a deep copy of a principal (that is safe to + * concurrently modify while the original principal is being read), and how to refresh a principal after it has been + * potentially modified. + * + * @param The underlying principal type + */ +public interface PrincipalSupplier { + + /** + * Re-fresh the principal after it has been modified. + *

+ * If the principal is populated from a backing store, refresh should re-read it. + * + * @param t the potentially stale principal to refresh + * @return The up-to-date principal + */ + T refresh(T t); + + /** + * Create a deep, in-memory copy of the principal. This should be identical to the original principal, but should + * share no mutable state with the original. It should be safe for two threads to independently write and read from + * two independent deep copies. + * + * @param t the principal to copy + * @return An in-memory copy of the principal + */ + T deepCopy(T t); + + class ImmutablePrincipalSupplier implements PrincipalSupplier { + @SuppressWarnings({"rawtypes"}) + private static final PrincipalSupplier INSTANCE = new ImmutablePrincipalSupplier(); + + @Override + public T refresh(final T t) { + return t; + } + + @Override + public T deepCopy(final T t) { + return t; + } + } + + /** + * @return A principal supplier that can be used if the principal type does not support modification. + */ + static PrincipalSupplier forImmutablePrincipal() { + //noinspection unchecked + return (PrincipalSupplier) ImmutablePrincipalSupplier.INSTANCE; + } +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/ReadOnly.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/ReadOnly.java new file mode 100644 index 000000000..8deab76ce --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/ReadOnly.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.websocket.auth; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * An @{@link io.dropwizard.auth.Auth} object annotated with {@link ReadOnly} indicates that the consumer of the object + * will never modify the object, nor its underlying canonical source. + *

+ * For example, a consumer of a @ReadOnly AuthenticatedAccount promises to never modify the in-memory + * AuthenticatedAccount and to never modify the underlying Account database for the account. + * + * @see org.whispersystems.websocket.ReusableAuth + */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.FIELD, ElementType.PARAMETER}) +public @interface ReadOnly { +} + diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java index a8fee10b4..e769e2d28 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java @@ -7,27 +7,8 @@ package org.whispersystems.websocket.auth; import java.security.Principal; import java.util.Optional; import org.eclipse.jetty.websocket.api.UpgradeRequest; +import org.whispersystems.websocket.ReusableAuth; public interface WebSocketAuthenticator { - - AuthenticationResult authenticate(UpgradeRequest request) throws AuthenticationException; - - @SuppressWarnings("OptionalUsedAsFieldOrParameterType") - class AuthenticationResult { - private final Optional user; - private final boolean credentialsPresented; - - public AuthenticationResult(final Optional user, final boolean credentialsPresented) { - this.user = user; - this.credentialsPresented = credentialsPresented; - } - - public Optional getUser() { - return user; - } - - public boolean credentialsPresented() { - return credentialsPresented; - } - } + ReusableAuth authenticate(UpgradeRequest request) throws AuthenticationException; } diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebsocketAuthValueFactoryProvider.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebsocketAuthValueFactoryProvider.java index 28c488ea3..776e20ac5 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebsocketAuthValueFactoryProvider.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebsocketAuthValueFactoryProvider.java @@ -5,24 +5,28 @@ package org.whispersystems.websocket.auth; import io.dropwizard.auth.Auth; +import java.lang.reflect.ParameterizedType; +import java.security.Principal; +import java.util.Optional; +import java.util.function.Function; +import javax.annotation.Nullable; +import javax.inject.Inject; +import javax.inject.Singleton; +import javax.ws.rs.WebApplicationException; import org.glassfish.jersey.internal.inject.AbstractBinder; import org.glassfish.jersey.server.ContainerRequest; import org.glassfish.jersey.server.internal.inject.AbstractValueParamProvider; import org.glassfish.jersey.server.internal.inject.MultivaluedParameterExtractorProvider; import org.glassfish.jersey.server.model.Parameter; import org.glassfish.jersey.server.spi.internal.ValueParamProvider; - -import javax.annotation.Nullable; -import javax.inject.Inject; -import javax.inject.Singleton; -import javax.ws.rs.WebApplicationException; -import java.lang.reflect.ParameterizedType; -import java.security.Principal; -import java.util.Optional; -import java.util.function.Function; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.websocket.ReusableAuth; +import org.whispersystems.websocket.WebSocketResourceProvider; @Singleton public class WebsocketAuthValueFactoryProvider extends AbstractValueParamProvider { + private static final Logger logger = LoggerFactory.getLogger(WebsocketAuthValueFactoryProvider.class); private final Class principalClass; @@ -39,18 +43,38 @@ public class WebsocketAuthValueFactoryProvider extends Abst return null; } - if (parameter.getRawType() == Optional.class && - ParameterizedType.class.isAssignableFrom(parameter.getType().getClass()) && - principalClass == ((ParameterizedType)parameter.getType()).getActualTypeArguments()[0]) - { - return request -> new OptionalContainerRequestValueFactory(request).provide(); + final boolean readOnly = parameter.isAnnotationPresent(ReadOnly.class); + + if (parameter.getRawType() == Optional.class + && ParameterizedType.class.isAssignableFrom(parameter.getType().getClass()) + && principalClass == ((ParameterizedType) parameter.getType()).getActualTypeArguments()[0]) { + return containerRequest -> createPrincipal(containerRequest, readOnly); } else if (principalClass.equals(parameter.getRawType())) { - return request -> new StandardContainerRequestValueFactory(request).provide(); + return containerRequest -> + createPrincipal(containerRequest, readOnly) + .orElseThrow(() -> new WebApplicationException("Authenticated resource", 401)); } else { throw new IllegalStateException("Can't inject unassignable principal: " + principalClass + " for parameter: " + parameter); } } + private Optional createPrincipal(final ContainerRequest request, final boolean readOnly) { + final Object obj = request.getProperty(WebSocketResourceProvider.REUSABLE_AUTH_PROPERTY); + if (!(obj instanceof ReusableAuth)) { + logger.warn("Unexpected reusable auth property type {} : {}", obj.getClass(), obj); + return Optional.empty(); + } + @SuppressWarnings("unchecked") final ReusableAuth reusableAuth = (ReusableAuth) obj; + if (readOnly) { + return reusableAuth.ref(); + } else { + return reusableAuth.mutableRef().map(writeRef -> { + request.setProperty(WebSocketResourceProvider.RESOLVED_PRINCIPAL_PROPERTY, writeRef); + return writeRef.ref(); + }); + } + } + @Singleton static class WebsocketPrincipalClassProvider { @@ -80,38 +104,4 @@ public class WebsocketAuthValueFactoryProvider extends Abst bind(WebsocketAuthValueFactoryProvider.class).to(ValueParamProvider.class).in(Singleton.class); } } - - private static class StandardContainerRequestValueFactory { - - private final ContainerRequest request; - - public StandardContainerRequestValueFactory(ContainerRequest request) { - this.request = request; - } - - public Principal provide() { - final Principal principal = request.getSecurityContext().getUserPrincipal(); - - if (principal == null) { - throw new WebApplicationException("Authenticated resource", 401); - } - - return principal; - } - - } - - private static class OptionalContainerRequestValueFactory { - - private final ContainerRequest request; - - public OptionalContainerRequestValueFactory(ContainerRequest request) { - this.request = request; - } - - public Optional provide() { - return Optional.ofNullable(request.getSecurityContext().getUserPrincipal()); - } - } - } diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java index 4546242ea..c11f3ace2 100644 --- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java +++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java @@ -7,6 +7,7 @@ package org.whispersystems.websocket; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -27,6 +28,7 @@ import org.glassfish.jersey.server.ResourceConfig; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.whispersystems.websocket.auth.AuthenticationException; +import org.whispersystems.websocket.auth.PrincipalSupplier; import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.setup.WebSocketEnvironment; @@ -56,8 +58,7 @@ public class WebSocketResourceProviderFactoryTest { @Test void testUnauthorized() throws AuthenticationException, IOException { when(environment.getAuthenticator()).thenReturn(authenticator); - when(authenticator.authenticate(eq(request))).thenReturn( - new WebSocketAuthenticator.AuthenticationResult<>(Optional.empty(), true)); + when(authenticator.authenticate(eq(request))).thenReturn(ReusableAuth.invalid()); when(environment.jersey()).thenReturn(jerseyEnvironment); WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, Account.class, @@ -74,8 +75,8 @@ public class WebSocketResourceProviderFactoryTest { Account account = new Account(); when(environment.getAuthenticator()).thenReturn(authenticator); - when(authenticator.authenticate(eq(request))).thenReturn( - new WebSocketAuthenticator.AuthenticationResult<>(Optional.of(account), true)); + when(authenticator.authenticate(eq(request))) + .thenReturn(ReusableAuth.authenticated(account, PrincipalSupplier.forImmutablePrincipal())); when(environment.jersey()).thenReturn(jerseyEnvironment); final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class); when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1"); @@ -137,6 +138,7 @@ public class WebSocketResourceProviderFactoryTest { public boolean implies(Subject subject) { return false; } + } diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java index 2f8536cd4..f2955fdf1 100644 --- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java +++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java @@ -59,6 +59,7 @@ import org.glassfish.jersey.server.ResourceConfig; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; +import org.whispersystems.websocket.auth.PrincipalSupplier; import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; import org.whispersystems.websocket.logging.WebsocketRequestLog; import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory; @@ -80,7 +81,7 @@ class WebSocketResourceProviderTest { WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, - new TestPrincipal("fooz"), + immutableTestPrincipal("fooz"), new ProtobufWebSocketMessageFactory(), Optional.of(connectListener), Duration.ofMillis(30000)); @@ -108,7 +109,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = mock(ApplicationHandler.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -184,7 +185,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = mock(ApplicationHandler.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -240,7 +241,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -280,7 +281,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -320,7 +321,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("authorizedUserName"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("authorizedUserName"), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -360,8 +361,8 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(), - Optional.empty(), Duration.ofMillis(30000)); + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, ReusableAuth.anonymous(), + new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); @@ -399,7 +400,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("something"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("something"), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -439,8 +440,8 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(), - Optional.empty(), Duration.ofMillis(30000)); + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, ReusableAuth.anonymous(), + new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); @@ -479,7 +480,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -520,7 +521,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -562,7 +563,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -602,7 +603,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -727,6 +728,10 @@ class WebSocketResourceProviderTest { } } + public static ReusableAuth immutableTestPrincipal(final String name) { + return ReusableAuth.authenticated(new TestPrincipal(name), PrincipalSupplier.forImmutablePrincipal()); + } + public static class TestException extends Exception { public TestException(String message) {