Lifecycle management for Account objects reused accross websocket requests

This commit is contained in:
Ravi Khadiwala 2024-02-06 16:59:42 -06:00 committed by ravi-signal
parent 29ef3f0b41
commit 26ffa19f36
38 changed files with 1317 additions and 457 deletions

View File

@ -188,6 +188,7 @@ import org.whispersystems.textsecuregcm.spam.SenderOverrideProvider;
import org.whispersystems.textsecuregcm.spam.SpamChecker; import org.whispersystems.textsecuregcm.spam.SpamChecker;
import org.whispersystems.textsecuregcm.spam.SpamFilter; import org.whispersystems.textsecuregcm.spam.SpamFilter;
import org.whispersystems.textsecuregcm.storage.AccountLockManager; import org.whispersystems.textsecuregcm.storage.AccountLockManager;
import org.whispersystems.textsecuregcm.storage.AccountPrincipalSupplier;
import org.whispersystems.textsecuregcm.storage.Accounts; import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ChangeNumberManager; import org.whispersystems.textsecuregcm.storage.ChangeNumberManager;
@ -812,7 +813,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
WebSocketEnvironment<AuthenticatedAccount> webSocketEnvironment = new WebSocketEnvironment<>(environment, WebSocketEnvironment<AuthenticatedAccount> webSocketEnvironment = new WebSocketEnvironment<>(environment,
config.getWebSocketConfiguration(), Duration.ofMillis(90000)); config.getWebSocketConfiguration(), Duration.ofMillis(90000));
webSocketEnvironment.jersey().register(new VirtualExecutorServiceProvider("managed-async-websocket-virtual-thread-")); webSocketEnvironment.jersey().register(new VirtualExecutorServiceProvider("managed-async-websocket-virtual-thread-"));
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator)); webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator, new AccountPrincipalSupplier(accountsManager)));
webSocketEnvironment.setConnectListener( webSocketEnvironment.setConnectListener(
new AuthenticatedConnectListener(receiptSender, messagesManager, pushNotificationManager, new AuthenticatedConnectListener(receiptSender, messagesManager, pushNotificationManager,
clientPresenceManager, websocketScheduledExecutor, messageDeliveryScheduler, clientReleaseManager)); clientPresenceManager, websocketScheduledExecutor, messageDeliveryScheduler, clientReleaseManager));

View File

@ -21,7 +21,6 @@ import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.RefreshingAccountAndDeviceSupplier;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
@ -108,8 +107,7 @@ public class AccountAuthenticator implements Authenticator<BasicCredentials, Aut
device.get(), device.get(),
SaltedTokenHash.generateFor(basicCredentials.getPassword())); // new credentials have current version SaltedTokenHash.generateFor(basicCredentials.getPassword())); // new credentials have current version
} }
return Optional.of(new AuthenticatedAccount( return Optional.of(new AuthenticatedAccount(authenticatedAccount, device.get()));
new RefreshingAccountAndDeviceSupplier(authenticatedAccount, device.get().getId(), accountsManager)));
} }
return Optional.empty(); return Optional.empty();

View File

@ -5,7 +5,6 @@
package org.whispersystems.textsecuregcm.auth; package org.whispersystems.textsecuregcm.auth;
import com.google.common.annotations.VisibleForTesting;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
@ -45,10 +44,6 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
this.accountsManager = accountsManager; this.accountsManager = accountsManager;
} }
@VisibleForTesting
static Map<Byte, Boolean> buildDevicesEnabledMap(final Account account) {
return account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::isEnabled));
}
@Override @Override
public void handleRequestFiltered(final RequestEvent requestEvent) { public void handleRequestFiltered(final RequestEvent requestEvent) {
@ -60,10 +55,13 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
setAccount(requestEvent.getContainerRequest(), account)); setAccount(requestEvent.getContainerRequest(), account));
} }
} }
public static void setAccount(final ContainerRequest containerRequest, final Account account) { public static void setAccount(final ContainerRequest containerRequest, final Account account) {
containerRequest.setProperty(ACCOUNT_UUID, account.getUuid()); setAccount(containerRequest, ContainerRequestUtil.AccountInfo.fromAccount(account));
containerRequest.setProperty(DEVICES_ENABLED, buildDevicesEnabledMap(account)); }
private static void setAccount(final ContainerRequest containerRequest, final ContainerRequestUtil.AccountInfo info) {
containerRequest.setProperty(ACCOUNT_UUID, info.accountId());
containerRequest.setProperty(DEVICES_ENABLED, info.devicesEnabled());
} }
@Override @Override
@ -75,9 +73,11 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
@SuppressWarnings("unchecked") final Map<Byte, Boolean> initialDevicesEnabled = @SuppressWarnings("unchecked") final Map<Byte, Boolean> initialDevicesEnabled =
(Map<Byte, Boolean>) requestEvent.getContainerRequest().getProperty(DEVICES_ENABLED); (Map<Byte, Boolean>) requestEvent.getContainerRequest().getProperty(DEVICES_ENABLED);
return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID)).map(account -> { return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID))
.map(ContainerRequestUtil.AccountInfo::fromAccount)
.map(account -> {
final Set<Byte> deviceIdsToDisplace; final Set<Byte> deviceIdsToDisplace;
final Map<Byte, Boolean> currentDevicesEnabled = buildDevicesEnabledMap(account); final Map<Byte, Boolean> currentDevicesEnabled = account.devicesEnabled();
if (!initialDevicesEnabled.equals(currentDevicesEnabled)) { if (!initialDevicesEnabled.equals(currentDevicesEnabled)) {
deviceIdsToDisplace = new HashSet<>(initialDevicesEnabled.keySet()); deviceIdsToDisplace = new HashSet<>(initialDevicesEnabled.keySet());
@ -87,13 +87,14 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
} }
return deviceIdsToDisplace.stream() return deviceIdsToDisplace.stream()
.map(deviceId -> new Pair<>(account.getUuid(), deviceId)) .map(deviceId -> new Pair<>(account.accountId(), deviceId))
.collect(Collectors.toList()); .collect(Collectors.toList());
}).orElseGet(() -> { }).orElseGet(() -> {
logger.error("Request had account, but it is no longer present"); logger.error("Request had account, but it is no longer present");
return Collections.emptyList(); return Collections.emptyList();
}); });
} else } else {
return Collections.emptyList(); return Collections.emptyList();
} }
}
} }

View File

@ -10,24 +10,24 @@ import java.util.function.Supplier;
import javax.security.auth.Subject; import javax.security.auth.Subject;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair;
public class AuthenticatedAccount implements Principal, AccountAndAuthenticatedDeviceHolder { public class AuthenticatedAccount implements Principal, AccountAndAuthenticatedDeviceHolder {
private final Account account;
private final Device device;
private final Supplier<Pair<Account, Device>> accountAndDevice; public AuthenticatedAccount(final Account account, final Device device) {
this.account = account;
public AuthenticatedAccount(final Supplier<Pair<Account, Device>> accountAndDevice) { this.device = device;
this.accountAndDevice = accountAndDevice;
} }
@Override @Override
public Account getAccount() { public Account getAccount() {
return accountAndDevice.get().first(); return account;
} }
@Override @Override
public Device getAuthenticatedDevice() { public Device getAuthenticatedDevice() {
return accountAndDevice.get().second(); return device;
} }
// Principal implementation // Principal implementation

View File

@ -0,0 +1,20 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* Indicates that an endpoint changes the phone number and PNI keys associated with an account, and that
* any websockets associated with the account may need to be refreshed after a call to that endpoint.
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface ChangesPhoneNumber {
}

View File

@ -7,15 +7,42 @@ package org.whispersystems.textsecuregcm.auth;
import org.glassfish.jersey.server.ContainerRequest; import org.glassfish.jersey.server.ContainerRequest;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import javax.ws.rs.core.SecurityContext; import javax.ws.rs.core.SecurityContext;
import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;
class ContainerRequestUtil { class ContainerRequestUtil {
static Optional<Account> getAuthenticatedAccount(final ContainerRequest request) { private static Map<Byte, Boolean> buildDevicesEnabledMap(final Account account) {
return account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::isEnabled));
}
/**
* A read-only subset of the authenticated Account object, to enforce that filter-based consumers do not perform
* account modifying operations.
*/
record AccountInfo(UUID accountId, String e164, Map<Byte, Boolean> devicesEnabled) {
static AccountInfo fromAccount(final Account account) {
return new AccountInfo(
account.getUuid(),
account.getNumber(),
buildDevicesEnabledMap(account));
}
}
static Optional<AccountInfo> getAuthenticatedAccount(final ContainerRequest request) {
return Optional.ofNullable(request.getSecurityContext()) return Optional.ofNullable(request.getSecurityContext())
.map(SecurityContext::getUserPrincipal) .map(SecurityContext::getUserPrincipal)
.map(principal -> principal instanceof AccountAndAuthenticatedDeviceHolder .map(principal -> {
? ((AccountAndAuthenticatedDeviceHolder) principal).getAccount() : null); if (principal instanceof AccountAndAuthenticatedDeviceHolder aaadh) {
return aaadh.getAccount();
}
return null;
})
.map(AccountInfo::fromAccount);
} }
} }

View File

@ -7,40 +7,50 @@ package org.whispersystems.textsecuregcm.auth;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.glassfish.jersey.server.monitoring.RequestEvent; import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
public class PhoneNumberChangeRefreshRequirementProvider implements WebsocketRefreshRequirementProvider { public class PhoneNumberChangeRefreshRequirementProvider implements WebsocketRefreshRequirementProvider {
private static final String ACCOUNT_UUID =
PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".accountUuid";
private static final String INITIAL_NUMBER_KEY = private static final String INITIAL_NUMBER_KEY =
PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".initialNumber"; PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".initialNumber";
private final AccountsManager accountsManager;
public PhoneNumberChangeRefreshRequirementProvider(final AccountsManager accountsManager) {
this.accountsManager = accountsManager;
}
@Override @Override
public void handleRequestFiltered(final RequestEvent requestEvent) { public void handleRequestFiltered(final RequestEvent requestEvent) {
if (requestEvent.getUriInfo().getMatchedResourceMethod().getInvocable().getHandlingMethod()
.getAnnotation(ChangesPhoneNumber.class) == null) {
return;
}
ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest()) ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest())
.ifPresent(account -> requestEvent.getContainerRequest().setProperty(INITIAL_NUMBER_KEY, account.getNumber())); .ifPresent(account -> {
requestEvent.getContainerRequest().setProperty(INITIAL_NUMBER_KEY, account.e164());
requestEvent.getContainerRequest().setProperty(ACCOUNT_UUID, account.accountId());
});
} }
@Override @Override
public List<Pair<UUID, Byte>> handleRequestFinished(final RequestEvent requestEvent) { public List<Pair<UUID, Byte>> handleRequestFinished(final RequestEvent requestEvent) {
final String initialNumber = (String) requestEvent.getContainerRequest().getProperty(INITIAL_NUMBER_KEY); final String initialNumber = (String) requestEvent.getContainerRequest().getProperty(INITIAL_NUMBER_KEY);
if (initialNumber != null) { if (initialNumber == null) {
final Optional<Account> maybeAuthenticatedAccount = return Collections.emptyList();
ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest()); }
return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID))
return maybeAuthenticatedAccount
.filter(account -> !initialNumber.equals(account.getNumber())) .filter(account -> !initialNumber.equals(account.getNumber()))
.map(account -> account.getDevices().stream() .map(account -> account.getDevices().stream()
.map(device -> new Pair<>(account.getUuid(), device.getId())) .map(device -> new Pair<>(account.getUuid(), device.getId()))
.collect(Collectors.toList())) .collect(Collectors.toList()))
.orElse(Collections.emptyList()); .orElse(Collections.emptyList());
} else {
return Collections.emptyList();
}
} }
} }

View File

@ -24,7 +24,7 @@ public class WebsocketRefreshApplicationEventListener implements ApplicationEven
this.websocketRefreshRequestEventListener = new WebsocketRefreshRequestEventListener(clientPresenceManager, this.websocketRefreshRequestEventListener = new WebsocketRefreshRequestEventListener(clientPresenceManager,
new AuthEnablementRefreshRequirementProvider(accountsManager), new AuthEnablementRefreshRequirementProvider(accountsManager),
new PhoneNumberChangeRefreshRequirementProvider()); new PhoneNumberChangeRefreshRequirementProvider(accountsManager));
} }
@Override @Override

View File

@ -36,6 +36,7 @@ import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ChangesPhoneNumber;
import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager; import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
import org.whispersystems.textsecuregcm.entities.AccountDataReportResponse; import org.whispersystems.textsecuregcm.entities.AccountDataReportResponse;
@ -90,6 +91,7 @@ public class AccountControllerV2 {
@Path("/number") @Path("/number")
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@ChangesPhoneNumber
@Operation(summary = "Change number", description = "Changes a phone number for an existing account.") @Operation(summary = "Change number", description = "Changes a phone number for an existing account.")
@ApiResponse(responseCode = "200", description = "The phone number associated with the authenticated account was changed successfully", useReturnTypeSchema = true) @ApiResponse(responseCode = "200", description = "The phone number associated with the authenticated account was changed successfully", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.") @ApiResponse(responseCode = "401", description = "Account authentication check failed.")

View File

@ -0,0 +1,36 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.websocket.auth.PrincipalSupplier;
public class AccountPrincipalSupplier implements PrincipalSupplier<AuthenticatedAccount> {
private final AccountsManager accountsManager;
public AccountPrincipalSupplier(final AccountsManager accountsManager) {
this.accountsManager = accountsManager;
}
@Override
public AuthenticatedAccount refresh(final AuthenticatedAccount oldAccount) {
final Account account = accountsManager.getByAccountIdentifier(oldAccount.getAccount().getUuid())
.orElseThrow(() -> new RefreshingAccountNotFoundException("Could not find account"));
final Device device = account.getDevice(oldAccount.getAuthenticatedDevice().getId())
.orElseThrow(() -> new RefreshingAccountNotFoundException("Could not find device"));
return new AuthenticatedAccount(account, device);
}
@Override
public AuthenticatedAccount deepCopy(final AuthenticatedAccount authenticatedAccount) {
final Account cloned = AccountUtil.cloneAccountAsNotStale(authenticatedAccount.getAccount());
return new AuthenticatedAccount(
cloned,
cloned.getDevice(authenticatedAccount.getAuthenticatedDevice().getId())
.orElseThrow(() -> new IllegalStateException(
"Could not find device from a clone of an account where the device was present")));
}
}

View File

@ -5,10 +5,12 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import java.io.IOException; import java.io.IOException;
class AccountUtil { public class AccountUtil {
static Account cloneAccountAsNotStale(final Account account) { static Account cloneAccountAsNotStale(final Account account) {
try { try {

View File

@ -1,35 +0,0 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import java.util.function.Supplier;
import org.whispersystems.textsecuregcm.util.Pair;
public class RefreshingAccountAndDeviceSupplier implements Supplier<Pair<Account, Device>> {
private Account account;
private Device device;
private final AccountsManager accountsManager;
public RefreshingAccountAndDeviceSupplier(Account account, byte deviceId, AccountsManager accountsManager) {
this.account = account;
this.device = account.getDevice(deviceId)
.orElseThrow(() -> new RefreshingAccountAndDeviceNotFoundException("Could not find device"));
this.accountsManager = accountsManager;
}
@Override
public Pair<Account, Device> get() {
if (account.isStale()) {
account = accountsManager.getByAccountIdentifier(account.getUuid())
.orElseThrow(() -> new RuntimeException("Could not find account"));
device = account.getDevice(device.getId())
.orElseThrow(() -> new RefreshingAccountAndDeviceNotFoundException("Could not find device"));
}
return new Pair<>(account, device);
}
}

View File

@ -5,9 +5,9 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
public class RefreshingAccountAndDeviceNotFoundException extends RuntimeException { public class RefreshingAccountNotFoundException extends RuntimeException {
public RefreshingAccountAndDeviceNotFoundException(final String message) { public RefreshingAccountNotFoundException(final String message) {
super(message); super(message);
} }

View File

@ -11,41 +11,40 @@ import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.basic.BasicCredentials; import io.dropwizard.auth.basic.BasicCredentials;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.AuthenticationException; import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.auth.WebSocketAuthenticator;
public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<AuthenticatedAccount> { public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<AuthenticatedAccount> {
private static final AuthenticationResult<AuthenticatedAccount> CREDENTIALS_NOT_PRESENTED = private static final ReusableAuth<AuthenticatedAccount> CREDENTIALS_NOT_PRESENTED = ReusableAuth.anonymous();
new AuthenticationResult<>(Optional.empty(), false);
private static final AuthenticationResult<AuthenticatedAccount> INVALID_CREDENTIALS_PRESENTED = private static final ReusableAuth<AuthenticatedAccount> INVALID_CREDENTIALS_PRESENTED = ReusableAuth.invalid();
new AuthenticationResult<>(Optional.empty(), true);
private final AccountAuthenticator accountAuthenticator; private final AccountAuthenticator accountAuthenticator;
private final PrincipalSupplier<AuthenticatedAccount> principalSupplier;
public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator,
public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator) { final PrincipalSupplier<AuthenticatedAccount> principalSupplier) {
this.accountAuthenticator = accountAuthenticator; this.accountAuthenticator = accountAuthenticator;
this.principalSupplier = principalSupplier;
} }
@Override @Override
public AuthenticationResult<AuthenticatedAccount> authenticate(final UpgradeRequest request) public ReusableAuth<AuthenticatedAccount> authenticate(final UpgradeRequest request)
throws AuthenticationException { throws AuthenticationException {
try { try {
final AuthenticationResult<AuthenticatedAccount> authResultFromHeader = // If the `Authorization` header was set for the request it takes priority, and we use the result of the
authenticatedAccountFromHeaderAuth(request.getHeader(HttpHeaders.AUTHORIZATION)); // header-based auth ignoring the result of the query-based auth.
// the logic here is that if the `Authorization` header was set for the request, final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);
// it takes the priority and we use the result of the header-based auth if (authHeader != null) {
// ignoring the result of the query-based auth. return authenticatedAccountFromHeaderAuth(authHeader);
if (authResultFromHeader.credentialsPresented()) {
return authResultFromHeader;
} }
return authenticatedAccountFromQueryParams(request); return authenticatedAccountFromQueryParams(request);
} catch (final Exception e) { } catch (final Exception e) {
@ -55,7 +54,7 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Aut
} }
} }
private AuthenticationResult<AuthenticatedAccount> authenticatedAccountFromQueryParams(final UpgradeRequest request) { private ReusableAuth<AuthenticatedAccount> authenticatedAccountFromQueryParams(final UpgradeRequest request) {
final Map<String, List<String>> parameters = request.getParameterMap(); final Map<String, List<String>> parameters = request.getParameterMap();
final List<String> usernames = parameters.get("login"); final List<String> usernames = parameters.get("login");
final List<String> passwords = parameters.get("password"); final List<String> passwords = parameters.get("password");
@ -65,16 +64,19 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Aut
} }
final BasicCredentials credentials = new BasicCredentials(usernames.get(0).replace(" ", "+"), final BasicCredentials credentials = new BasicCredentials(usernames.get(0).replace(" ", "+"),
passwords.get(0).replace(" ", "+")); passwords.get(0).replace(" ", "+"));
return new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true); return accountAuthenticator.authenticate(credentials)
.map(authenticatedAccount -> ReusableAuth.authenticated(authenticatedAccount, this.principalSupplier))
.orElse(INVALID_CREDENTIALS_PRESENTED);
} }
private AuthenticationResult<AuthenticatedAccount> authenticatedAccountFromHeaderAuth(@Nullable final String authHeader) private ReusableAuth<AuthenticatedAccount> authenticatedAccountFromHeaderAuth(@Nullable final String authHeader)
throws AuthenticationException { throws AuthenticationException {
if (authHeader == null) { if (authHeader == null) {
return CREDENTIALS_NOT_PRESENTED; return CREDENTIALS_NOT_PRESENTED;
} }
return basicCredentialsFromAuthHeader(authHeader) return basicCredentialsFromAuthHeader(authHeader)
.map(credentials -> new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true)) .flatMap(credentials -> accountAuthenticator.authenticate(credentials))
.map(authenticatedAccount -> ReusableAuth.authenticated(authenticatedAccount, this.principalSupplier))
.orElse(INVALID_CREDENTIALS_PRESENTED); .orElse(INVALID_CREDENTIALS_PRESENTED);
} }
} }

View File

@ -0,0 +1,279 @@
package org.whispersystems.textsecuregcm;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
import io.dropwizard.auth.Auth;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import java.io.IOException;
import java.net.URI;
import java.util.EnumSet;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.IntStream;
import javax.servlet.DispatcherType;
import javax.servlet.ServletRegistration;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync;
import org.glassfish.jersey.server.ServerProperties;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.RefreshingAccountNotFoundException;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
public class WebsocketReuseAuthIntegrationTest {
private static final AuthenticatedAccount ACCOUNT = mock(AuthenticatedAccount.class);
@SuppressWarnings("unchecked")
private static final PrincipalSupplier<AuthenticatedAccount> PRINCIPAL_SUPPLIER = mock(PrincipalSupplier.class);
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION =
new DropwizardAppExtension<>(TestApplication.class);
private WebSocketClient client;
@BeforeEach
void setUp() throws Exception {
reset(PRINCIPAL_SUPPLIER);
reset(ACCOUNT);
when(ACCOUNT.getName()).thenReturn("original");
client = new WebSocketClient();
client.start();
}
@AfterEach
void tearDown() throws Exception {
client.stop();
}
public static class TestApplication extends Application<Configuration> {
@Override
public void run(final Configuration configuration, final Environment environment) throws Exception {
final TestController testController = new TestController();
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
final WebSocketEnvironment<AuthenticatedAccount> webSocketEnvironment =
new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter(true))
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(testController);
webSocketEnvironment.jersey().register(new RemoteAddressFilter(true));
webSocketEnvironment.setAuthenticator(upgradeRequest -> ReusableAuth.authenticated(ACCOUNT, PRINCIPAL_SUPPLIER));
webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE);
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {
});
final WebSocketResourceProviderFactory<AuthenticatedAccount> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedAccount.class,
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
final ServletRegistration.Dynamic websocketServlet =
environment.servlets().addServlet("WebSocket", webSocketServlet);
websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
}
}
private WebSocketResponseMessage make1WebsocketRequest(final String requestPath) throws IOException {
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
return testWebsocketListener.doGet(requestPath).join();
}
@ParameterizedTest
@ValueSource(strings = {"/test/read-auth", "/test/optional-read-auth"})
public void readAuth(final String path) throws IOException {
final WebSocketResponseMessage response = make1WebsocketRequest(path);
assertThat(response.getStatus()).isEqualTo(200);
verifyNoMoreInteractions(PRINCIPAL_SUPPLIER);
}
@ParameterizedTest
@ValueSource(strings = {"/test/write-auth", "/test/optional-write-auth"})
public void writeAuth(final String path) throws IOException {
final AuthenticatedAccount copiedAccount = mock(AuthenticatedAccount.class);
when(copiedAccount.getName()).thenReturn("copy");
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(copiedAccount);
final WebSocketResponseMessage response = make1WebsocketRequest(path);
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.getBody().map(String::new)).get().isEqualTo("copy");
verify(PRINCIPAL_SUPPLIER, times(1)).deepCopy(any());
verifyNoMoreInteractions(PRINCIPAL_SUPPLIER);
}
@Test
public void readAfterWrite() throws IOException {
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(ACCOUNT);
final AuthenticatedAccount account2 = mock(AuthenticatedAccount.class);
when(account2.getName()).thenReturn("refresh");
when(PRINCIPAL_SUPPLIER.refresh(any())).thenReturn(account2);
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse.getBody().map(String::new)).get().isEqualTo("original");
final WebSocketResponseMessage writeResponse = testWebsocketListener.doGet("/test/write-auth").join();
assertThat(writeResponse.getBody().map(String::new)).get().isEqualTo("original");
final WebSocketResponseMessage readResponse2 = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse2.getBody().map(String::new)).get().isEqualTo("refresh");
}
@Test
public void readAfterWriteRefreshFails() throws IOException {
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(ACCOUNT);
when(PRINCIPAL_SUPPLIER.refresh(any())).thenThrow(RefreshingAccountNotFoundException.class);
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
final WebSocketResponseMessage writeResponse = testWebsocketListener.doGet("/test/write-auth").join();
assertThat(writeResponse.getBody().map(String::new)).get().isEqualTo("original");
final WebSocketResponseMessage readResponse2 = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse2.getStatus()).isEqualTo(500);
}
@Test
public void readConcurrentWithWrite() throws IOException, ExecutionException, InterruptedException, TimeoutException {
final AuthenticatedAccount deepCopy = mock(AuthenticatedAccount.class);
when(deepCopy.getName()).thenReturn("deepCopy");
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(deepCopy);
final AuthenticatedAccount refresh = mock(AuthenticatedAccount.class);
when(refresh.getName()).thenReturn("refresh");
when(PRINCIPAL_SUPPLIER.refresh(any())).thenReturn(refresh);
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
// start a write request that takes a while to finish
final CompletableFuture<WebSocketResponseMessage> writeResponse =
testWebsocketListener.doGet("/test/start-delayed-write/foo");
// send a bunch of reads, they should reflect the original auth
final List<CompletableFuture<WebSocketResponseMessage>> futures = IntStream.range(0, 10)
.boxed().map(i -> testWebsocketListener.doGet("/test/read-auth"))
.toList();
CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join();
for (CompletableFuture<WebSocketResponseMessage> future : futures) {
assertThat(future.join().getBody().map(String::new)).get().isEqualTo("original");
}
assertThat(writeResponse.isDone()).isFalse();
// finish the delayed write request
testWebsocketListener.doGet("/test/finish-delayed-write/foo").get(1, TimeUnit.SECONDS);
assertThat(writeResponse.join().getBody().map(String::new)).get().isEqualTo("deepCopy");
// subsequent reads should have the refreshed auth
final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse.getBody().map(String::new)).get().isEqualTo("refresh");
}
@Path("/test")
public static class TestController {
private final ConcurrentHashMap<String, CountDownLatch> delayedWriteLatches = new ConcurrentHashMap<>();
@GET
@Path("/read-auth")
@ManagedAsync
public String readAuth(@ReadOnly @Auth final AuthenticatedAccount account) {
return account.getName();
}
@GET
@Path("/optional-read-auth")
@ManagedAsync
public String optionalReadAuth(@ReadOnly @Auth final Optional<AuthenticatedAccount> account) {
return account.map(AuthenticatedAccount::getName).orElse("empty");
}
@GET
@Path("/write-auth")
@ManagedAsync
public String writeAuth(@Auth final AuthenticatedAccount account) {
return account.getName();
}
@GET
@Path("/optional-write-auth")
@ManagedAsync
public String optionalWriteAuth(@Auth final Optional<AuthenticatedAccount> account) {
return account.map(AuthenticatedAccount::getName).orElse("empty");
}
@GET
@Path("/start-delayed-write/{id}")
@ManagedAsync
public String startDelayedWrite(@Auth final AuthenticatedAccount account, @PathParam("id") String id)
throws InterruptedException {
delayedWriteLatches.computeIfAbsent(id, i -> new CountDownLatch(1)).await();
return account.getName();
}
@GET
@Path("/finish-delayed-write/{id}")
@ManagedAsync
public String finishDelayedWrite(@PathParam("id") String id) {
delayedWriteLatches.computeIfAbsent(id, i -> new CountDownLatch(1)).countDown();
return "ok";
}
}
}

View File

@ -7,8 +7,6 @@ package org.whispersystems.textsecuregcm.auth;
import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
@ -30,7 +28,6 @@ import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.Principal; import java.security.Principal;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Base64; import java.util.Base64;
import java.util.LinkedList; import java.util.LinkedList;
@ -76,7 +73,9 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProvider; import org.whispersystems.websocket.WebSocketResourceProvider;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.logging.WebsocketRequestLog; import org.whispersystems.websocket.logging.WebsocketRequestLog;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory; import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
@ -132,38 +131,6 @@ class AuthEnablementRefreshRequirementProviderTest {
.forEach(device -> when(clientPresenceManager.isPresent(uuid, device.getId())).thenReturn(true)); .forEach(device -> when(clientPresenceManager.isPresent(uuid, device.getId())).thenReturn(true));
} }
@Test
void testBuildDevicesEnabled() {
final byte disabledDeviceId = 3;
final Account account = mock(Account.class);
final List<Device> devices = new ArrayList<>();
when(account.getDevices()).thenReturn(devices);
IntStream.range(1, 5)
.forEach(id -> {
final Device device = mock(Device.class);
when(device.getId()).thenReturn((byte) id);
when(device.isEnabled()).thenReturn(id != disabledDeviceId);
devices.add(device);
});
final Map<Byte, Boolean> devicesEnabled = AuthEnablementRefreshRequirementProvider.buildDevicesEnabledMap(account);
assertEquals(4, devicesEnabled.size());
assertAll(devicesEnabled.entrySet().stream()
.map(deviceAndEnabled -> () -> {
if (deviceAndEnabled.getKey().equals(disabledDeviceId)) {
assertFalse(deviceAndEnabled.getValue());
} else {
assertTrue(deviceAndEnabled.getValue());
}
}));
}
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void testDeviceEnabledChanged(final Map<Byte, Boolean> initialEnabled, final Map<Byte, Boolean> finalEnabled) { void testDeviceEnabledChanged(final Map<Byte, Boolean> initialEnabled, final Map<Byte, Boolean> finalEnabled) {
@ -308,7 +275,7 @@ class AuthEnablementRefreshRequirementProviderTest {
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
provider = new WebSocketResourceProvider<>("127.0.0.1", RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, provider = new WebSocketResourceProvider<>("127.0.0.1", RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME,
applicationHandler, requestLog, new TestPrincipal("test", account, authenticatedDevice), applicationHandler, requestLog, TestPrincipal.reusableAuth("test", account, authenticatedDevice),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
remoteEndpoint = mock(RemoteEndpoint.class); remoteEndpoint = mock(RemoteEndpoint.class);
@ -349,7 +316,7 @@ class AuthEnablementRefreshRequirementProviderTest {
private final Account account; private final Account account;
private final Device device; private final Device device;
private TestPrincipal(String name, final Account account, final Device device) { private TestPrincipal(final String name, final Account account, final Device device) {
this.name = name; this.name = name;
this.account = account; this.account = account;
this.device = device; this.device = device;
@ -369,6 +336,11 @@ class AuthEnablementRefreshRequirementProviderTest {
public Device getAuthenticatedDevice() { public Device getAuthenticatedDevice() {
return device; return device;
} }
public static ReusableAuth<TestPrincipal> reusableAuth(final String name, final Account account, final Device device) {
return ReusableAuth.authenticated(new TestPrincipal(name, account, device), PrincipalSupplier.forImmutablePrincipal());
}
} }
@Path("/v1/test") @Path("/v1/test")

View File

@ -0,0 +1,55 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.IntStream;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class ContainerRequestUtilTest {
@Test
void testBuildDevicesEnabled() {
final byte disabledDeviceId = 3;
final Account account = mock(Account.class);
final List<Device> devices = new ArrayList<>();
when(account.getDevices()).thenReturn(devices);
IntStream.range(1, 5)
.forEach(id -> {
final Device device = mock(Device.class);
when(device.getId()).thenReturn((byte) id);
when(device.isEnabled()).thenReturn(id != disabledDeviceId);
devices.add(device);
});
final Map<Byte, Boolean> devicesEnabled = ContainerRequestUtil.AccountInfo.fromAccount(account).devicesEnabled();
assertEquals(4, devicesEnabled.size());
assertAll(devicesEnabled.entrySet().stream()
.map(deviceAndEnabled -> () -> {
if (deviceAndEnabled.getKey().equals(disabledDeviceId)) {
assertFalse(deviceAndEnabled.getValue());
} else {
assertTrue(deviceAndEnabled.getValue());
}
}));
}
}

View File

@ -5,108 +5,292 @@
package org.whispersystems.textsecuregcm.auth; package org.whispersystems.textsecuregcm.auth;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth;
import io.dropwizard.auth.AuthDynamicFeature;
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import java.io.IOException;
import java.net.URI;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.EnumSet;
import java.util.List; import java.util.Optional;
import java.util.Map;
import java.util.UUID; import java.util.UUID;
import javax.annotation.Nullable; import javax.servlet.DispatcherType;
import javax.ws.rs.core.SecurityContext; import javax.servlet.ServletRegistration;
import org.glassfish.jersey.server.ContainerRequest; import javax.ws.rs.GET;
import org.glassfish.jersey.server.monitoring.RequestEvent; import javax.ws.rs.Path;
import javax.ws.rs.client.Invocation;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
class PhoneNumberChangeRefreshRequirementProviderTest { class PhoneNumberChangeRefreshRequirementProviderTest {
private PhoneNumberChangeRefreshRequirementProvider provider;
private Account account;
private RequestEvent requestEvent;
private ContainerRequest request;
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final String NUMBER = "+18005551234"; private static final String NUMBER = "+18005551234";
private static final String CHANGED_NUMBER = "+18005554321"; private static final String CHANGED_NUMBER = "+18005554321";
private static final String TEST_CRED_HEADER = HeaderUtils.basicAuthHeader("test", "password");
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION = new DropwizardAppExtension<>(
TestApplication.class);
private static final AccountAuthenticator AUTHENTICATOR = mock(AccountAuthenticator.class);
private static final AccountsManager ACCOUNTS_MANAGER = mock(AccountsManager.class);
private static final ClientPresenceManager CLIENT_PRESENCE = mock(ClientPresenceManager.class);
private WebSocketClient client;
private final Account account1 = new Account();
private final Account account2 = new Account();
private final Device authenticatedDevice = DevicesHelper.createDevice(Device.PRIMARY_ID);
@BeforeEach @BeforeEach
void setUp() { void setUp() throws Exception {
provider = new PhoneNumberChangeRefreshRequirementProvider(); reset(AUTHENTICATOR, CLIENT_PRESENCE, ACCOUNTS_MANAGER);
client = new WebSocketClient();
client.start();
account = mock(Account.class); final UUID uuid = UUID.randomUUID();
final Device device = mock(Device.class); account1.setUuid(uuid);
account1.addDevice(authenticatedDevice);
account1.setNumber(NUMBER, UUID.randomUUID());
when(account.getUuid()).thenReturn(ACCOUNT_UUID); account2.setUuid(uuid);
when(account.getNumber()).thenReturn(NUMBER); account2.addDevice(authenticatedDevice);
when(account.getDevices()).thenReturn(List.of(device)); account2.setNumber(CHANGED_NUMBER, UUID.randomUUID());
when(device.getId()).thenReturn(Device.PRIMARY_ID);
request = mock(ContainerRequest.class); }
final Map<String, Object> requestProperties = new HashMap<>(); @AfterEach
void tearDown() throws Exception {
client.stop();
}
doAnswer(invocation -> {
requestProperties.put(invocation.getArgument(0, String.class), invocation.getArgument(1));
return null;
}).when(request).setProperty(anyString(), any());
when(request.getProperty(anyString())).thenAnswer( public static class TestApplication extends Application<Configuration> {
invocation -> requestProperties.get(invocation.getArgument(0, String.class)));
requestEvent = mock(RequestEvent.class); @Override
when(requestEvent.getContainerRequest()).thenReturn(request); public void run(final Configuration configuration, final Environment environment) throws Exception {
final TestController testController = new TestController();
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
final WebSocketEnvironment<AuthenticatedAccount> webSocketEnvironment =
new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.jersey().register(testController);
webSocketEnvironment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter(true))
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(new RemoteAddressFilter(true));
webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE));
environment.jersey()
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE));
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {
});
environment.jersey().register(new AuthDynamicFeature(new BasicCredentialAuthFilter.Builder<AuthenticatedAccount>()
.setAuthenticator(AUTHENTICATOR)
.buildAuthFilter()));
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(AUTHENTICATOR, mock(PrincipalSupplier.class)));
final WebSocketResourceProviderFactory<AuthenticatedAccount> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedAccount.class,
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
final ServletRegistration.Dynamic websocketServlet =
environment.servlets().addServlet("WebSocket", webSocketServlet);
websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
}
}
enum Protocol { HTTP, WEBSOCKET }
private void makeAnonymousRequest(final Protocol protocol, final String requestPath) throws IOException {
makeRequest(protocol, requestPath, true);
}
/*
* Make an authenticated request that will return account1 as the principal
*/
private void makeAuthenticatedRequest(
final Protocol protocol,
final String requestPath) throws IOException {
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedAccount(account1, authenticatedDevice)));
makeRequest(protocol,requestPath, false);
}
private void makeRequest(final Protocol protocol, final String requestPath, final boolean anonymous) throws IOException {
switch (protocol) {
case WEBSOCKET -> {
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
if (!anonymous) {
upgradeRequest.setHeader(HttpHeaders.AUTHORIZATION, TEST_CRED_HEADER);
}
client.connect(
testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())),
upgradeRequest);
testWebsocketListener.sendRequest(requestPath, "GET", Collections.emptyList(), Optional.empty()).join();
}
case HTTP -> {
final Invocation.Builder request = DROPWIZARD_APP_EXTENSION.client()
.target("http://127.0.0.1:%s%s".formatted(DROPWIZARD_APP_EXTENSION.getLocalPort(), requestPath))
.request();
if (!anonymous) {
request.header(HttpHeaders.AUTHORIZATION, TEST_CRED_HEADER);
}
request.get();
}
}
}
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestNoChange(final Protocol protocol) throws IOException {
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account1));
makeAuthenticatedRequest(protocol, "/test/annotated");
// Event listeners can fire after responses are sent
verify(ACCOUNTS_MANAGER, timeout(5000).times(1)).getByAccountIdentifier(eq(account1.getUuid()));
verifyNoMoreInteractions(CLIENT_PRESENCE);
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
}
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestChange(final Protocol protocol) throws IOException {
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account2));
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedAccount(account1, authenticatedDevice)));
makeAuthenticatedRequest(protocol, "/test/annotated");
// Make sure we disconnect the account if the account has changed numbers. Event listeners can fire after responses
// are sent, so use a timeout.
verify(CLIENT_PRESENCE, timeout(5000))
.disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId()));
verifyNoMoreInteractions(CLIENT_PRESENCE);
} }
@Test @Test
void handleRequestNoChange() { void handleRequestChangeAsyncEndpoint() throws IOException {
setAuthenticatedAccount(request, account); when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account2));
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedAccount(account1, authenticatedDevice)));
provider.handleRequestFiltered(requestEvent); // Event listeners with asynchronous HTTP endpoints don't currently correctly maintain state between request and
assertEquals(Collections.emptyList(), provider.handleRequestFinished(requestEvent)); // response
makeAuthenticatedRequest(Protocol.WEBSOCKET, "/test/async-annotated");
// Make sure we disconnect the account if the account has changed numbers. Event listeners can fire after responses
// are sent, so use a timeout.
verify(CLIENT_PRESENCE, timeout(5000))
.disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId()));
verifyNoMoreInteractions(CLIENT_PRESENCE);
} }
@Test @ParameterizedTest
void handleRequestNumberChange() { @EnumSource(Protocol.class)
setAuthenticatedAccount(request, account); void handleRequestNotAnnotated(final Protocol protocol) throws IOException, InterruptedException {
makeAuthenticatedRequest(protocol,"/test/not-annotated");
provider.handleRequestFiltered(requestEvent); // Give a tick for event listeners to run. Racy, but should occasionally catch an errant running listener if one is
when(account.getNumber()).thenReturn(CHANGED_NUMBER); // introduced.
assertEquals(List.of(new Pair<>(ACCOUNT_UUID, Device.PRIMARY_ID)), provider.handleRequestFinished(requestEvent)); Thread.sleep(100);
// Shouldn't even read the account if the method has not been annotated
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
verifyNoMoreInteractions(CLIENT_PRESENCE);
} }
@Test @ParameterizedTest
void handleRequestNoAuthenticatedAccount() { @EnumSource(Protocol.class)
final ContainerRequest request = mock(ContainerRequest.class); void handleRequestNotAuthenticated(final Protocol protocol) throws IOException, InterruptedException {
setAuthenticatedAccount(request, null); makeAnonymousRequest(protocol, "/test/not-authenticated");
when(requestEvent.getContainerRequest()).thenReturn(request); // Give a tick for event listeners to run. Racy, but should occasionally catch an errant running listener if one is
// introduced.
Thread.sleep(100);
provider.handleRequestFiltered(requestEvent); // Shouldn't even read the account if the method has not been annotated
assertEquals(Collections.emptyList(), provider.handleRequestFinished(requestEvent)); verifyNoMoreInteractions(ACCOUNTS_MANAGER);
verifyNoMoreInteractions(CLIENT_PRESENCE);
} }
private static void setAuthenticatedAccount(final ContainerRequest mockRequest, @Nullable final Account account) {
final SecurityContext securityContext = mock(SecurityContext.class);
when(mockRequest.getSecurityContext()).thenReturn(securityContext); @Path("/test")
public static class TestController {
if (account != null) { @GET
final AuthenticatedAccount authenticatedAccount = mock(AuthenticatedAccount.class); @Path("/annotated")
@ChangesPhoneNumber
public String annotated(@ReadOnly @Auth final AuthenticatedAccount account) {
return "ok";
}
when(securityContext.getUserPrincipal()).thenReturn(authenticatedAccount); @GET
when(authenticatedAccount.getAccount()).thenReturn(account); @Path("/async-annotated")
} else { @ChangesPhoneNumber
when(securityContext.getUserPrincipal()).thenReturn(null); @ManagedAsync
public String asyncAnnotated(@ReadOnly @Auth final AuthenticatedAccount account) {
return "ok";
}
@GET
@Path("/not-authenticated")
@ChangesPhoneNumber
public String notAuthenticated() {
return "ok";
}
@GET
@Path("/not-annotated")
public String notAnnotated(@ReadOnly @Auth final AuthenticatedAccount account) {
return "ok";
} }
} }
} }

View File

@ -87,7 +87,7 @@ class BasicCredentialAuthenticationInterceptorTest {
when(device.getId()).thenReturn(Device.PRIMARY_ID); when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(accountAuthenticator.authenticate(any())) when(accountAuthenticator.authenticate(any()))
.thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(account, device)))); .thenReturn(Optional.of(new AuthenticatedAccount(account, device)));
} else { } else {
when(accountAuthenticator.authenticate(any())) when(accountAuthenticator.authenticate(any()))
.thenReturn(Optional.empty()); .thenReturn(Optional.empty());

View File

@ -39,7 +39,7 @@ class DirectoryControllerV2Test {
when(account.getUuid()).thenReturn(uuid); when(account.getUuid()).thenReturn(uuid);
final ExternalServiceCredentials credentials = (ExternalServiceCredentials) controller.getAuthToken( final ExternalServiceCredentials credentials = (ExternalServiceCredentials) controller.getAuthToken(
new AuthenticatedAccount(() -> new Pair<>(account, mock(Device.class)))).getEntity(); new AuthenticatedAccount(account, mock(Device.class))).getEntity();
assertEquals(credentials.username(), "d369bc712e2e0dd36258"); assertEquals(credentials.username(), "d369bc712e2e0dd36258");
assertEquals(credentials.password(), "1633738643:4433b0fab41f25f79dd4"); assertEquals(credentials.password(), "1633738643:4433b0fab41f25f79dd4");

View File

@ -51,7 +51,10 @@ import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.tests.util.TestPrincipal;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProvider; import org.whispersystems.websocket.WebSocketResourceProvider;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.logging.WebsocketRequestLog; import org.whispersystems.websocket.logging.WebsocketRequestLog;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory; import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
@ -139,7 +142,7 @@ class MetricsRequestEventListenerTest {
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.reusableAuth("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class); final Session session = mock(Session.class);
@ -201,7 +204,7 @@ class MetricsRequestEventListenerTest {
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.reusableAuth("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class); final Session session = mock(Session.class);
@ -252,19 +255,6 @@ class MetricsRequestEventListenerTest {
return SubProtocol.WebSocketMessage.parseFrom(responseCaptor.getValue().array()).getResponse(); return SubProtocol.WebSocketMessage.parseFrom(responseCaptor.getValue().array()).getResponse();
} }
public static class TestPrincipal implements Principal {
private final String name;
private TestPrincipal(String name) {
this.name = name;
}
@Override
public String getName() {
return name;
}
}
@Path("/v1/test") @Path("/v1/test")
public static class TestResource { public static class TestResource {

View File

@ -1,71 +0,0 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.util.Pair;
class RefreshingAccountAndDeviceSupplierTest {
@Test
void test() {
final AccountsManager accountsManager = mock(AccountsManager.class);
final UUID uuid = UUID.randomUUID();
final byte deviceId = 2;
final Account initialAccount = mock(Account.class);
final Device initialDevice = mock(Device.class);
when(initialAccount.getUuid()).thenReturn(uuid);
when(initialDevice.getId()).thenReturn(deviceId);
when(initialAccount.getDevice(deviceId)).thenReturn(Optional.of(initialDevice));
when(accountsManager.getByAccountIdentifier(any(UUID.class))).thenAnswer(answer -> {
final Account account = mock(Account.class);
final Device device = mock(Device.class);
when(account.getUuid()).thenReturn(answer.getArgument(0, UUID.class));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
return Optional.of(account);
});
final RefreshingAccountAndDeviceSupplier refreshingAccountAndDeviceSupplier = new RefreshingAccountAndDeviceSupplier(
initialAccount, deviceId, accountsManager);
Pair<Account, Device> accountAndDevice = refreshingAccountAndDeviceSupplier.get();
assertSame(initialAccount, accountAndDevice.first());
assertSame(initialDevice, accountAndDevice.second());
accountAndDevice = refreshingAccountAndDeviceSupplier.get();
assertSame(initialAccount, accountAndDevice.first());
assertSame(initialDevice, accountAndDevice.second());
when(initialAccount.isStale()).thenReturn(true);
accountAndDevice = refreshingAccountAndDeviceSupplier.get();
assertNotSame(initialAccount, accountAndDevice.first());
assertNotSame(initialDevice, accountAndDevice.second());
assertEquals(uuid, accountAndDevice.first().getUuid());
}
}

View File

@ -0,0 +1,27 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.util;
import java.security.Principal;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.PrincipalSupplier;
public class TestPrincipal implements Principal {
private final String name;
private TestPrincipal(String name) {
this.name = name;
}
@Override
public String getName() {
return name;
}
public static ReusableAuth<TestPrincipal> reusableAuth(final String name) {
return ReusableAuth.authenticated(new TestPrincipal(name), PrincipalSupplier.forImmutablePrincipal());
}
}

View File

@ -0,0 +1,79 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.util;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
public class TestWebsocketListener implements WebSocketListener {
private final AtomicLong requestId = new AtomicLong();
private final CompletableFuture<Session> started = new CompletableFuture<>();
private final ConcurrentHashMap<Long, CompletableFuture<WebSocketResponseMessage>> responseFutures = new ConcurrentHashMap<>();
private final WebSocketMessageFactory messageFactory;
public TestWebsocketListener() {
this.messageFactory = new ProtobufWebSocketMessageFactory();
}
@Override
public void onWebSocketConnect(final Session session) {
started.complete(session);
}
public CompletableFuture<WebSocketResponseMessage> doGet(final String requestPath) {
return sendRequest(requestPath, "GET", List.of("Accept: application/json"), Optional.empty());
}
public CompletableFuture<WebSocketResponseMessage> sendRequest(
final String requestPath,
final String verb,
final List<String> headers,
final Optional<byte[]> body) {
return started.thenCompose(session -> {
final long id = requestId.incrementAndGet();
final CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
responseFutures.put(id, future);
final byte[] requestBytes = messageFactory.createRequest(
Optional.of(id), verb, requestPath, headers, body).toByteArray();
try {
session.getRemote().sendBytes(ByteBuffer.wrap(requestBytes));
} catch (IOException e) {
throw new RuntimeException(e);
}
return future;
});
}
@Override
public void onWebSocketBinary(final byte[] payload, final int offset, final int length) {
try {
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) {
responseFutures.get(webSocketMessage.getResponseMessage().getRequestId())
.complete(webSocketMessage.getResponseMessage());
} else {
throw new RuntimeException("Unexpected message type: " + webSocketMessage.getType());
}
} catch (final Exception e) {
throw new RuntimeException(e);
}
}
}

View File

@ -57,6 +57,7 @@ import org.junit.jupiter.params.provider.MethodSource;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper; import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper;
import org.whispersystems.textsecuregcm.tests.util.TestPrincipal;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.WebSocketResourceProvider; import org.whispersystems.websocket.WebSocketResourceProvider;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
@ -175,7 +176,8 @@ class LoggingUnhandledExceptionMapperTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog,
TestPrincipal.reusableAuth("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@ -238,18 +240,4 @@ class LoggingUnhandledExceptionMapperTest {
throw new RuntimeException(); throw new RuntimeException();
} }
} }
public static class TestPrincipal implements Principal {
private final String name;
private TestPrincipal(String name) {
this.name = name;
}
@Override
public String getName() {
return name;
}
}
} }

View File

@ -28,8 +28,8 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.auth.PrincipalSupplier;
class WebSocketAccountAuthenticatorTest { class WebSocketAccountAuthenticatorTest {
@ -52,7 +52,7 @@ class WebSocketAccountAuthenticatorTest {
accountAuthenticator = mock(AccountAuthenticator.class); accountAuthenticator = mock(AccountAuthenticator.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(mock(Account.class), mock(Device.class))))); .thenReturn(Optional.of(new AuthenticatedAccount(mock(Account.class), mock(Device.class))));
when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD)))) when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD))))
.thenReturn(Optional.empty()); .thenReturn(Optional.empty());
@ -66,7 +66,7 @@ class WebSocketAccountAuthenticatorTest {
@Nullable final String authorizationHeaderValue, @Nullable final String authorizationHeaderValue,
final Map<String, List<String>> upgradeRequestParameters, final Map<String, List<String>> upgradeRequestParameters,
final boolean expectAccount, final boolean expectAccount,
final boolean expectCredentialsPresented) throws Exception { final boolean expectInvalid) throws Exception {
when(upgradeRequest.getParameterMap()).thenReturn(upgradeRequestParameters); when(upgradeRequest.getParameterMap()).thenReturn(upgradeRequestParameters);
if (authorizationHeaderValue != null) { if (authorizationHeaderValue != null) {
@ -74,13 +74,13 @@ class WebSocketAccountAuthenticatorTest {
} }
final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator( final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(
accountAuthenticator); accountAuthenticator,
mock(PrincipalSupplier.class));
final WebSocketAuthenticator.AuthenticationResult<AuthenticatedAccount> result = webSocketAuthenticator.authenticate( final ReusableAuth<AuthenticatedAccount> result = webSocketAuthenticator.authenticate(upgradeRequest);
upgradeRequest);
assertEquals(expectAccount, result.getUser().isPresent()); assertEquals(expectAccount, result.ref().isPresent());
assertEquals(expectCredentialsPresented, result.credentialsPresented()); assertEquals(expectInvalid, result.invalidCredentialsProvided());
} }
private static Stream<Arguments> testAuthenticate() { private static Stream<Arguments> testAuthenticate() {
@ -94,17 +94,17 @@ class WebSocketAccountAuthenticatorTest {
HeaderUtils.basicAuthHeader(INVALID_USER, INVALID_PASSWORD); HeaderUtils.basicAuthHeader(INVALID_USER, INVALID_PASSWORD);
return Stream.of( return Stream.of(
// if `Authorization` header is present, outcome should not depend on the value of query parameters // if `Authorization` header is present, outcome should not depend on the value of query parameters
Arguments.of(headerWithValidAuth, Map.of(), true, true), Arguments.of(headerWithValidAuth, Map.of(), true, false),
Arguments.of(headerWithInvalidAuth, Map.of(), false, true), Arguments.of(headerWithInvalidAuth, Map.of(), false, true),
Arguments.of("invalid header value", Map.of(), false, true), Arguments.of("invalid header value", Map.of(), false, true),
Arguments.of(headerWithValidAuth, paramsMapWithValidAuth, true, true), Arguments.of(headerWithValidAuth, paramsMapWithValidAuth, true, false),
Arguments.of(headerWithInvalidAuth, paramsMapWithValidAuth, false, true), Arguments.of(headerWithInvalidAuth, paramsMapWithValidAuth, false, true),
Arguments.of("invalid header value", paramsMapWithValidAuth, false, true), Arguments.of("invalid header value", paramsMapWithValidAuth, false, true),
Arguments.of(headerWithValidAuth, paramsMapWithInvalidAuth, true, true), Arguments.of(headerWithValidAuth, paramsMapWithInvalidAuth, true, false),
Arguments.of(headerWithInvalidAuth, paramsMapWithInvalidAuth, false, true), Arguments.of(headerWithInvalidAuth, paramsMapWithInvalidAuth, false, true),
Arguments.of("invalid header value", paramsMapWithInvalidAuth, false, true), Arguments.of("invalid header value", paramsMapWithInvalidAuth, false, true),
// if `Authorization` header is not set, outcome should match the query params based auth // if `Authorization` header is not set, outcome should match the query params based auth
Arguments.of(null, paramsMapWithValidAuth, true, true), Arguments.of(null, paramsMapWithValidAuth, true, false),
Arguments.of(null, paramsMapWithInvalidAuth, false, true), Arguments.of(null, paramsMapWithInvalidAuth, false, true),
Arguments.of(null, Map.of(), false, false) Arguments.of(null, Map.of(), false, false)
); );

View File

@ -125,7 +125,7 @@ class WebSocketConnectionIntegrationTest {
final WebSocketConnection webSocketConnection = new WebSocketConnection( final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class), mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
new AuthenticatedAccount(() -> new Pair<>(account, device)), new AuthenticatedAccount(account, device),
device, device,
webSocketClient, webSocketClient,
scheduledExecutorService, scheduledExecutorService,
@ -210,7 +210,7 @@ class WebSocketConnectionIntegrationTest {
final WebSocketConnection webSocketConnection = new WebSocketConnection( final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class), mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
new AuthenticatedAccount(() -> new Pair<>(account, device)), new AuthenticatedAccount(account, device),
device, device,
webSocketClient, webSocketClient,
scheduledExecutorService, scheduledExecutorService,
@ -276,7 +276,7 @@ class WebSocketConnectionIntegrationTest {
final WebSocketConnection webSocketConnection = new WebSocketConnection( final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class), mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService), new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
new AuthenticatedAccount(() -> new Pair<>(account, device)), new AuthenticatedAccount(account, device),
device, device,
webSocketClient, webSocketClient,
100, // use a very short timeout, so that this test completes quickly 100, // use a very short timeout, so that this test completes quickly

View File

@ -64,9 +64,9 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult; import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.messages.WebSocketResponseMessage; import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.session.WebSocketSessionContext; import org.whispersystems.websocket.session.WebSocketSessionContext;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
@ -101,7 +101,7 @@ class WebSocketConnectionTest {
accountsManager = mock(AccountsManager.class); accountsManager = mock(AccountsManager.class);
account = mock(Account.class); account = mock(Account.class);
device = mock(Device.class); device = mock(Device.class);
auth = new AuthenticatedAccount(() -> new Pair<>(account, device)); auth = new AuthenticatedAccount(account, device);
upgradeRequest = mock(UpgradeRequest.class); upgradeRequest = mock(UpgradeRequest.class);
messagesManager = mock(MessagesManager.class); messagesManager = mock(MessagesManager.class);
receiptSender = mock(ReceiptSender.class); receiptSender = mock(ReceiptSender.class);
@ -118,18 +118,19 @@ class WebSocketConnectionTest {
@Test @Test
void testCredentials() throws Exception { void testCredentials() throws Exception {
WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator); WebSocketAccountAuthenticator webSocketAuthenticator =
new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class));
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager, AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager,
mock(PushNotificationManager.class), mock(ClientPresenceManager.class), mock(PushNotificationManager.class), mock(ClientPresenceManager.class),
retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager); retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager);
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(account, device)))); .thenReturn(Optional.of(new AuthenticatedAccount(account, device)));
AuthenticationResult<AuthenticatedAccount> account = webSocketAuthenticator.authenticate(upgradeRequest); ReusableAuth<AuthenticatedAccount> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated()).thenReturn(account.getUser().orElse(null)); when(sessionContext.getAuthenticated()).thenReturn(account.ref().orElse(null));
when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.getUser().orElse(null)); when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.ref().orElse(null));
final WebSocketClient webSocketClient = mock(WebSocketClient.class); final WebSocketClient webSocketClient = mock(WebSocketClient.class);
when(webSocketClient.getUserAgent()).thenReturn("Signal-Android/6.22.8"); when(webSocketClient.getUserAgent()).thenReturn("Signal-Android/6.22.8");
@ -144,8 +145,8 @@ class WebSocketConnectionTest {
// unauthenticated // unauthenticated
when(upgradeRequest.getParameterMap()).thenReturn(Map.of()); when(upgradeRequest.getParameterMap()).thenReturn(Map.of());
account = webSocketAuthenticator.authenticate(upgradeRequest); account = webSocketAuthenticator.authenticate(upgradeRequest);
assertFalse(account.getUser().isPresent()); assertFalse(account.ref().isPresent());
assertFalse(account.credentialsPresented()); assertFalse(account.invalidCredentialsProvided());
connectListener.onWebSocketConnect(sessionContext); connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext, times(2)).addWebsocketClosedListener( verify(sessionContext, times(2)).addWebsocketClosedListener(

View File

@ -0,0 +1,182 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket;
import java.security.Principal;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import org.whispersystems.websocket.auth.PrincipalSupplier;
/**
* This class holds a principal that can be reused across requests on a websocket. Since two requests may operate
* concurrently on the same principal, and some principals contain non thread-safe mutable state, appropriate use of
* this class ensures that no data races occur. It also ensures that after a principal is modified, a subsequent request
* gets the up-to-date principal
*
* @param <T> The underlying principal type
* @see PrincipalSupplier
*/
public abstract sealed class ReusableAuth<T extends Principal> {
/**
* Get a reference to the underlying principal that callers pledge not to modify.
* <p>
* The reference returned will potentially be provided to many threads concurrently accessing the principal. Callers
* should use this method only if they can ensure that they will not modify the in-memory principal object AND they do
* not intend to modify the underlying canonical representation of the principal.
* <p>
* For example, if a caller retrieves a reference to a principal, does not modify the in memory state, but updates a
* field on a database that should be reflected in subsequent retrievals of the principal, they will have met the
* first criteria, but not the second. In that case they should instead use {@link #mutableRef()}.
* <p>
* If other callers have modified the underlying principal by using {@link #mutableRef()}, this method may need to
* refresh the principal via {@link PrincipalSupplier#refresh} which could be a blocking operation.
*
* @return If authenticated, a reference to the underlying principal that should not be modified
*/
public abstract Optional<T> ref();
public interface MutableRef<T> {
T ref();
void close();
}
/**
* Get a reference to the underlying principal that may be modified.
* <p>
* The underlying principal can be safely modified. Multiple threads may operate on the same {@link ReusableAuth} so
* long as they each have their own {@link MutableRef}. After any modifications, the caller must call
* {@link MutableRef#close} to notify the principal has become dirty. Close should be called after modifications but
* before sending a response on the websocket. This ensures that a request that comes in after a modification response
* is received is guaranteed to see the modification.
*
* @return If authenticated, a reference to the underlying principal that may be modified
*/
public abstract Optional<MutableRef<T>> mutableRef();
public boolean invalidCredentialsProvided() {
return switch (this) {
case Invalid<T> ignored -> true;
case ReusableAuth.Anonymous<T> ignored -> false;
case ReusableAuth.Authenticated<T> ignored-> false;
};
}
/**
* @return A {@link ReusableAuth} indicating no credential were provided
*/
public static <T extends Principal> ReusableAuth<T> anonymous() {
//noinspection unchecked
return (ReusableAuth<T>) Anonymous.ANON_RESULT;
}
/**
* @return A {@link ReusableAuth} indicating that invalid credentials were provided
*/
public static <T extends Principal> ReusableAuth<T> invalid() {
//noinspection unchecked
return (ReusableAuth<T>) Invalid.INVALID_RESULT;
}
/**
* Create a successfully authenticated {@link ReusableAuth}
*
* @param principal The authenticated principal
* @param principalSupplier Instructions for how to refresh or copy this principal
* @param <T> The principal type
* @return A {@link ReusableAuth} for a successfully authenticated principal
*/
public static <T extends Principal> ReusableAuth<T> authenticated(T principal,
PrincipalSupplier<T> principalSupplier) {
return new Authenticated<>(principal, principalSupplier);
}
private static final class Invalid<T extends Principal> extends ReusableAuth<T> {
@SuppressWarnings({"rawtypes"})
private static final ReusableAuth INVALID_RESULT = new Invalid();
@Override
public Optional<T> ref() {
return Optional.empty();
}
@Override
public Optional<MutableRef<T>> mutableRef() {
return Optional.empty();
}
}
private static final class Anonymous<T extends Principal> extends ReusableAuth<T> {
@SuppressWarnings({"rawtypes"})
private static final ReusableAuth ANON_RESULT = new Anonymous();
@Override
public Optional<T> ref() {
return Optional.empty();
}
@Override
public Optional<MutableRef<T>> mutableRef() {
return Optional.empty();
}
}
private static final class Authenticated<T extends Principal> extends ReusableAuth<T> {
private T basePrincipal;
private final AtomicBoolean needRefresh = new AtomicBoolean(false);
private final PrincipalSupplier<T> principalSupplier;
Authenticated(final T basePrincipal, PrincipalSupplier<T> principalSupplier) {
this.basePrincipal = basePrincipal;
this.principalSupplier = principalSupplier;
}
@Override
public Optional<T> ref() {
maybeRefresh();
return Optional.of(basePrincipal);
}
@Override
public Optional<MutableRef<T>> mutableRef() {
maybeRefresh();
return Optional.of(new AuthenticatedMutableRef(principalSupplier.deepCopy(basePrincipal)));
}
private void maybeRefresh() {
if (needRefresh.compareAndSet(true, false)) {
basePrincipal = principalSupplier.refresh(basePrincipal);
}
}
private class AuthenticatedMutableRef implements MutableRef<T> {
final T ref;
private AuthenticatedMutableRef(T ref) {
this.ref = ref;
}
public T ref() {
return ref;
}
public void close() {
needRefresh.set(true);
}
}
}
private ReusableAuth() {
}
}

View File

@ -58,7 +58,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
private final Map<Long, CompletableFuture<WebSocketResponseMessage>> requestMap = new ConcurrentHashMap<>(); private final Map<Long, CompletableFuture<WebSocketResponseMessage>> requestMap = new ConcurrentHashMap<>();
private final T authenticated; private final ReusableAuth<T> reusableAuth;
private final WebSocketMessageFactory messageFactory; private final WebSocketMessageFactory messageFactory;
private final Optional<WebSocketConnectListener> connectListener; private final Optional<WebSocketConnectListener> connectListener;
private final ApplicationHandler jerseyHandler; private final ApplicationHandler jerseyHandler;
@ -77,7 +77,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
String remoteAddressPropertyName, String remoteAddressPropertyName,
ApplicationHandler jerseyHandler, ApplicationHandler jerseyHandler,
WebsocketRequestLog requestLog, WebsocketRequestLog requestLog,
T authenticated, ReusableAuth<T> authenticated,
WebSocketMessageFactory messageFactory, WebSocketMessageFactory messageFactory,
Optional<WebSocketConnectListener> connectListener, Optional<WebSocketConnectListener> connectListener,
Duration idleTimeout) { Duration idleTimeout) {
@ -85,7 +85,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
this.remoteAddressPropertyName = remoteAddressPropertyName; this.remoteAddressPropertyName = remoteAddressPropertyName;
this.jerseyHandler = jerseyHandler; this.jerseyHandler = jerseyHandler;
this.requestLog = requestLog; this.requestLog = requestLog;
this.authenticated = authenticated; this.reusableAuth = authenticated;
this.messageFactory = messageFactory; this.messageFactory = messageFactory;
this.connectListener = connectListener; this.connectListener = connectListener;
this.idleTimeout = idleTimeout; this.idleTimeout = idleTimeout;
@ -97,7 +97,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
this.remoteEndpoint = session.getRemote(); this.remoteEndpoint = session.getRemote();
this.context = new WebSocketSessionContext( this.context = new WebSocketSessionContext(
new WebSocketClient(session, remoteEndpoint, messageFactory, requestMap)); new WebSocketClient(session, remoteEndpoint, messageFactory, requestMap));
this.context.setAuthenticated(authenticated); this.context.setAuthenticated(reusableAuth.ref().orElse(null));
this.session.setIdleTimeout(idleTimeout); this.session.setIdleTimeout(idleTimeout);
connectListener.ifPresent(listener -> listener.onWebSocketConnect(this.context)); connectListener.ifPresent(listener -> listener.onWebSocketConnect(this.context));
@ -162,6 +162,17 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
logger.debug("onWebSocketText!"); logger.debug("onWebSocketText!");
} }
/**
* The property name where {@link org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider} can find an
* {@link ReusableAuth} object that lives for the lifetime of the websocket
*/
public static final String REUSABLE_AUTH_PROPERTY = WebSocketResourceProvider.class.getName() + ".reusableAuth";
/**
* The property name where {@link org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider} can install a
* {@link org.whispersystems.websocket.ReusableAuth.MutableRef} for us to close when the request is finished
*/
public static final String RESOLVED_PRINCIPAL_PROPERTY = WebSocketResourceProvider.class.getName() + ".resolvedPrincipal";
private void handleRequest(WebSocketRequestMessage requestMessage) { private void handleRequest(WebSocketRequestMessage requestMessage) {
ContainerRequest containerRequest = new ContainerRequest(null, URI.create(requestMessage.getPath()), ContainerRequest containerRequest = new ContainerRequest(null, URI.create(requestMessage.getPath()),
requestMessage.getVerb(), new WebSocketSecurityContext(new ContextPrincipal(context)), requestMessage.getVerb(), new WebSocketSecurityContext(new ContextPrincipal(context)),
@ -173,19 +184,32 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
} }
containerRequest.setProperty(remoteAddressPropertyName, remoteAddress); containerRequest.setProperty(remoteAddressPropertyName, remoteAddress);
containerRequest.setProperty(REUSABLE_AUTH_PROPERTY, reusableAuth);
ByteArrayOutputStream responseBody = new ByteArrayOutputStream(); ByteArrayOutputStream responseBody = new ByteArrayOutputStream();
CompletableFuture<ContainerResponse> responseFuture = (CompletableFuture<ContainerResponse>) jerseyHandler.apply( CompletableFuture<ContainerResponse> responseFuture = (CompletableFuture<ContainerResponse>) jerseyHandler.apply(
containerRequest, responseBody); containerRequest, responseBody);
responseFuture.thenAccept(response -> { responseFuture
.whenComplete((ignoredResponse, ignoredError) -> {
// If the request ended up being one that mutates our principal, we have to close it to indicate we're done
// with the mutation operation
final Object resolvedPrincipal = containerRequest.getProperty(RESOLVED_PRINCIPAL_PROPERTY);
if (resolvedPrincipal instanceof ReusableAuth.MutableRef ref) {
ref.close();
} else if (resolvedPrincipal != null) {
logger.warn("unexpected resolved principal type {} : {}", resolvedPrincipal.getClass(), resolvedPrincipal);
}
})
.thenAccept(response -> {
try { try {
sendResponse(requestMessage, response, responseBody); sendResponse(requestMessage, response, responseBody);
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
requestLog.log(remoteAddress, containerRequest, response); requestLog.log(remoteAddress, containerRequest, response);
}).exceptionally(exception -> { })
.exceptionally(exception -> {
logger.warn("Websocket Error: " + requestMessage.getVerb() + " " + requestMessage.getPath() + "\n" logger.warn("Websocket Error: " + requestMessage.getVerb() + " " + requestMessage.getPath() + "\n"
+ requestMessage.getBody(), exception); + requestMessage.getBody(), exception);
try { try {

View File

@ -22,7 +22,6 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.websocket.auth.AuthenticationException; import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider; import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider;
@ -57,17 +56,17 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
public Object createWebSocket(final JettyServerUpgradeRequest request, final JettyServerUpgradeResponse response) { public Object createWebSocket(final JettyServerUpgradeRequest request, final JettyServerUpgradeResponse response) {
try { try {
Optional<WebSocketAuthenticator<T>> authenticator = Optional.ofNullable(environment.getAuthenticator()); Optional<WebSocketAuthenticator<T>> authenticator = Optional.ofNullable(environment.getAuthenticator());
T authenticated = null;
final ReusableAuth<T> authenticated;
if (authenticator.isPresent()) { if (authenticator.isPresent()) {
AuthenticationResult<T> authenticationResult = authenticator.get().authenticate(request); authenticated = authenticator.get().authenticate(request);
if (authenticationResult.getUser().isEmpty() && authenticationResult.credentialsPresented()) { if (authenticated.invalidCredentialsProvided()) {
response.sendForbidden("Unauthorized"); response.sendForbidden("Unauthorized");
return null; return null;
} else {
authenticated = authenticationResult.getUser().orElse(null);
} }
} else {
authenticated = ReusableAuth.anonymous();
} }
return new WebSocketResourceProvider<>(getRemoteAddress(request), return new WebSocketResourceProvider<>(getRemoteAddress(request),

View File

@ -0,0 +1,26 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket.auth;
import io.dropwizard.auth.Auth;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* An @{@link Auth} object annotated with {@link Mutable} indicates that the consumer of the object
* will modify the object or its underlying canonical source.
*
* Note: An {@link Auth} object that does not specify @{@link ReadOnly} will be assumed to be @Mutable
*
* @see org.whispersystems.websocket.ReusableAuth
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD, ElementType.PARAMETER})
public @interface Mutable {
}

View File

@ -0,0 +1,58 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket.auth;
/**
* Teach {@link org.whispersystems.websocket.ReusableAuth} how to make a deep copy of a principal (that is safe to
* concurrently modify while the original principal is being read), and how to refresh a principal after it has been
* potentially modified.
*
* @param <T> The underlying principal type
*/
public interface PrincipalSupplier<T> {
/**
* Re-fresh the principal after it has been modified.
* <p>
* If the principal is populated from a backing store, refresh should re-read it.
*
* @param t the potentially stale principal to refresh
* @return The up-to-date principal
*/
T refresh(T t);
/**
* Create a deep, in-memory copy of the principal. This should be identical to the original principal, but should
* share no mutable state with the original. It should be safe for two threads to independently write and read from
* two independent deep copies.
*
* @param t the principal to copy
* @return An in-memory copy of the principal
*/
T deepCopy(T t);
class ImmutablePrincipalSupplier<T> implements PrincipalSupplier<T> {
@SuppressWarnings({"rawtypes"})
private static final PrincipalSupplier INSTANCE = new ImmutablePrincipalSupplier();
@Override
public T refresh(final T t) {
return t;
}
@Override
public T deepCopy(final T t) {
return t;
}
}
/**
* @return A principal supplier that can be used if the principal type does not support modification.
*/
static <T> PrincipalSupplier<T> forImmutablePrincipal() {
//noinspection unchecked
return (PrincipalSupplier<T>) ImmutablePrincipalSupplier.INSTANCE;
}
}

View File

@ -0,0 +1,25 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket.auth;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* An @{@link io.dropwizard.auth.Auth} object annotated with {@link ReadOnly} indicates that the consumer of the object
* will never modify the object, nor its underlying canonical source.
* <p>
* For example, a consumer of a @ReadOnly AuthenticatedAccount promises to never modify the in-memory
* AuthenticatedAccount and to never modify the underlying Account database for the account.
*
* @see org.whispersystems.websocket.ReusableAuth
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD, ElementType.PARAMETER})
public @interface ReadOnly {
}

View File

@ -7,27 +7,8 @@ package org.whispersystems.websocket.auth;
import java.security.Principal; import java.security.Principal;
import java.util.Optional; import java.util.Optional;
import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.websocket.ReusableAuth;
public interface WebSocketAuthenticator<T extends Principal> { public interface WebSocketAuthenticator<T extends Principal> {
ReusableAuth<T> authenticate(UpgradeRequest request) throws AuthenticationException;
AuthenticationResult<T> authenticate(UpgradeRequest request) throws AuthenticationException;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
class AuthenticationResult<T> {
private final Optional<T> user;
private final boolean credentialsPresented;
public AuthenticationResult(final Optional<T> user, final boolean credentialsPresented) {
this.user = user;
this.credentialsPresented = credentialsPresented;
}
public Optional<T> getUser() {
return user;
}
public boolean credentialsPresented() {
return credentialsPresented;
}
}
} }

View File

@ -5,24 +5,28 @@
package org.whispersystems.websocket.auth; package org.whispersystems.websocket.auth;
import io.dropwizard.auth.Auth; import io.dropwizard.auth.Auth;
import java.lang.reflect.ParameterizedType;
import java.security.Principal;
import java.util.Optional;
import java.util.function.Function;
import javax.annotation.Nullable;
import javax.inject.Inject;
import javax.inject.Singleton;
import javax.ws.rs.WebApplicationException;
import org.glassfish.jersey.internal.inject.AbstractBinder; import org.glassfish.jersey.internal.inject.AbstractBinder;
import org.glassfish.jersey.server.ContainerRequest; import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.internal.inject.AbstractValueParamProvider; import org.glassfish.jersey.server.internal.inject.AbstractValueParamProvider;
import org.glassfish.jersey.server.internal.inject.MultivaluedParameterExtractorProvider; import org.glassfish.jersey.server.internal.inject.MultivaluedParameterExtractorProvider;
import org.glassfish.jersey.server.model.Parameter; import org.glassfish.jersey.server.model.Parameter;
import org.glassfish.jersey.server.spi.internal.ValueParamProvider; import org.glassfish.jersey.server.spi.internal.ValueParamProvider;
import org.slf4j.Logger;
import javax.annotation.Nullable; import org.slf4j.LoggerFactory;
import javax.inject.Inject; import org.whispersystems.websocket.ReusableAuth;
import javax.inject.Singleton; import org.whispersystems.websocket.WebSocketResourceProvider;
import javax.ws.rs.WebApplicationException;
import java.lang.reflect.ParameterizedType;
import java.security.Principal;
import java.util.Optional;
import java.util.function.Function;
@Singleton @Singleton
public class WebsocketAuthValueFactoryProvider<T extends Principal> extends AbstractValueParamProvider { public class WebsocketAuthValueFactoryProvider<T extends Principal> extends AbstractValueParamProvider {
private static final Logger logger = LoggerFactory.getLogger(WebsocketAuthValueFactoryProvider.class);
private final Class<T> principalClass; private final Class<T> principalClass;
@ -39,18 +43,38 @@ public class WebsocketAuthValueFactoryProvider<T extends Principal> extends Abst
return null; return null;
} }
if (parameter.getRawType() == Optional.class && final boolean readOnly = parameter.isAnnotationPresent(ReadOnly.class);
ParameterizedType.class.isAssignableFrom(parameter.getType().getClass()) &&
principalClass == ((ParameterizedType)parameter.getType()).getActualTypeArguments()[0]) if (parameter.getRawType() == Optional.class
{ && ParameterizedType.class.isAssignableFrom(parameter.getType().getClass())
return request -> new OptionalContainerRequestValueFactory(request).provide(); && principalClass == ((ParameterizedType) parameter.getType()).getActualTypeArguments()[0]) {
return containerRequest -> createPrincipal(containerRequest, readOnly);
} else if (principalClass.equals(parameter.getRawType())) { } else if (principalClass.equals(parameter.getRawType())) {
return request -> new StandardContainerRequestValueFactory(request).provide(); return containerRequest ->
createPrincipal(containerRequest, readOnly)
.orElseThrow(() -> new WebApplicationException("Authenticated resource", 401));
} else { } else {
throw new IllegalStateException("Can't inject unassignable principal: " + principalClass + " for parameter: " + parameter); throw new IllegalStateException("Can't inject unassignable principal: " + principalClass + " for parameter: " + parameter);
} }
} }
private Optional<? extends Principal> createPrincipal(final ContainerRequest request, final boolean readOnly) {
final Object obj = request.getProperty(WebSocketResourceProvider.REUSABLE_AUTH_PROPERTY);
if (!(obj instanceof ReusableAuth<?>)) {
logger.warn("Unexpected reusable auth property type {} : {}", obj.getClass(), obj);
return Optional.empty();
}
@SuppressWarnings("unchecked") final ReusableAuth<T> reusableAuth = (ReusableAuth<T>) obj;
if (readOnly) {
return reusableAuth.ref();
} else {
return reusableAuth.mutableRef().map(writeRef -> {
request.setProperty(WebSocketResourceProvider.RESOLVED_PRINCIPAL_PROPERTY, writeRef);
return writeRef.ref();
});
}
}
@Singleton @Singleton
static class WebsocketPrincipalClassProvider<T extends Principal> { static class WebsocketPrincipalClassProvider<T extends Principal> {
@ -80,38 +104,4 @@ public class WebsocketAuthValueFactoryProvider<T extends Principal> extends Abst
bind(WebsocketAuthValueFactoryProvider.class).to(ValueParamProvider.class).in(Singleton.class); bind(WebsocketAuthValueFactoryProvider.class).to(ValueParamProvider.class).in(Singleton.class);
} }
} }
private static class StandardContainerRequestValueFactory {
private final ContainerRequest request;
public StandardContainerRequestValueFactory(ContainerRequest request) {
this.request = request;
}
public Principal provide() {
final Principal principal = request.getSecurityContext().getUserPrincipal();
if (principal == null) {
throw new WebApplicationException("Authenticated resource", 401);
}
return principal;
}
}
private static class OptionalContainerRequestValueFactory {
private final ContainerRequest request;
public OptionalContainerRequestValueFactory(ContainerRequest request) {
this.request = request;
}
public Optional<Principal> provide() {
return Optional.ofNullable(request.getSecurityContext().getUserPrincipal());
}
}
} }

View File

@ -7,6 +7,7 @@ package org.whispersystems.websocket;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -27,6 +28,7 @@ import org.glassfish.jersey.server.ResourceConfig;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.whispersystems.websocket.auth.AuthenticationException; import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.setup.WebSocketEnvironment; import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ -56,8 +58,7 @@ public class WebSocketResourceProviderFactoryTest {
@Test @Test
void testUnauthorized() throws AuthenticationException, IOException { void testUnauthorized() throws AuthenticationException, IOException {
when(environment.getAuthenticator()).thenReturn(authenticator); when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenReturn( when(authenticator.authenticate(eq(request))).thenReturn(ReusableAuth.invalid());
new WebSocketAuthenticator.AuthenticationResult<>(Optional.empty(), true));
when(environment.jersey()).thenReturn(jerseyEnvironment); when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class, WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
@ -74,8 +75,8 @@ public class WebSocketResourceProviderFactoryTest {
Account account = new Account(); Account account = new Account();
when(environment.getAuthenticator()).thenReturn(authenticator); when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenReturn( when(authenticator.authenticate(eq(request)))
new WebSocketAuthenticator.AuthenticationResult<>(Optional.of(account), true)); .thenReturn(ReusableAuth.authenticated(account, PrincipalSupplier.forImmutablePrincipal()));
when(environment.jersey()).thenReturn(jerseyEnvironment); when(environment.jersey()).thenReturn(jerseyEnvironment);
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class); final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1"); when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1");
@ -137,6 +138,7 @@ public class WebSocketResourceProviderFactoryTest {
public boolean implies(Subject subject) { public boolean implies(Subject subject) {
return false; return false;
} }
} }

View File

@ -59,6 +59,7 @@ import org.glassfish.jersey.server.ResourceConfig;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.logging.WebsocketRequestLog; import org.whispersystems.websocket.logging.WebsocketRequestLog;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory; import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
@ -80,7 +81,7 @@ class WebSocketResourceProviderTest {
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, REMOTE_ADDRESS_PROPERTY_NAME,
applicationHandler, requestLog, applicationHandler, requestLog,
new TestPrincipal("fooz"), immutableTestPrincipal("fooz"),
new ProtobufWebSocketMessageFactory(), new ProtobufWebSocketMessageFactory(),
Optional.of(connectListener), Optional.of(connectListener),
Duration.ofMillis(30000)); Duration.ofMillis(30000));
@ -108,7 +109,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = mock(ApplicationHandler.class); ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
@ -184,7 +185,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = mock(ApplicationHandler.class); ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
@ -240,7 +241,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
@ -280,7 +281,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
@ -320,7 +321,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("authorizedUserName"), REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("authorizedUserName"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
@ -360,8 +361,8 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(), REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, ReusableAuth.anonymous(),
Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@ -399,7 +400,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("something"), REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("something"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
@ -439,8 +440,8 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(), REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, ReusableAuth.anonymous(),
Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@ -479,7 +480,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"), REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
@ -520,7 +521,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"), REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
@ -562,7 +563,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"), REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
@ -602,7 +603,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"), REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class); Session session = mock(Session.class);
@ -727,6 +728,10 @@ class WebSocketResourceProviderTest {
} }
} }
public static ReusableAuth<TestPrincipal> immutableTestPrincipal(final String name) {
return ReusableAuth.authenticated(new TestPrincipal(name), PrincipalSupplier.forImmutablePrincipal());
}
public static class TestException extends Exception { public static class TestException extends Exception {
public TestException(String message) { public TestException(String message) {