Lifecycle management for Account objects reused accross websocket requests

This commit is contained in:
Ravi Khadiwala 2024-02-06 16:59:42 -06:00 committed by ravi-signal
parent 29ef3f0b41
commit 26ffa19f36
38 changed files with 1317 additions and 457 deletions

View File

@ -188,6 +188,7 @@ import org.whispersystems.textsecuregcm.spam.SenderOverrideProvider;
import org.whispersystems.textsecuregcm.spam.SpamChecker;
import org.whispersystems.textsecuregcm.spam.SpamFilter;
import org.whispersystems.textsecuregcm.storage.AccountLockManager;
import org.whispersystems.textsecuregcm.storage.AccountPrincipalSupplier;
import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ChangeNumberManager;
@ -812,7 +813,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
WebSocketEnvironment<AuthenticatedAccount> webSocketEnvironment = new WebSocketEnvironment<>(environment,
config.getWebSocketConfiguration(), Duration.ofMillis(90000));
webSocketEnvironment.jersey().register(new VirtualExecutorServiceProvider("managed-async-websocket-virtual-thread-"));
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator));
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator, new AccountPrincipalSupplier(accountsManager)));
webSocketEnvironment.setConnectListener(
new AuthenticatedConnectListener(receiptSender, messagesManager, pushNotificationManager,
clientPresenceManager, websocketScheduledExecutor, messageDeliveryScheduler, clientReleaseManager));

View File

@ -21,7 +21,6 @@ import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.RefreshingAccountAndDeviceSupplier;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
@ -108,8 +107,7 @@ public class AccountAuthenticator implements Authenticator<BasicCredentials, Aut
device.get(),
SaltedTokenHash.generateFor(basicCredentials.getPassword())); // new credentials have current version
}
return Optional.of(new AuthenticatedAccount(
new RefreshingAccountAndDeviceSupplier(authenticatedAccount, device.get().getId(), accountsManager)));
return Optional.of(new AuthenticatedAccount(authenticatedAccount, device.get()));
}
return Optional.empty();

View File

@ -5,7 +5,6 @@
package org.whispersystems.textsecuregcm.auth;
import com.google.common.annotations.VisibleForTesting;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
@ -45,10 +44,6 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
this.accountsManager = accountsManager;
}
@VisibleForTesting
static Map<Byte, Boolean> buildDevicesEnabledMap(final Account account) {
return account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::isEnabled));
}
@Override
public void handleRequestFiltered(final RequestEvent requestEvent) {
@ -60,10 +55,13 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
setAccount(requestEvent.getContainerRequest(), account));
}
}
public static void setAccount(final ContainerRequest containerRequest, final Account account) {
containerRequest.setProperty(ACCOUNT_UUID, account.getUuid());
containerRequest.setProperty(DEVICES_ENABLED, buildDevicesEnabledMap(account));
setAccount(containerRequest, ContainerRequestUtil.AccountInfo.fromAccount(account));
}
private static void setAccount(final ContainerRequest containerRequest, final ContainerRequestUtil.AccountInfo info) {
containerRequest.setProperty(ACCOUNT_UUID, info.accountId());
containerRequest.setProperty(DEVICES_ENABLED, info.devicesEnabled());
}
@Override
@ -75,25 +73,28 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
@SuppressWarnings("unchecked") final Map<Byte, Boolean> initialDevicesEnabled =
(Map<Byte, Boolean>) requestEvent.getContainerRequest().getProperty(DEVICES_ENABLED);
return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID)).map(account -> {
final Set<Byte> deviceIdsToDisplace;
final Map<Byte, Boolean> currentDevicesEnabled = buildDevicesEnabledMap(account);
return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID))
.map(ContainerRequestUtil.AccountInfo::fromAccount)
.map(account -> {
final Set<Byte> deviceIdsToDisplace;
final Map<Byte, Boolean> currentDevicesEnabled = account.devicesEnabled();
if (!initialDevicesEnabled.equals(currentDevicesEnabled)) {
deviceIdsToDisplace = new HashSet<>(initialDevicesEnabled.keySet());
deviceIdsToDisplace.addAll(currentDevicesEnabled.keySet());
} else {
deviceIdsToDisplace = Collections.emptySet();
}
if (!initialDevicesEnabled.equals(currentDevicesEnabled)) {
deviceIdsToDisplace = new HashSet<>(initialDevicesEnabled.keySet());
deviceIdsToDisplace.addAll(currentDevicesEnabled.keySet());
} else {
deviceIdsToDisplace = Collections.emptySet();
}
return deviceIdsToDisplace.stream()
.map(deviceId -> new Pair<>(account.getUuid(), deviceId))
.collect(Collectors.toList());
}).orElseGet(() -> {
logger.error("Request had account, but it is no longer present");
return Collections.emptyList();
});
} else
return deviceIdsToDisplace.stream()
.map(deviceId -> new Pair<>(account.accountId(), deviceId))
.collect(Collectors.toList());
}).orElseGet(() -> {
logger.error("Request had account, but it is no longer present");
return Collections.emptyList();
});
} else {
return Collections.emptyList();
}
}
}

View File

@ -10,24 +10,24 @@ import java.util.function.Supplier;
import javax.security.auth.Subject;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair;
public class AuthenticatedAccount implements Principal, AccountAndAuthenticatedDeviceHolder {
private final Account account;
private final Device device;
private final Supplier<Pair<Account, Device>> accountAndDevice;
public AuthenticatedAccount(final Supplier<Pair<Account, Device>> accountAndDevice) {
this.accountAndDevice = accountAndDevice;
public AuthenticatedAccount(final Account account, final Device device) {
this.account = account;
this.device = device;
}
@Override
public Account getAccount() {
return accountAndDevice.get().first();
return account;
}
@Override
public Device getAuthenticatedDevice() {
return accountAndDevice.get().second();
return device;
}
// Principal implementation

View File

@ -0,0 +1,20 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* Indicates that an endpoint changes the phone number and PNI keys associated with an account, and that
* any websockets associated with the account may need to be refreshed after a call to that endpoint.
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface ChangesPhoneNumber {
}

View File

@ -7,15 +7,42 @@ package org.whispersystems.textsecuregcm.auth;
import org.glassfish.jersey.server.ContainerRequest;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import javax.ws.rs.core.SecurityContext;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;
class ContainerRequestUtil {
static Optional<Account> getAuthenticatedAccount(final ContainerRequest request) {
private static Map<Byte, Boolean> buildDevicesEnabledMap(final Account account) {
return account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::isEnabled));
}
/**
* A read-only subset of the authenticated Account object, to enforce that filter-based consumers do not perform
* account modifying operations.
*/
record AccountInfo(UUID accountId, String e164, Map<Byte, Boolean> devicesEnabled) {
static AccountInfo fromAccount(final Account account) {
return new AccountInfo(
account.getUuid(),
account.getNumber(),
buildDevicesEnabledMap(account));
}
}
static Optional<AccountInfo> getAuthenticatedAccount(final ContainerRequest request) {
return Optional.ofNullable(request.getSecurityContext())
.map(SecurityContext::getUserPrincipal)
.map(principal -> principal instanceof AccountAndAuthenticatedDeviceHolder
? ((AccountAndAuthenticatedDeviceHolder) principal).getAccount() : null);
.map(principal -> {
if (principal instanceof AccountAndAuthenticatedDeviceHolder aaadh) {
return aaadh.getAccount();
}
return null;
})
.map(AccountInfo::fromAccount);
}
}

View File

@ -7,40 +7,50 @@ package org.whispersystems.textsecuregcm.auth;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;
import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.util.Pair;
public class PhoneNumberChangeRefreshRequirementProvider implements WebsocketRefreshRequirementProvider {
private static final String ACCOUNT_UUID =
PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".accountUuid";
private static final String INITIAL_NUMBER_KEY =
PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".initialNumber";
private final AccountsManager accountsManager;
public PhoneNumberChangeRefreshRequirementProvider(final AccountsManager accountsManager) {
this.accountsManager = accountsManager;
}
@Override
public void handleRequestFiltered(final RequestEvent requestEvent) {
if (requestEvent.getUriInfo().getMatchedResourceMethod().getInvocable().getHandlingMethod()
.getAnnotation(ChangesPhoneNumber.class) == null) {
return;
}
ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest())
.ifPresent(account -> requestEvent.getContainerRequest().setProperty(INITIAL_NUMBER_KEY, account.getNumber()));
.ifPresent(account -> {
requestEvent.getContainerRequest().setProperty(INITIAL_NUMBER_KEY, account.e164());
requestEvent.getContainerRequest().setProperty(ACCOUNT_UUID, account.accountId());
});
}
@Override
public List<Pair<UUID, Byte>> handleRequestFinished(final RequestEvent requestEvent) {
final String initialNumber = (String) requestEvent.getContainerRequest().getProperty(INITIAL_NUMBER_KEY);
if (initialNumber != null) {
final Optional<Account> maybeAuthenticatedAccount =
ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest());
return maybeAuthenticatedAccount
.filter(account -> !initialNumber.equals(account.getNumber()))
.map(account -> account.getDevices().stream()
.map(device -> new Pair<>(account.getUuid(), device.getId()))
.collect(Collectors.toList()))
.orElse(Collections.emptyList());
} else {
if (initialNumber == null) {
return Collections.emptyList();
}
return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID))
.filter(account -> !initialNumber.equals(account.getNumber()))
.map(account -> account.getDevices().stream()
.map(device -> new Pair<>(account.getUuid(), device.getId()))
.collect(Collectors.toList()))
.orElse(Collections.emptyList());
}
}

View File

@ -24,7 +24,7 @@ public class WebsocketRefreshApplicationEventListener implements ApplicationEven
this.websocketRefreshRequestEventListener = new WebsocketRefreshRequestEventListener(clientPresenceManager,
new AuthEnablementRefreshRequirementProvider(accountsManager),
new PhoneNumberChangeRefreshRequirementProvider());
new PhoneNumberChangeRefreshRequirementProvider(accountsManager));
}
@Override

View File

@ -36,6 +36,7 @@ import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ChangesPhoneNumber;
import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
import org.whispersystems.textsecuregcm.entities.AccountDataReportResponse;
@ -90,6 +91,7 @@ public class AccountControllerV2 {
@Path("/number")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@ChangesPhoneNumber
@Operation(summary = "Change number", description = "Changes a phone number for an existing account.")
@ApiResponse(responseCode = "200", description = "The phone number associated with the authenticated account was changed successfully", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")

View File

@ -0,0 +1,36 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.websocket.auth.PrincipalSupplier;
public class AccountPrincipalSupplier implements PrincipalSupplier<AuthenticatedAccount> {
private final AccountsManager accountsManager;
public AccountPrincipalSupplier(final AccountsManager accountsManager) {
this.accountsManager = accountsManager;
}
@Override
public AuthenticatedAccount refresh(final AuthenticatedAccount oldAccount) {
final Account account = accountsManager.getByAccountIdentifier(oldAccount.getAccount().getUuid())
.orElseThrow(() -> new RefreshingAccountNotFoundException("Could not find account"));
final Device device = account.getDevice(oldAccount.getAuthenticatedDevice().getId())
.orElseThrow(() -> new RefreshingAccountNotFoundException("Could not find device"));
return new AuthenticatedAccount(account, device);
}
@Override
public AuthenticatedAccount deepCopy(final AuthenticatedAccount authenticatedAccount) {
final Account cloned = AccountUtil.cloneAccountAsNotStale(authenticatedAccount.getAccount());
return new AuthenticatedAccount(
cloned,
cloned.getDevice(authenticatedAccount.getAuthenticatedDevice().getId())
.orElseThrow(() -> new IllegalStateException(
"Could not find device from a clone of an account where the device was present")));
}
}

View File

@ -5,10 +5,12 @@
package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import java.io.IOException;
class AccountUtil {
public class AccountUtil {
static Account cloneAccountAsNotStale(final Account account) {
try {

View File

@ -1,35 +0,0 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import java.util.function.Supplier;
import org.whispersystems.textsecuregcm.util.Pair;
public class RefreshingAccountAndDeviceSupplier implements Supplier<Pair<Account, Device>> {
private Account account;
private Device device;
private final AccountsManager accountsManager;
public RefreshingAccountAndDeviceSupplier(Account account, byte deviceId, AccountsManager accountsManager) {
this.account = account;
this.device = account.getDevice(deviceId)
.orElseThrow(() -> new RefreshingAccountAndDeviceNotFoundException("Could not find device"));
this.accountsManager = accountsManager;
}
@Override
public Pair<Account, Device> get() {
if (account.isStale()) {
account = accountsManager.getByAccountIdentifier(account.getUuid())
.orElseThrow(() -> new RuntimeException("Could not find account"));
device = account.getDevice(device.getId())
.orElseThrow(() -> new RefreshingAccountAndDeviceNotFoundException("Could not find device"));
}
return new Pair<>(account, device);
}
}

View File

@ -5,9 +5,9 @@
package org.whispersystems.textsecuregcm.storage;
public class RefreshingAccountAndDeviceNotFoundException extends RuntimeException {
public class RefreshingAccountNotFoundException extends RuntimeException {
public RefreshingAccountAndDeviceNotFoundException(final String message) {
public RefreshingAccountNotFoundException(final String message) {
super(message);
}

View File

@ -11,41 +11,40 @@ import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.basic.BasicCredentials;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import javax.annotation.Nullable;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<AuthenticatedAccount> {
private static final AuthenticationResult<AuthenticatedAccount> CREDENTIALS_NOT_PRESENTED =
new AuthenticationResult<>(Optional.empty(), false);
private static final ReusableAuth<AuthenticatedAccount> CREDENTIALS_NOT_PRESENTED = ReusableAuth.anonymous();
private static final AuthenticationResult<AuthenticatedAccount> INVALID_CREDENTIALS_PRESENTED =
new AuthenticationResult<>(Optional.empty(), true);
private static final ReusableAuth<AuthenticatedAccount> INVALID_CREDENTIALS_PRESENTED = ReusableAuth.invalid();
private final AccountAuthenticator accountAuthenticator;
private final PrincipalSupplier<AuthenticatedAccount> principalSupplier;
public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator) {
public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator,
final PrincipalSupplier<AuthenticatedAccount> principalSupplier) {
this.accountAuthenticator = accountAuthenticator;
this.principalSupplier = principalSupplier;
}
@Override
public AuthenticationResult<AuthenticatedAccount> authenticate(final UpgradeRequest request)
public ReusableAuth<AuthenticatedAccount> authenticate(final UpgradeRequest request)
throws AuthenticationException {
try {
final AuthenticationResult<AuthenticatedAccount> authResultFromHeader =
authenticatedAccountFromHeaderAuth(request.getHeader(HttpHeaders.AUTHORIZATION));
// the logic here is that if the `Authorization` header was set for the request,
// it takes the priority and we use the result of the header-based auth
// ignoring the result of the query-based auth.
if (authResultFromHeader.credentialsPresented()) {
return authResultFromHeader;
// If the `Authorization` header was set for the request it takes priority, and we use the result of the
// header-based auth ignoring the result of the query-based auth.
final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);
if (authHeader != null) {
return authenticatedAccountFromHeaderAuth(authHeader);
}
return authenticatedAccountFromQueryParams(request);
} catch (final Exception e) {
@ -55,7 +54,7 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Aut
}
}
private AuthenticationResult<AuthenticatedAccount> authenticatedAccountFromQueryParams(final UpgradeRequest request) {
private ReusableAuth<AuthenticatedAccount> authenticatedAccountFromQueryParams(final UpgradeRequest request) {
final Map<String, List<String>> parameters = request.getParameterMap();
final List<String> usernames = parameters.get("login");
final List<String> passwords = parameters.get("password");
@ -65,16 +64,19 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Aut
}
final BasicCredentials credentials = new BasicCredentials(usernames.get(0).replace(" ", "+"),
passwords.get(0).replace(" ", "+"));
return new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true);
return accountAuthenticator.authenticate(credentials)
.map(authenticatedAccount -> ReusableAuth.authenticated(authenticatedAccount, this.principalSupplier))
.orElse(INVALID_CREDENTIALS_PRESENTED);
}
private AuthenticationResult<AuthenticatedAccount> authenticatedAccountFromHeaderAuth(@Nullable final String authHeader)
private ReusableAuth<AuthenticatedAccount> authenticatedAccountFromHeaderAuth(@Nullable final String authHeader)
throws AuthenticationException {
if (authHeader == null) {
return CREDENTIALS_NOT_PRESENTED;
}
return basicCredentialsFromAuthHeader(authHeader)
.map(credentials -> new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true))
.flatMap(credentials -> accountAuthenticator.authenticate(credentials))
.map(authenticatedAccount -> ReusableAuth.authenticated(authenticatedAccount, this.principalSupplier))
.orElse(INVALID_CREDENTIALS_PRESENTED);
}
}

View File

@ -0,0 +1,279 @@
package org.whispersystems.textsecuregcm;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
import io.dropwizard.auth.Auth;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import java.io.IOException;
import java.net.URI;
import java.util.EnumSet;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.IntStream;
import javax.servlet.DispatcherType;
import javax.servlet.ServletRegistration;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync;
import org.glassfish.jersey.server.ServerProperties;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.RefreshingAccountNotFoundException;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
public class WebsocketReuseAuthIntegrationTest {
private static final AuthenticatedAccount ACCOUNT = mock(AuthenticatedAccount.class);
@SuppressWarnings("unchecked")
private static final PrincipalSupplier<AuthenticatedAccount> PRINCIPAL_SUPPLIER = mock(PrincipalSupplier.class);
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION =
new DropwizardAppExtension<>(TestApplication.class);
private WebSocketClient client;
@BeforeEach
void setUp() throws Exception {
reset(PRINCIPAL_SUPPLIER);
reset(ACCOUNT);
when(ACCOUNT.getName()).thenReturn("original");
client = new WebSocketClient();
client.start();
}
@AfterEach
void tearDown() throws Exception {
client.stop();
}
public static class TestApplication extends Application<Configuration> {
@Override
public void run(final Configuration configuration, final Environment environment) throws Exception {
final TestController testController = new TestController();
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
final WebSocketEnvironment<AuthenticatedAccount> webSocketEnvironment =
new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter(true))
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(testController);
webSocketEnvironment.jersey().register(new RemoteAddressFilter(true));
webSocketEnvironment.setAuthenticator(upgradeRequest -> ReusableAuth.authenticated(ACCOUNT, PRINCIPAL_SUPPLIER));
webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE);
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {
});
final WebSocketResourceProviderFactory<AuthenticatedAccount> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedAccount.class,
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
final ServletRegistration.Dynamic websocketServlet =
environment.servlets().addServlet("WebSocket", webSocketServlet);
websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
}
}
private WebSocketResponseMessage make1WebsocketRequest(final String requestPath) throws IOException {
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
return testWebsocketListener.doGet(requestPath).join();
}
@ParameterizedTest
@ValueSource(strings = {"/test/read-auth", "/test/optional-read-auth"})
public void readAuth(final String path) throws IOException {
final WebSocketResponseMessage response = make1WebsocketRequest(path);
assertThat(response.getStatus()).isEqualTo(200);
verifyNoMoreInteractions(PRINCIPAL_SUPPLIER);
}
@ParameterizedTest
@ValueSource(strings = {"/test/write-auth", "/test/optional-write-auth"})
public void writeAuth(final String path) throws IOException {
final AuthenticatedAccount copiedAccount = mock(AuthenticatedAccount.class);
when(copiedAccount.getName()).thenReturn("copy");
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(copiedAccount);
final WebSocketResponseMessage response = make1WebsocketRequest(path);
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.getBody().map(String::new)).get().isEqualTo("copy");
verify(PRINCIPAL_SUPPLIER, times(1)).deepCopy(any());
verifyNoMoreInteractions(PRINCIPAL_SUPPLIER);
}
@Test
public void readAfterWrite() throws IOException {
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(ACCOUNT);
final AuthenticatedAccount account2 = mock(AuthenticatedAccount.class);
when(account2.getName()).thenReturn("refresh");
when(PRINCIPAL_SUPPLIER.refresh(any())).thenReturn(account2);
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse.getBody().map(String::new)).get().isEqualTo("original");
final WebSocketResponseMessage writeResponse = testWebsocketListener.doGet("/test/write-auth").join();
assertThat(writeResponse.getBody().map(String::new)).get().isEqualTo("original");
final WebSocketResponseMessage readResponse2 = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse2.getBody().map(String::new)).get().isEqualTo("refresh");
}
@Test
public void readAfterWriteRefreshFails() throws IOException {
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(ACCOUNT);
when(PRINCIPAL_SUPPLIER.refresh(any())).thenThrow(RefreshingAccountNotFoundException.class);
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
final WebSocketResponseMessage writeResponse = testWebsocketListener.doGet("/test/write-auth").join();
assertThat(writeResponse.getBody().map(String::new)).get().isEqualTo("original");
final WebSocketResponseMessage readResponse2 = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse2.getStatus()).isEqualTo(500);
}
@Test
public void readConcurrentWithWrite() throws IOException, ExecutionException, InterruptedException, TimeoutException {
final AuthenticatedAccount deepCopy = mock(AuthenticatedAccount.class);
when(deepCopy.getName()).thenReturn("deepCopy");
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(deepCopy);
final AuthenticatedAccount refresh = mock(AuthenticatedAccount.class);
when(refresh.getName()).thenReturn("refresh");
when(PRINCIPAL_SUPPLIER.refresh(any())).thenReturn(refresh);
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
// start a write request that takes a while to finish
final CompletableFuture<WebSocketResponseMessage> writeResponse =
testWebsocketListener.doGet("/test/start-delayed-write/foo");
// send a bunch of reads, they should reflect the original auth
final List<CompletableFuture<WebSocketResponseMessage>> futures = IntStream.range(0, 10)
.boxed().map(i -> testWebsocketListener.doGet("/test/read-auth"))
.toList();
CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join();
for (CompletableFuture<WebSocketResponseMessage> future : futures) {
assertThat(future.join().getBody().map(String::new)).get().isEqualTo("original");
}
assertThat(writeResponse.isDone()).isFalse();
// finish the delayed write request
testWebsocketListener.doGet("/test/finish-delayed-write/foo").get(1, TimeUnit.SECONDS);
assertThat(writeResponse.join().getBody().map(String::new)).get().isEqualTo("deepCopy");
// subsequent reads should have the refreshed auth
final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse.getBody().map(String::new)).get().isEqualTo("refresh");
}
@Path("/test")
public static class TestController {
private final ConcurrentHashMap<String, CountDownLatch> delayedWriteLatches = new ConcurrentHashMap<>();
@GET
@Path("/read-auth")
@ManagedAsync
public String readAuth(@ReadOnly @Auth final AuthenticatedAccount account) {
return account.getName();
}
@GET
@Path("/optional-read-auth")
@ManagedAsync
public String optionalReadAuth(@ReadOnly @Auth final Optional<AuthenticatedAccount> account) {
return account.map(AuthenticatedAccount::getName).orElse("empty");
}
@GET
@Path("/write-auth")
@ManagedAsync
public String writeAuth(@Auth final AuthenticatedAccount account) {
return account.getName();
}
@GET
@Path("/optional-write-auth")
@ManagedAsync
public String optionalWriteAuth(@Auth final Optional<AuthenticatedAccount> account) {
return account.map(AuthenticatedAccount::getName).orElse("empty");
}
@GET
@Path("/start-delayed-write/{id}")
@ManagedAsync
public String startDelayedWrite(@Auth final AuthenticatedAccount account, @PathParam("id") String id)
throws InterruptedException {
delayedWriteLatches.computeIfAbsent(id, i -> new CountDownLatch(1)).await();
return account.getName();
}
@GET
@Path("/finish-delayed-write/{id}")
@ManagedAsync
public String finishDelayedWrite(@PathParam("id") String id) {
delayedWriteLatches.computeIfAbsent(id, i -> new CountDownLatch(1)).countDown();
return "ok";
}
}
}

View File

@ -7,8 +7,6 @@ package org.whispersystems.textsecuregcm.auth;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
@ -30,7 +28,6 @@ import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.Principal;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.LinkedList;
@ -76,7 +73,9 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProvider;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.logging.WebsocketRequestLog;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
@ -132,38 +131,6 @@ class AuthEnablementRefreshRequirementProviderTest {
.forEach(device -> when(clientPresenceManager.isPresent(uuid, device.getId())).thenReturn(true));
}
@Test
void testBuildDevicesEnabled() {
final byte disabledDeviceId = 3;
final Account account = mock(Account.class);
final List<Device> devices = new ArrayList<>();
when(account.getDevices()).thenReturn(devices);
IntStream.range(1, 5)
.forEach(id -> {
final Device device = mock(Device.class);
when(device.getId()).thenReturn((byte) id);
when(device.isEnabled()).thenReturn(id != disabledDeviceId);
devices.add(device);
});
final Map<Byte, Boolean> devicesEnabled = AuthEnablementRefreshRequirementProvider.buildDevicesEnabledMap(account);
assertEquals(4, devicesEnabled.size());
assertAll(devicesEnabled.entrySet().stream()
.map(deviceAndEnabled -> () -> {
if (deviceAndEnabled.getKey().equals(disabledDeviceId)) {
assertFalse(deviceAndEnabled.getValue());
} else {
assertTrue(deviceAndEnabled.getValue());
}
}));
}
@ParameterizedTest
@MethodSource
void testDeviceEnabledChanged(final Map<Byte, Boolean> initialEnabled, final Map<Byte, Boolean> finalEnabled) {
@ -308,7 +275,7 @@ class AuthEnablementRefreshRequirementProviderTest {
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
provider = new WebSocketResourceProvider<>("127.0.0.1", RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME,
applicationHandler, requestLog, new TestPrincipal("test", account, authenticatedDevice),
applicationHandler, requestLog, TestPrincipal.reusableAuth("test", account, authenticatedDevice),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
remoteEndpoint = mock(RemoteEndpoint.class);
@ -349,7 +316,7 @@ class AuthEnablementRefreshRequirementProviderTest {
private final Account account;
private final Device device;
private TestPrincipal(String name, final Account account, final Device device) {
private TestPrincipal(final String name, final Account account, final Device device) {
this.name = name;
this.account = account;
this.device = device;
@ -369,6 +336,11 @@ class AuthEnablementRefreshRequirementProviderTest {
public Device getAuthenticatedDevice() {
return device;
}
public static ReusableAuth<TestPrincipal> reusableAuth(final String name, final Account account, final Device device) {
return ReusableAuth.authenticated(new TestPrincipal(name, account, device), PrincipalSupplier.forImmutablePrincipal());
}
}
@Path("/v1/test")

View File

@ -0,0 +1,55 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.IntStream;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class ContainerRequestUtilTest {
@Test
void testBuildDevicesEnabled() {
final byte disabledDeviceId = 3;
final Account account = mock(Account.class);
final List<Device> devices = new ArrayList<>();
when(account.getDevices()).thenReturn(devices);
IntStream.range(1, 5)
.forEach(id -> {
final Device device = mock(Device.class);
when(device.getId()).thenReturn((byte) id);
when(device.isEnabled()).thenReturn(id != disabledDeviceId);
devices.add(device);
});
final Map<Byte, Boolean> devicesEnabled = ContainerRequestUtil.AccountInfo.fromAccount(account).devicesEnabled();
assertEquals(4, devicesEnabled.size());
assertAll(devicesEnabled.entrySet().stream()
.map(deviceAndEnabled -> () -> {
if (deviceAndEnabled.getKey().equals(disabledDeviceId)) {
assertFalse(deviceAndEnabled.getValue());
} else {
assertTrue(deviceAndEnabled.getValue());
}
}));
}
}

View File

@ -5,108 +5,292 @@
package org.whispersystems.textsecuregcm.auth;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth;
import io.dropwizard.auth.AuthDynamicFeature;
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import java.io.IOException;
import java.net.URI;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.EnumSet;
import java.util.Optional;
import java.util.UUID;
import javax.annotation.Nullable;
import javax.ws.rs.core.SecurityContext;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.monitoring.RequestEvent;
import javax.servlet.DispatcherType;
import javax.servlet.ServletRegistration;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.client.Invocation;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
class PhoneNumberChangeRefreshRequirementProviderTest {
private PhoneNumberChangeRefreshRequirementProvider provider;
private Account account;
private RequestEvent requestEvent;
private ContainerRequest request;
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
private static final String NUMBER = "+18005551234";
private static final String CHANGED_NUMBER = "+18005554321";
private static final String TEST_CRED_HEADER = HeaderUtils.basicAuthHeader("test", "password");
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION = new DropwizardAppExtension<>(
TestApplication.class);
private static final AccountAuthenticator AUTHENTICATOR = mock(AccountAuthenticator.class);
private static final AccountsManager ACCOUNTS_MANAGER = mock(AccountsManager.class);
private static final ClientPresenceManager CLIENT_PRESENCE = mock(ClientPresenceManager.class);
private WebSocketClient client;
private final Account account1 = new Account();
private final Account account2 = new Account();
private final Device authenticatedDevice = DevicesHelper.createDevice(Device.PRIMARY_ID);
@BeforeEach
void setUp() {
provider = new PhoneNumberChangeRefreshRequirementProvider();
void setUp() throws Exception {
reset(AUTHENTICATOR, CLIENT_PRESENCE, ACCOUNTS_MANAGER);
client = new WebSocketClient();
client.start();
account = mock(Account.class);
final Device device = mock(Device.class);
final UUID uuid = UUID.randomUUID();
account1.setUuid(uuid);
account1.addDevice(authenticatedDevice);
account1.setNumber(NUMBER, UUID.randomUUID());
when(account.getUuid()).thenReturn(ACCOUNT_UUID);
when(account.getNumber()).thenReturn(NUMBER);
when(account.getDevices()).thenReturn(List.of(device));
when(device.getId()).thenReturn(Device.PRIMARY_ID);
account2.setUuid(uuid);
account2.addDevice(authenticatedDevice);
account2.setNumber(CHANGED_NUMBER, UUID.randomUUID());
request = mock(ContainerRequest.class);
}
final Map<String, Object> requestProperties = new HashMap<>();
@AfterEach
void tearDown() throws Exception {
client.stop();
}
doAnswer(invocation -> {
requestProperties.put(invocation.getArgument(0, String.class), invocation.getArgument(1));
return null;
}).when(request).setProperty(anyString(), any());
when(request.getProperty(anyString())).thenAnswer(
invocation -> requestProperties.get(invocation.getArgument(0, String.class)));
public static class TestApplication extends Application<Configuration> {
requestEvent = mock(RequestEvent.class);
when(requestEvent.getContainerRequest()).thenReturn(request);
@Override
public void run(final Configuration configuration, final Environment environment) throws Exception {
final TestController testController = new TestController();
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
final WebSocketEnvironment<AuthenticatedAccount> webSocketEnvironment =
new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.jersey().register(testController);
webSocketEnvironment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter(true))
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(new RemoteAddressFilter(true));
webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE));
environment.jersey()
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE));
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {
});
environment.jersey().register(new AuthDynamicFeature(new BasicCredentialAuthFilter.Builder<AuthenticatedAccount>()
.setAuthenticator(AUTHENTICATOR)
.buildAuthFilter()));
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(AUTHENTICATOR, mock(PrincipalSupplier.class)));
final WebSocketResourceProviderFactory<AuthenticatedAccount> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedAccount.class,
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
final ServletRegistration.Dynamic websocketServlet =
environment.servlets().addServlet("WebSocket", webSocketServlet);
websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
}
}
enum Protocol { HTTP, WEBSOCKET }
private void makeAnonymousRequest(final Protocol protocol, final String requestPath) throws IOException {
makeRequest(protocol, requestPath, true);
}
/*
* Make an authenticated request that will return account1 as the principal
*/
private void makeAuthenticatedRequest(
final Protocol protocol,
final String requestPath) throws IOException {
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedAccount(account1, authenticatedDevice)));
makeRequest(protocol,requestPath, false);
}
private void makeRequest(final Protocol protocol, final String requestPath, final boolean anonymous) throws IOException {
switch (protocol) {
case WEBSOCKET -> {
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
if (!anonymous) {
upgradeRequest.setHeader(HttpHeaders.AUTHORIZATION, TEST_CRED_HEADER);
}
client.connect(
testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())),
upgradeRequest);
testWebsocketListener.sendRequest(requestPath, "GET", Collections.emptyList(), Optional.empty()).join();
}
case HTTP -> {
final Invocation.Builder request = DROPWIZARD_APP_EXTENSION.client()
.target("http://127.0.0.1:%s%s".formatted(DROPWIZARD_APP_EXTENSION.getLocalPort(), requestPath))
.request();
if (!anonymous) {
request.header(HttpHeaders.AUTHORIZATION, TEST_CRED_HEADER);
}
request.get();
}
}
}
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestNoChange(final Protocol protocol) throws IOException {
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account1));
makeAuthenticatedRequest(protocol, "/test/annotated");
// Event listeners can fire after responses are sent
verify(ACCOUNTS_MANAGER, timeout(5000).times(1)).getByAccountIdentifier(eq(account1.getUuid()));
verifyNoMoreInteractions(CLIENT_PRESENCE);
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
}
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestChange(final Protocol protocol) throws IOException {
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account2));
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedAccount(account1, authenticatedDevice)));
makeAuthenticatedRequest(protocol, "/test/annotated");
// Make sure we disconnect the account if the account has changed numbers. Event listeners can fire after responses
// are sent, so use a timeout.
verify(CLIENT_PRESENCE, timeout(5000))
.disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId()));
verifyNoMoreInteractions(CLIENT_PRESENCE);
}
@Test
void handleRequestNoChange() {
setAuthenticatedAccount(request, account);
void handleRequestChangeAsyncEndpoint() throws IOException {
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account2));
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedAccount(account1, authenticatedDevice)));
provider.handleRequestFiltered(requestEvent);
assertEquals(Collections.emptyList(), provider.handleRequestFinished(requestEvent));
// Event listeners with asynchronous HTTP endpoints don't currently correctly maintain state between request and
// response
makeAuthenticatedRequest(Protocol.WEBSOCKET, "/test/async-annotated");
// Make sure we disconnect the account if the account has changed numbers. Event listeners can fire after responses
// are sent, so use a timeout.
verify(CLIENT_PRESENCE, timeout(5000))
.disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId()));
verifyNoMoreInteractions(CLIENT_PRESENCE);
}
@Test
void handleRequestNumberChange() {
setAuthenticatedAccount(request, account);
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestNotAnnotated(final Protocol protocol) throws IOException, InterruptedException {
makeAuthenticatedRequest(protocol,"/test/not-annotated");
provider.handleRequestFiltered(requestEvent);
when(account.getNumber()).thenReturn(CHANGED_NUMBER);
assertEquals(List.of(new Pair<>(ACCOUNT_UUID, Device.PRIMARY_ID)), provider.handleRequestFinished(requestEvent));
// Give a tick for event listeners to run. Racy, but should occasionally catch an errant running listener if one is
// introduced.
Thread.sleep(100);
// Shouldn't even read the account if the method has not been annotated
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
verifyNoMoreInteractions(CLIENT_PRESENCE);
}
@Test
void handleRequestNoAuthenticatedAccount() {
final ContainerRequest request = mock(ContainerRequest.class);
setAuthenticatedAccount(request, null);
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestNotAuthenticated(final Protocol protocol) throws IOException, InterruptedException {
makeAnonymousRequest(protocol, "/test/not-authenticated");
when(requestEvent.getContainerRequest()).thenReturn(request);
// Give a tick for event listeners to run. Racy, but should occasionally catch an errant running listener if one is
// introduced.
Thread.sleep(100);
provider.handleRequestFiltered(requestEvent);
assertEquals(Collections.emptyList(), provider.handleRequestFinished(requestEvent));
// Shouldn't even read the account if the method has not been annotated
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
verifyNoMoreInteractions(CLIENT_PRESENCE);
}
private static void setAuthenticatedAccount(final ContainerRequest mockRequest, @Nullable final Account account) {
final SecurityContext securityContext = mock(SecurityContext.class);
when(mockRequest.getSecurityContext()).thenReturn(securityContext);
@Path("/test")
public static class TestController {
if (account != null) {
final AuthenticatedAccount authenticatedAccount = mock(AuthenticatedAccount.class);
@GET
@Path("/annotated")
@ChangesPhoneNumber
public String annotated(@ReadOnly @Auth final AuthenticatedAccount account) {
return "ok";
}
when(securityContext.getUserPrincipal()).thenReturn(authenticatedAccount);
when(authenticatedAccount.getAccount()).thenReturn(account);
} else {
when(securityContext.getUserPrincipal()).thenReturn(null);
@GET
@Path("/async-annotated")
@ChangesPhoneNumber
@ManagedAsync
public String asyncAnnotated(@ReadOnly @Auth final AuthenticatedAccount account) {
return "ok";
}
@GET
@Path("/not-authenticated")
@ChangesPhoneNumber
public String notAuthenticated() {
return "ok";
}
@GET
@Path("/not-annotated")
public String notAnnotated(@ReadOnly @Auth final AuthenticatedAccount account) {
return "ok";
}
}
}

View File

@ -87,7 +87,7 @@ class BasicCredentialAuthenticationInterceptorTest {
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(accountAuthenticator.authenticate(any()))
.thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(account, device))));
.thenReturn(Optional.of(new AuthenticatedAccount(account, device)));
} else {
when(accountAuthenticator.authenticate(any()))
.thenReturn(Optional.empty());

View File

@ -39,7 +39,7 @@ class DirectoryControllerV2Test {
when(account.getUuid()).thenReturn(uuid);
final ExternalServiceCredentials credentials = (ExternalServiceCredentials) controller.getAuthToken(
new AuthenticatedAccount(() -> new Pair<>(account, mock(Device.class)))).getEntity();
new AuthenticatedAccount(account, mock(Device.class))).getEntity();
assertEquals(credentials.username(), "d369bc712e2e0dd36258");
assertEquals(credentials.password(), "1633738643:4433b0fab41f25f79dd4");

View File

@ -51,7 +51,10 @@ import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.tests.util.TestPrincipal;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProvider;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.logging.WebsocketRequestLog;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
@ -139,7 +142,7 @@ class MetricsRequestEventListenerTest {
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.reusableAuth("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class);
@ -201,7 +204,7 @@ class MetricsRequestEventListenerTest {
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.reusableAuth("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class);
@ -252,19 +255,6 @@ class MetricsRequestEventListenerTest {
return SubProtocol.WebSocketMessage.parseFrom(responseCaptor.getValue().array()).getResponse();
}
public static class TestPrincipal implements Principal {
private final String name;
private TestPrincipal(String name) {
this.name = name;
}
@Override
public String getName() {
return name;
}
}
@Path("/v1/test")
public static class TestResource {

View File

@ -1,71 +0,0 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.util.Pair;
class RefreshingAccountAndDeviceSupplierTest {
@Test
void test() {
final AccountsManager accountsManager = mock(AccountsManager.class);
final UUID uuid = UUID.randomUUID();
final byte deviceId = 2;
final Account initialAccount = mock(Account.class);
final Device initialDevice = mock(Device.class);
when(initialAccount.getUuid()).thenReturn(uuid);
when(initialDevice.getId()).thenReturn(deviceId);
when(initialAccount.getDevice(deviceId)).thenReturn(Optional.of(initialDevice));
when(accountsManager.getByAccountIdentifier(any(UUID.class))).thenAnswer(answer -> {
final Account account = mock(Account.class);
final Device device = mock(Device.class);
when(account.getUuid()).thenReturn(answer.getArgument(0, UUID.class));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
return Optional.of(account);
});
final RefreshingAccountAndDeviceSupplier refreshingAccountAndDeviceSupplier = new RefreshingAccountAndDeviceSupplier(
initialAccount, deviceId, accountsManager);
Pair<Account, Device> accountAndDevice = refreshingAccountAndDeviceSupplier.get();
assertSame(initialAccount, accountAndDevice.first());
assertSame(initialDevice, accountAndDevice.second());
accountAndDevice = refreshingAccountAndDeviceSupplier.get();
assertSame(initialAccount, accountAndDevice.first());
assertSame(initialDevice, accountAndDevice.second());
when(initialAccount.isStale()).thenReturn(true);
accountAndDevice = refreshingAccountAndDeviceSupplier.get();
assertNotSame(initialAccount, accountAndDevice.first());
assertNotSame(initialDevice, accountAndDevice.second());
assertEquals(uuid, accountAndDevice.first().getUuid());
}
}

View File

@ -0,0 +1,27 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.util;
import java.security.Principal;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.PrincipalSupplier;
public class TestPrincipal implements Principal {
private final String name;
private TestPrincipal(String name) {
this.name = name;
}
@Override
public String getName() {
return name;
}
public static ReusableAuth<TestPrincipal> reusableAuth(final String name) {
return ReusableAuth.authenticated(new TestPrincipal(name), PrincipalSupplier.forImmutablePrincipal());
}
}

View File

@ -0,0 +1,79 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.util;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.whispersystems.websocket.messages.WebSocketMessage;
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
public class TestWebsocketListener implements WebSocketListener {
private final AtomicLong requestId = new AtomicLong();
private final CompletableFuture<Session> started = new CompletableFuture<>();
private final ConcurrentHashMap<Long, CompletableFuture<WebSocketResponseMessage>> responseFutures = new ConcurrentHashMap<>();
private final WebSocketMessageFactory messageFactory;
public TestWebsocketListener() {
this.messageFactory = new ProtobufWebSocketMessageFactory();
}
@Override
public void onWebSocketConnect(final Session session) {
started.complete(session);
}
public CompletableFuture<WebSocketResponseMessage> doGet(final String requestPath) {
return sendRequest(requestPath, "GET", List.of("Accept: application/json"), Optional.empty());
}
public CompletableFuture<WebSocketResponseMessage> sendRequest(
final String requestPath,
final String verb,
final List<String> headers,
final Optional<byte[]> body) {
return started.thenCompose(session -> {
final long id = requestId.incrementAndGet();
final CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
responseFutures.put(id, future);
final byte[] requestBytes = messageFactory.createRequest(
Optional.of(id), verb, requestPath, headers, body).toByteArray();
try {
session.getRemote().sendBytes(ByteBuffer.wrap(requestBytes));
} catch (IOException e) {
throw new RuntimeException(e);
}
return future;
});
}
@Override
public void onWebSocketBinary(final byte[] payload, final int offset, final int length) {
try {
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) {
responseFutures.get(webSocketMessage.getResponseMessage().getRequestId())
.complete(webSocketMessage.getResponseMessage());
} else {
throw new RuntimeException("Unexpected message type: " + webSocketMessage.getType());
}
} catch (final Exception e) {
throw new RuntimeException(e);
}
}
}

View File

@ -57,6 +57,7 @@ import org.junit.jupiter.params.provider.MethodSource;
import org.slf4j.Logger;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper;
import org.whispersystems.textsecuregcm.tests.util.TestPrincipal;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.WebSocketResourceProvider;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
@ -175,7 +176,8 @@ class LoggingUnhandledExceptionMapperTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog,
TestPrincipal.reusableAuth("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@ -238,18 +240,4 @@ class LoggingUnhandledExceptionMapperTest {
throw new RuntimeException();
}
}
public static class TestPrincipal implements Principal {
private final String name;
private TestPrincipal(String name) {
this.name = name;
}
@Override
public String getName() {
return name;
}
}
}

View File

@ -28,8 +28,8 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.PrincipalSupplier;
class WebSocketAccountAuthenticatorTest {
@ -52,7 +52,7 @@ class WebSocketAccountAuthenticatorTest {
accountAuthenticator = mock(AccountAuthenticator.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(mock(Account.class), mock(Device.class)))));
.thenReturn(Optional.of(new AuthenticatedAccount(mock(Account.class), mock(Device.class))));
when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD))))
.thenReturn(Optional.empty());
@ -66,7 +66,7 @@ class WebSocketAccountAuthenticatorTest {
@Nullable final String authorizationHeaderValue,
final Map<String, List<String>> upgradeRequestParameters,
final boolean expectAccount,
final boolean expectCredentialsPresented) throws Exception {
final boolean expectInvalid) throws Exception {
when(upgradeRequest.getParameterMap()).thenReturn(upgradeRequestParameters);
if (authorizationHeaderValue != null) {
@ -74,13 +74,13 @@ class WebSocketAccountAuthenticatorTest {
}
final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(
accountAuthenticator);
accountAuthenticator,
mock(PrincipalSupplier.class));
final WebSocketAuthenticator.AuthenticationResult<AuthenticatedAccount> result = webSocketAuthenticator.authenticate(
upgradeRequest);
final ReusableAuth<AuthenticatedAccount> result = webSocketAuthenticator.authenticate(upgradeRequest);
assertEquals(expectAccount, result.getUser().isPresent());
assertEquals(expectCredentialsPresented, result.credentialsPresented());
assertEquals(expectAccount, result.ref().isPresent());
assertEquals(expectInvalid, result.invalidCredentialsProvided());
}
private static Stream<Arguments> testAuthenticate() {
@ -94,17 +94,17 @@ class WebSocketAccountAuthenticatorTest {
HeaderUtils.basicAuthHeader(INVALID_USER, INVALID_PASSWORD);
return Stream.of(
// if `Authorization` header is present, outcome should not depend on the value of query parameters
Arguments.of(headerWithValidAuth, Map.of(), true, true),
Arguments.of(headerWithValidAuth, Map.of(), true, false),
Arguments.of(headerWithInvalidAuth, Map.of(), false, true),
Arguments.of("invalid header value", Map.of(), false, true),
Arguments.of(headerWithValidAuth, paramsMapWithValidAuth, true, true),
Arguments.of(headerWithValidAuth, paramsMapWithValidAuth, true, false),
Arguments.of(headerWithInvalidAuth, paramsMapWithValidAuth, false, true),
Arguments.of("invalid header value", paramsMapWithValidAuth, false, true),
Arguments.of(headerWithValidAuth, paramsMapWithInvalidAuth, true, true),
Arguments.of(headerWithValidAuth, paramsMapWithInvalidAuth, true, false),
Arguments.of(headerWithInvalidAuth, paramsMapWithInvalidAuth, false, true),
Arguments.of("invalid header value", paramsMapWithInvalidAuth, false, true),
// if `Authorization` header is not set, outcome should match the query params based auth
Arguments.of(null, paramsMapWithValidAuth, true, true),
Arguments.of(null, paramsMapWithValidAuth, true, false),
Arguments.of(null, paramsMapWithInvalidAuth, false, true),
Arguments.of(null, Map.of(), false, false)
);

View File

@ -125,7 +125,7 @@ class WebSocketConnectionIntegrationTest {
final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
new AuthenticatedAccount(() -> new Pair<>(account, device)),
new AuthenticatedAccount(account, device),
device,
webSocketClient,
scheduledExecutorService,
@ -210,7 +210,7 @@ class WebSocketConnectionIntegrationTest {
final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
new AuthenticatedAccount(() -> new Pair<>(account, device)),
new AuthenticatedAccount(account, device),
device,
webSocketClient,
scheduledExecutorService,
@ -276,7 +276,7 @@ class WebSocketConnectionIntegrationTest {
final WebSocketConnection webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
new AuthenticatedAccount(() -> new Pair<>(account, device)),
new AuthenticatedAccount(account, device),
device,
webSocketClient,
100, // use a very short timeout, so that this test completes quickly

View File

@ -64,9 +64,9 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import reactor.core.publisher.Flux;
@ -101,7 +101,7 @@ class WebSocketConnectionTest {
accountsManager = mock(AccountsManager.class);
account = mock(Account.class);
device = mock(Device.class);
auth = new AuthenticatedAccount(() -> new Pair<>(account, device));
auth = new AuthenticatedAccount(account, device);
upgradeRequest = mock(UpgradeRequest.class);
messagesManager = mock(MessagesManager.class);
receiptSender = mock(ReceiptSender.class);
@ -118,18 +118,19 @@ class WebSocketConnectionTest {
@Test
void testCredentials() throws Exception {
WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
WebSocketAccountAuthenticator webSocketAuthenticator =
new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class));
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager,
mock(PushNotificationManager.class), mock(ClientPresenceManager.class),
retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager);
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(account, device))));
.thenReturn(Optional.of(new AuthenticatedAccount(account, device)));
AuthenticationResult<AuthenticatedAccount> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated()).thenReturn(account.getUser().orElse(null));
when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.getUser().orElse(null));
ReusableAuth<AuthenticatedAccount> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated()).thenReturn(account.ref().orElse(null));
when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.ref().orElse(null));
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
when(webSocketClient.getUserAgent()).thenReturn("Signal-Android/6.22.8");
@ -144,8 +145,8 @@ class WebSocketConnectionTest {
// unauthenticated
when(upgradeRequest.getParameterMap()).thenReturn(Map.of());
account = webSocketAuthenticator.authenticate(upgradeRequest);
assertFalse(account.getUser().isPresent());
assertFalse(account.credentialsPresented());
assertFalse(account.ref().isPresent());
assertFalse(account.invalidCredentialsProvided());
connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext, times(2)).addWebsocketClosedListener(

View File

@ -0,0 +1,182 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket;
import java.security.Principal;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import org.whispersystems.websocket.auth.PrincipalSupplier;
/**
* This class holds a principal that can be reused across requests on a websocket. Since two requests may operate
* concurrently on the same principal, and some principals contain non thread-safe mutable state, appropriate use of
* this class ensures that no data races occur. It also ensures that after a principal is modified, a subsequent request
* gets the up-to-date principal
*
* @param <T> The underlying principal type
* @see PrincipalSupplier
*/
public abstract sealed class ReusableAuth<T extends Principal> {
/**
* Get a reference to the underlying principal that callers pledge not to modify.
* <p>
* The reference returned will potentially be provided to many threads concurrently accessing the principal. Callers
* should use this method only if they can ensure that they will not modify the in-memory principal object AND they do
* not intend to modify the underlying canonical representation of the principal.
* <p>
* For example, if a caller retrieves a reference to a principal, does not modify the in memory state, but updates a
* field on a database that should be reflected in subsequent retrievals of the principal, they will have met the
* first criteria, but not the second. In that case they should instead use {@link #mutableRef()}.
* <p>
* If other callers have modified the underlying principal by using {@link #mutableRef()}, this method may need to
* refresh the principal via {@link PrincipalSupplier#refresh} which could be a blocking operation.
*
* @return If authenticated, a reference to the underlying principal that should not be modified
*/
public abstract Optional<T> ref();
public interface MutableRef<T> {
T ref();
void close();
}
/**
* Get a reference to the underlying principal that may be modified.
* <p>
* The underlying principal can be safely modified. Multiple threads may operate on the same {@link ReusableAuth} so
* long as they each have their own {@link MutableRef}. After any modifications, the caller must call
* {@link MutableRef#close} to notify the principal has become dirty. Close should be called after modifications but
* before sending a response on the websocket. This ensures that a request that comes in after a modification response
* is received is guaranteed to see the modification.
*
* @return If authenticated, a reference to the underlying principal that may be modified
*/
public abstract Optional<MutableRef<T>> mutableRef();
public boolean invalidCredentialsProvided() {
return switch (this) {
case Invalid<T> ignored -> true;
case ReusableAuth.Anonymous<T> ignored -> false;
case ReusableAuth.Authenticated<T> ignored-> false;
};
}
/**
* @return A {@link ReusableAuth} indicating no credential were provided
*/
public static <T extends Principal> ReusableAuth<T> anonymous() {
//noinspection unchecked
return (ReusableAuth<T>) Anonymous.ANON_RESULT;
}
/**
* @return A {@link ReusableAuth} indicating that invalid credentials were provided
*/
public static <T extends Principal> ReusableAuth<T> invalid() {
//noinspection unchecked
return (ReusableAuth<T>) Invalid.INVALID_RESULT;
}
/**
* Create a successfully authenticated {@link ReusableAuth}
*
* @param principal The authenticated principal
* @param principalSupplier Instructions for how to refresh or copy this principal
* @param <T> The principal type
* @return A {@link ReusableAuth} for a successfully authenticated principal
*/
public static <T extends Principal> ReusableAuth<T> authenticated(T principal,
PrincipalSupplier<T> principalSupplier) {
return new Authenticated<>(principal, principalSupplier);
}
private static final class Invalid<T extends Principal> extends ReusableAuth<T> {
@SuppressWarnings({"rawtypes"})
private static final ReusableAuth INVALID_RESULT = new Invalid();
@Override
public Optional<T> ref() {
return Optional.empty();
}
@Override
public Optional<MutableRef<T>> mutableRef() {
return Optional.empty();
}
}
private static final class Anonymous<T extends Principal> extends ReusableAuth<T> {
@SuppressWarnings({"rawtypes"})
private static final ReusableAuth ANON_RESULT = new Anonymous();
@Override
public Optional<T> ref() {
return Optional.empty();
}
@Override
public Optional<MutableRef<T>> mutableRef() {
return Optional.empty();
}
}
private static final class Authenticated<T extends Principal> extends ReusableAuth<T> {
private T basePrincipal;
private final AtomicBoolean needRefresh = new AtomicBoolean(false);
private final PrincipalSupplier<T> principalSupplier;
Authenticated(final T basePrincipal, PrincipalSupplier<T> principalSupplier) {
this.basePrincipal = basePrincipal;
this.principalSupplier = principalSupplier;
}
@Override
public Optional<T> ref() {
maybeRefresh();
return Optional.of(basePrincipal);
}
@Override
public Optional<MutableRef<T>> mutableRef() {
maybeRefresh();
return Optional.of(new AuthenticatedMutableRef(principalSupplier.deepCopy(basePrincipal)));
}
private void maybeRefresh() {
if (needRefresh.compareAndSet(true, false)) {
basePrincipal = principalSupplier.refresh(basePrincipal);
}
}
private class AuthenticatedMutableRef implements MutableRef<T> {
final T ref;
private AuthenticatedMutableRef(T ref) {
this.ref = ref;
}
public T ref() {
return ref;
}
public void close() {
needRefresh.set(true);
}
}
}
private ReusableAuth() {
}
}

View File

@ -58,7 +58,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
private final Map<Long, CompletableFuture<WebSocketResponseMessage>> requestMap = new ConcurrentHashMap<>();
private final T authenticated;
private final ReusableAuth<T> reusableAuth;
private final WebSocketMessageFactory messageFactory;
private final Optional<WebSocketConnectListener> connectListener;
private final ApplicationHandler jerseyHandler;
@ -77,7 +77,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
String remoteAddressPropertyName,
ApplicationHandler jerseyHandler,
WebsocketRequestLog requestLog,
T authenticated,
ReusableAuth<T> authenticated,
WebSocketMessageFactory messageFactory,
Optional<WebSocketConnectListener> connectListener,
Duration idleTimeout) {
@ -85,7 +85,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
this.remoteAddressPropertyName = remoteAddressPropertyName;
this.jerseyHandler = jerseyHandler;
this.requestLog = requestLog;
this.authenticated = authenticated;
this.reusableAuth = authenticated;
this.messageFactory = messageFactory;
this.connectListener = connectListener;
this.idleTimeout = idleTimeout;
@ -97,7 +97,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
this.remoteEndpoint = session.getRemote();
this.context = new WebSocketSessionContext(
new WebSocketClient(session, remoteEndpoint, messageFactory, requestMap));
this.context.setAuthenticated(authenticated);
this.context.setAuthenticated(reusableAuth.ref().orElse(null));
this.session.setIdleTimeout(idleTimeout);
connectListener.ifPresent(listener -> listener.onWebSocketConnect(this.context));
@ -162,6 +162,17 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
logger.debug("onWebSocketText!");
}
/**
* The property name where {@link org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider} can find an
* {@link ReusableAuth} object that lives for the lifetime of the websocket
*/
public static final String REUSABLE_AUTH_PROPERTY = WebSocketResourceProvider.class.getName() + ".reusableAuth";
/**
* The property name where {@link org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider} can install a
* {@link org.whispersystems.websocket.ReusableAuth.MutableRef} for us to close when the request is finished
*/
public static final String RESOLVED_PRINCIPAL_PROPERTY = WebSocketResourceProvider.class.getName() + ".resolvedPrincipal";
private void handleRequest(WebSocketRequestMessage requestMessage) {
ContainerRequest containerRequest = new ContainerRequest(null, URI.create(requestMessage.getPath()),
requestMessage.getVerb(), new WebSocketSecurityContext(new ContextPrincipal(context)),
@ -173,30 +184,43 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
}
containerRequest.setProperty(remoteAddressPropertyName, remoteAddress);
containerRequest.setProperty(REUSABLE_AUTH_PROPERTY, reusableAuth);
ByteArrayOutputStream responseBody = new ByteArrayOutputStream();
CompletableFuture<ContainerResponse> responseFuture = (CompletableFuture<ContainerResponse>) jerseyHandler.apply(
containerRequest, responseBody);
responseFuture.thenAccept(response -> {
try {
sendResponse(requestMessage, response, responseBody);
} catch (IOException e) {
throw new RuntimeException(e);
}
requestLog.log(remoteAddress, containerRequest, response);
}).exceptionally(exception -> {
logger.warn("Websocket Error: " + requestMessage.getVerb() + " " + requestMessage.getPath() + "\n"
+ requestMessage.getBody(), exception);
try {
sendErrorResponse(requestMessage, Response.status(500).build());
} catch (IOException e) {
logger.warn("Failed to send error response", e);
}
requestLog.log(remoteAddress, containerRequest,
new ContainerResponse(containerRequest, Response.status(500).build()));
return null;
});
responseFuture
.whenComplete((ignoredResponse, ignoredError) -> {
// If the request ended up being one that mutates our principal, we have to close it to indicate we're done
// with the mutation operation
final Object resolvedPrincipal = containerRequest.getProperty(RESOLVED_PRINCIPAL_PROPERTY);
if (resolvedPrincipal instanceof ReusableAuth.MutableRef ref) {
ref.close();
} else if (resolvedPrincipal != null) {
logger.warn("unexpected resolved principal type {} : {}", resolvedPrincipal.getClass(), resolvedPrincipal);
}
})
.thenAccept(response -> {
try {
sendResponse(requestMessage, response, responseBody);
} catch (IOException e) {
throw new RuntimeException(e);
}
requestLog.log(remoteAddress, containerRequest, response);
})
.exceptionally(exception -> {
logger.warn("Websocket Error: " + requestMessage.getVerb() + " " + requestMessage.getPath() + "\n"
+ requestMessage.getBody(), exception);
try {
sendErrorResponse(requestMessage, Response.status(500).build());
} catch (IOException e) {
logger.warn("Failed to send error response", e);
}
requestLog.log(remoteAddress, containerRequest,
new ContainerResponse(containerRequest, Response.status(500).build()));
return null;
});
}
@VisibleForTesting

View File

@ -22,7 +22,6 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider;
@ -57,17 +56,17 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
public Object createWebSocket(final JettyServerUpgradeRequest request, final JettyServerUpgradeResponse response) {
try {
Optional<WebSocketAuthenticator<T>> authenticator = Optional.ofNullable(environment.getAuthenticator());
T authenticated = null;
final ReusableAuth<T> authenticated;
if (authenticator.isPresent()) {
AuthenticationResult<T> authenticationResult = authenticator.get().authenticate(request);
authenticated = authenticator.get().authenticate(request);
if (authenticationResult.getUser().isEmpty() && authenticationResult.credentialsPresented()) {
if (authenticated.invalidCredentialsProvided()) {
response.sendForbidden("Unauthorized");
return null;
} else {
authenticated = authenticationResult.getUser().orElse(null);
}
} else {
authenticated = ReusableAuth.anonymous();
}
return new WebSocketResourceProvider<>(getRemoteAddress(request),

View File

@ -0,0 +1,26 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket.auth;
import io.dropwizard.auth.Auth;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* An @{@link Auth} object annotated with {@link Mutable} indicates that the consumer of the object
* will modify the object or its underlying canonical source.
*
* Note: An {@link Auth} object that does not specify @{@link ReadOnly} will be assumed to be @Mutable
*
* @see org.whispersystems.websocket.ReusableAuth
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD, ElementType.PARAMETER})
public @interface Mutable {
}

View File

@ -0,0 +1,58 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket.auth;
/**
* Teach {@link org.whispersystems.websocket.ReusableAuth} how to make a deep copy of a principal (that is safe to
* concurrently modify while the original principal is being read), and how to refresh a principal after it has been
* potentially modified.
*
* @param <T> The underlying principal type
*/
public interface PrincipalSupplier<T> {
/**
* Re-fresh the principal after it has been modified.
* <p>
* If the principal is populated from a backing store, refresh should re-read it.
*
* @param t the potentially stale principal to refresh
* @return The up-to-date principal
*/
T refresh(T t);
/**
* Create a deep, in-memory copy of the principal. This should be identical to the original principal, but should
* share no mutable state with the original. It should be safe for two threads to independently write and read from
* two independent deep copies.
*
* @param t the principal to copy
* @return An in-memory copy of the principal
*/
T deepCopy(T t);
class ImmutablePrincipalSupplier<T> implements PrincipalSupplier<T> {
@SuppressWarnings({"rawtypes"})
private static final PrincipalSupplier INSTANCE = new ImmutablePrincipalSupplier();
@Override
public T refresh(final T t) {
return t;
}
@Override
public T deepCopy(final T t) {
return t;
}
}
/**
* @return A principal supplier that can be used if the principal type does not support modification.
*/
static <T> PrincipalSupplier<T> forImmutablePrincipal() {
//noinspection unchecked
return (PrincipalSupplier<T>) ImmutablePrincipalSupplier.INSTANCE;
}
}

View File

@ -0,0 +1,25 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket.auth;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* An @{@link io.dropwizard.auth.Auth} object annotated with {@link ReadOnly} indicates that the consumer of the object
* will never modify the object, nor its underlying canonical source.
* <p>
* For example, a consumer of a @ReadOnly AuthenticatedAccount promises to never modify the in-memory
* AuthenticatedAccount and to never modify the underlying Account database for the account.
*
* @see org.whispersystems.websocket.ReusableAuth
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD, ElementType.PARAMETER})
public @interface ReadOnly {
}

View File

@ -7,27 +7,8 @@ package org.whispersystems.websocket.auth;
import java.security.Principal;
import java.util.Optional;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.websocket.ReusableAuth;
public interface WebSocketAuthenticator<T extends Principal> {
AuthenticationResult<T> authenticate(UpgradeRequest request) throws AuthenticationException;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
class AuthenticationResult<T> {
private final Optional<T> user;
private final boolean credentialsPresented;
public AuthenticationResult(final Optional<T> user, final boolean credentialsPresented) {
this.user = user;
this.credentialsPresented = credentialsPresented;
}
public Optional<T> getUser() {
return user;
}
public boolean credentialsPresented() {
return credentialsPresented;
}
}
ReusableAuth<T> authenticate(UpgradeRequest request) throws AuthenticationException;
}

View File

@ -5,24 +5,28 @@
package org.whispersystems.websocket.auth;
import io.dropwizard.auth.Auth;
import java.lang.reflect.ParameterizedType;
import java.security.Principal;
import java.util.Optional;
import java.util.function.Function;
import javax.annotation.Nullable;
import javax.inject.Inject;
import javax.inject.Singleton;
import javax.ws.rs.WebApplicationException;
import org.glassfish.jersey.internal.inject.AbstractBinder;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.internal.inject.AbstractValueParamProvider;
import org.glassfish.jersey.server.internal.inject.MultivaluedParameterExtractorProvider;
import org.glassfish.jersey.server.model.Parameter;
import org.glassfish.jersey.server.spi.internal.ValueParamProvider;
import javax.annotation.Nullable;
import javax.inject.Inject;
import javax.inject.Singleton;
import javax.ws.rs.WebApplicationException;
import java.lang.reflect.ParameterizedType;
import java.security.Principal;
import java.util.Optional;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProvider;
@Singleton
public class WebsocketAuthValueFactoryProvider<T extends Principal> extends AbstractValueParamProvider {
private static final Logger logger = LoggerFactory.getLogger(WebsocketAuthValueFactoryProvider.class);
private final Class<T> principalClass;
@ -39,18 +43,38 @@ public class WebsocketAuthValueFactoryProvider<T extends Principal> extends Abst
return null;
}
if (parameter.getRawType() == Optional.class &&
ParameterizedType.class.isAssignableFrom(parameter.getType().getClass()) &&
principalClass == ((ParameterizedType)parameter.getType()).getActualTypeArguments()[0])
{
return request -> new OptionalContainerRequestValueFactory(request).provide();
final boolean readOnly = parameter.isAnnotationPresent(ReadOnly.class);
if (parameter.getRawType() == Optional.class
&& ParameterizedType.class.isAssignableFrom(parameter.getType().getClass())
&& principalClass == ((ParameterizedType) parameter.getType()).getActualTypeArguments()[0]) {
return containerRequest -> createPrincipal(containerRequest, readOnly);
} else if (principalClass.equals(parameter.getRawType())) {
return request -> new StandardContainerRequestValueFactory(request).provide();
return containerRequest ->
createPrincipal(containerRequest, readOnly)
.orElseThrow(() -> new WebApplicationException("Authenticated resource", 401));
} else {
throw new IllegalStateException("Can't inject unassignable principal: " + principalClass + " for parameter: " + parameter);
}
}
private Optional<? extends Principal> createPrincipal(final ContainerRequest request, final boolean readOnly) {
final Object obj = request.getProperty(WebSocketResourceProvider.REUSABLE_AUTH_PROPERTY);
if (!(obj instanceof ReusableAuth<?>)) {
logger.warn("Unexpected reusable auth property type {} : {}", obj.getClass(), obj);
return Optional.empty();
}
@SuppressWarnings("unchecked") final ReusableAuth<T> reusableAuth = (ReusableAuth<T>) obj;
if (readOnly) {
return reusableAuth.ref();
} else {
return reusableAuth.mutableRef().map(writeRef -> {
request.setProperty(WebSocketResourceProvider.RESOLVED_PRINCIPAL_PROPERTY, writeRef);
return writeRef.ref();
});
}
}
@Singleton
static class WebsocketPrincipalClassProvider<T extends Principal> {
@ -80,38 +104,4 @@ public class WebsocketAuthValueFactoryProvider<T extends Principal> extends Abst
bind(WebsocketAuthValueFactoryProvider.class).to(ValueParamProvider.class).in(Singleton.class);
}
}
private static class StandardContainerRequestValueFactory {
private final ContainerRequest request;
public StandardContainerRequestValueFactory(ContainerRequest request) {
this.request = request;
}
public Principal provide() {
final Principal principal = request.getSecurityContext().getUserPrincipal();
if (principal == null) {
throw new WebApplicationException("Authenticated resource", 401);
}
return principal;
}
}
private static class OptionalContainerRequestValueFactory {
private final ContainerRequest request;
public OptionalContainerRequestValueFactory(ContainerRequest request) {
this.request = request;
}
public Optional<Principal> provide() {
return Optional.ofNullable(request.getSecurityContext().getUserPrincipal());
}
}
}

View File

@ -7,6 +7,7 @@ package org.whispersystems.websocket;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@ -27,6 +28,7 @@ import org.glassfish.jersey.server.ResourceConfig;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.websocket.auth.AuthenticationException;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ -56,8 +58,7 @@ public class WebSocketResourceProviderFactoryTest {
@Test
void testUnauthorized() throws AuthenticationException, IOException {
when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenReturn(
new WebSocketAuthenticator.AuthenticationResult<>(Optional.empty(), true));
when(authenticator.authenticate(eq(request))).thenReturn(ReusableAuth.invalid());
when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
@ -74,8 +75,8 @@ public class WebSocketResourceProviderFactoryTest {
Account account = new Account();
when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenReturn(
new WebSocketAuthenticator.AuthenticationResult<>(Optional.of(account), true));
when(authenticator.authenticate(eq(request)))
.thenReturn(ReusableAuth.authenticated(account, PrincipalSupplier.forImmutablePrincipal()));
when(environment.jersey()).thenReturn(jerseyEnvironment);
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1");
@ -137,6 +138,7 @@ public class WebSocketResourceProviderFactoryTest {
public boolean implies(Subject subject) {
return false;
}
}

View File

@ -59,6 +59,7 @@ import org.glassfish.jersey.server.ResourceConfig;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.logging.WebsocketRequestLog;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
@ -80,7 +81,7 @@ class WebSocketResourceProviderTest {
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME,
applicationHandler, requestLog,
new TestPrincipal("fooz"),
immutableTestPrincipal("fooz"),
new ProtobufWebSocketMessageFactory(),
Optional.of(connectListener),
Duration.ofMillis(30000));
@ -108,7 +109,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -184,7 +185,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -240,7 +241,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -280,7 +281,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -320,7 +321,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("authorizedUserName"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("authorizedUserName"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -360,8 +361,8 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(),
Optional.empty(), Duration.ofMillis(30000));
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, ReusableAuth.anonymous(),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@ -399,7 +400,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("something"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("something"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -439,8 +440,8 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(),
Optional.empty(), Duration.ofMillis(30000));
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, ReusableAuth.anonymous(),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@ -479,7 +480,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -520,7 +521,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -562,7 +563,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -602,7 +603,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -727,6 +728,10 @@ class WebSocketResourceProviderTest {
}
}
public static ReusableAuth<TestPrincipal> immutableTestPrincipal(final String name) {
return ReusableAuth.authenticated(new TestPrincipal(name), PrincipalSupplier.forImmutablePrincipal());
}
public static class TestException extends Exception {
public TestException(String message) {