Lifecycle management for Account objects reused accross websocket requests
This commit is contained in:
parent
29ef3f0b41
commit
26ffa19f36
|
@ -188,6 +188,7 @@ import org.whispersystems.textsecuregcm.spam.SenderOverrideProvider;
|
|||
import org.whispersystems.textsecuregcm.spam.SpamChecker;
|
||||
import org.whispersystems.textsecuregcm.spam.SpamFilter;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountLockManager;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountPrincipalSupplier;
|
||||
import org.whispersystems.textsecuregcm.storage.Accounts;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.ChangeNumberManager;
|
||||
|
@ -812,7 +813,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
|||
WebSocketEnvironment<AuthenticatedAccount> webSocketEnvironment = new WebSocketEnvironment<>(environment,
|
||||
config.getWebSocketConfiguration(), Duration.ofMillis(90000));
|
||||
webSocketEnvironment.jersey().register(new VirtualExecutorServiceProvider("managed-async-websocket-virtual-thread-"));
|
||||
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator));
|
||||
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator, new AccountPrincipalSupplier(accountsManager)));
|
||||
webSocketEnvironment.setConnectListener(
|
||||
new AuthenticatedConnectListener(receiptSender, messagesManager, pushNotificationManager,
|
||||
clientPresenceManager, websocketScheduledExecutor, messageDeliveryScheduler, clientReleaseManager));
|
||||
|
|
|
@ -21,7 +21,6 @@ import org.apache.commons.lang3.StringUtils;
|
|||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.storage.RefreshingAccountAndDeviceSupplier;
|
||||
import org.whispersystems.textsecuregcm.util.Pair;
|
||||
import org.whispersystems.textsecuregcm.util.Util;
|
||||
|
||||
|
@ -108,8 +107,7 @@ public class AccountAuthenticator implements Authenticator<BasicCredentials, Aut
|
|||
device.get(),
|
||||
SaltedTokenHash.generateFor(basicCredentials.getPassword())); // new credentials have current version
|
||||
}
|
||||
return Optional.of(new AuthenticatedAccount(
|
||||
new RefreshingAccountAndDeviceSupplier(authenticatedAccount, device.get().getId(), accountsManager)));
|
||||
return Optional.of(new AuthenticatedAccount(authenticatedAccount, device.get()));
|
||||
}
|
||||
|
||||
return Optional.empty();
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
|
||||
package org.whispersystems.textsecuregcm.auth;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
|
@ -45,10 +44,6 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
|
|||
this.accountsManager = accountsManager;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static Map<Byte, Boolean> buildDevicesEnabledMap(final Account account) {
|
||||
return account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::isEnabled));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleRequestFiltered(final RequestEvent requestEvent) {
|
||||
|
@ -60,10 +55,13 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
|
|||
setAccount(requestEvent.getContainerRequest(), account));
|
||||
}
|
||||
}
|
||||
|
||||
public static void setAccount(final ContainerRequest containerRequest, final Account account) {
|
||||
containerRequest.setProperty(ACCOUNT_UUID, account.getUuid());
|
||||
containerRequest.setProperty(DEVICES_ENABLED, buildDevicesEnabledMap(account));
|
||||
setAccount(containerRequest, ContainerRequestUtil.AccountInfo.fromAccount(account));
|
||||
}
|
||||
|
||||
private static void setAccount(final ContainerRequest containerRequest, final ContainerRequestUtil.AccountInfo info) {
|
||||
containerRequest.setProperty(ACCOUNT_UUID, info.accountId());
|
||||
containerRequest.setProperty(DEVICES_ENABLED, info.devicesEnabled());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -75,25 +73,28 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
|
|||
@SuppressWarnings("unchecked") final Map<Byte, Boolean> initialDevicesEnabled =
|
||||
(Map<Byte, Boolean>) requestEvent.getContainerRequest().getProperty(DEVICES_ENABLED);
|
||||
|
||||
return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID)).map(account -> {
|
||||
final Set<Byte> deviceIdsToDisplace;
|
||||
final Map<Byte, Boolean> currentDevicesEnabled = buildDevicesEnabledMap(account);
|
||||
return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID))
|
||||
.map(ContainerRequestUtil.AccountInfo::fromAccount)
|
||||
.map(account -> {
|
||||
final Set<Byte> deviceIdsToDisplace;
|
||||
final Map<Byte, Boolean> currentDevicesEnabled = account.devicesEnabled();
|
||||
|
||||
if (!initialDevicesEnabled.equals(currentDevicesEnabled)) {
|
||||
deviceIdsToDisplace = new HashSet<>(initialDevicesEnabled.keySet());
|
||||
deviceIdsToDisplace.addAll(currentDevicesEnabled.keySet());
|
||||
} else {
|
||||
deviceIdsToDisplace = Collections.emptySet();
|
||||
}
|
||||
if (!initialDevicesEnabled.equals(currentDevicesEnabled)) {
|
||||
deviceIdsToDisplace = new HashSet<>(initialDevicesEnabled.keySet());
|
||||
deviceIdsToDisplace.addAll(currentDevicesEnabled.keySet());
|
||||
} else {
|
||||
deviceIdsToDisplace = Collections.emptySet();
|
||||
}
|
||||
|
||||
return deviceIdsToDisplace.stream()
|
||||
.map(deviceId -> new Pair<>(account.getUuid(), deviceId))
|
||||
.collect(Collectors.toList());
|
||||
}).orElseGet(() -> {
|
||||
logger.error("Request had account, but it is no longer present");
|
||||
return Collections.emptyList();
|
||||
});
|
||||
} else
|
||||
return deviceIdsToDisplace.stream()
|
||||
.map(deviceId -> new Pair<>(account.accountId(), deviceId))
|
||||
.collect(Collectors.toList());
|
||||
}).orElseGet(() -> {
|
||||
logger.error("Request had account, but it is no longer present");
|
||||
return Collections.emptyList();
|
||||
});
|
||||
} else {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,24 +10,24 @@ import java.util.function.Supplier;
|
|||
import javax.security.auth.Subject;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.util.Pair;
|
||||
|
||||
public class AuthenticatedAccount implements Principal, AccountAndAuthenticatedDeviceHolder {
|
||||
private final Account account;
|
||||
private final Device device;
|
||||
|
||||
private final Supplier<Pair<Account, Device>> accountAndDevice;
|
||||
|
||||
public AuthenticatedAccount(final Supplier<Pair<Account, Device>> accountAndDevice) {
|
||||
this.accountAndDevice = accountAndDevice;
|
||||
public AuthenticatedAccount(final Account account, final Device device) {
|
||||
this.account = account;
|
||||
this.device = device;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Account getAccount() {
|
||||
return accountAndDevice.get().first();
|
||||
return account;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Device getAuthenticatedDevice() {
|
||||
return accountAndDevice.get().second();
|
||||
return device;
|
||||
}
|
||||
|
||||
// Principal implementation
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.auth;
|
||||
|
||||
import java.lang.annotation.ElementType;
|
||||
import java.lang.annotation.Retention;
|
||||
import java.lang.annotation.RetentionPolicy;
|
||||
import java.lang.annotation.Target;
|
||||
|
||||
/**
|
||||
* Indicates that an endpoint changes the phone number and PNI keys associated with an account, and that
|
||||
* any websockets associated with the account may need to be refreshed after a call to that endpoint.
|
||||
*/
|
||||
@Target(ElementType.METHOD)
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
public @interface ChangesPhoneNumber {
|
||||
}
|
|
@ -7,15 +7,42 @@ package org.whispersystems.textsecuregcm.auth;
|
|||
|
||||
import org.glassfish.jersey.server.ContainerRequest;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import javax.ws.rs.core.SecurityContext;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
class ContainerRequestUtil {
|
||||
|
||||
static Optional<Account> getAuthenticatedAccount(final ContainerRequest request) {
|
||||
private static Map<Byte, Boolean> buildDevicesEnabledMap(final Account account) {
|
||||
return account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::isEnabled));
|
||||
}
|
||||
|
||||
/**
|
||||
* A read-only subset of the authenticated Account object, to enforce that filter-based consumers do not perform
|
||||
* account modifying operations.
|
||||
*/
|
||||
record AccountInfo(UUID accountId, String e164, Map<Byte, Boolean> devicesEnabled) {
|
||||
|
||||
static AccountInfo fromAccount(final Account account) {
|
||||
return new AccountInfo(
|
||||
account.getUuid(),
|
||||
account.getNumber(),
|
||||
buildDevicesEnabledMap(account));
|
||||
}
|
||||
}
|
||||
|
||||
static Optional<AccountInfo> getAuthenticatedAccount(final ContainerRequest request) {
|
||||
return Optional.ofNullable(request.getSecurityContext())
|
||||
.map(SecurityContext::getUserPrincipal)
|
||||
.map(principal -> principal instanceof AccountAndAuthenticatedDeviceHolder
|
||||
? ((AccountAndAuthenticatedDeviceHolder) principal).getAccount() : null);
|
||||
.map(principal -> {
|
||||
if (principal instanceof AccountAndAuthenticatedDeviceHolder aaadh) {
|
||||
return aaadh.getAccount();
|
||||
}
|
||||
return null;
|
||||
})
|
||||
.map(AccountInfo::fromAccount);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,40 +7,50 @@ package org.whispersystems.textsecuregcm.auth;
|
|||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.stream.Collectors;
|
||||
import org.glassfish.jersey.server.monitoring.RequestEvent;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.util.Pair;
|
||||
|
||||
public class PhoneNumberChangeRefreshRequirementProvider implements WebsocketRefreshRequirementProvider {
|
||||
|
||||
private static final String ACCOUNT_UUID =
|
||||
PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".accountUuid";
|
||||
|
||||
private static final String INITIAL_NUMBER_KEY =
|
||||
PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".initialNumber";
|
||||
private final AccountsManager accountsManager;
|
||||
|
||||
public PhoneNumberChangeRefreshRequirementProvider(final AccountsManager accountsManager) {
|
||||
this.accountsManager = accountsManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleRequestFiltered(final RequestEvent requestEvent) {
|
||||
if (requestEvent.getUriInfo().getMatchedResourceMethod().getInvocable().getHandlingMethod()
|
||||
.getAnnotation(ChangesPhoneNumber.class) == null) {
|
||||
return;
|
||||
}
|
||||
ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest())
|
||||
.ifPresent(account -> requestEvent.getContainerRequest().setProperty(INITIAL_NUMBER_KEY, account.getNumber()));
|
||||
.ifPresent(account -> {
|
||||
requestEvent.getContainerRequest().setProperty(INITIAL_NUMBER_KEY, account.e164());
|
||||
requestEvent.getContainerRequest().setProperty(ACCOUNT_UUID, account.accountId());
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Pair<UUID, Byte>> handleRequestFinished(final RequestEvent requestEvent) {
|
||||
final String initialNumber = (String) requestEvent.getContainerRequest().getProperty(INITIAL_NUMBER_KEY);
|
||||
|
||||
if (initialNumber != null) {
|
||||
final Optional<Account> maybeAuthenticatedAccount =
|
||||
ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest());
|
||||
|
||||
return maybeAuthenticatedAccount
|
||||
.filter(account -> !initialNumber.equals(account.getNumber()))
|
||||
.map(account -> account.getDevices().stream()
|
||||
.map(device -> new Pair<>(account.getUuid(), device.getId()))
|
||||
.collect(Collectors.toList()))
|
||||
.orElse(Collections.emptyList());
|
||||
} else {
|
||||
if (initialNumber == null) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID))
|
||||
.filter(account -> !initialNumber.equals(account.getNumber()))
|
||||
.map(account -> account.getDevices().stream()
|
||||
.map(device -> new Pair<>(account.getUuid(), device.getId()))
|
||||
.collect(Collectors.toList()))
|
||||
.orElse(Collections.emptyList());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ public class WebsocketRefreshApplicationEventListener implements ApplicationEven
|
|||
|
||||
this.websocketRefreshRequestEventListener = new WebsocketRefreshRequestEventListener(clientPresenceManager,
|
||||
new AuthEnablementRefreshRequirementProvider(accountsManager),
|
||||
new PhoneNumberChangeRefreshRequirementProvider());
|
||||
new PhoneNumberChangeRefreshRequirementProvider(accountsManager));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -36,6 +36,7 @@ import javax.ws.rs.WebApplicationException;
|
|||
import javax.ws.rs.core.MediaType;
|
||||
import javax.ws.rs.core.Response;
|
||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
|
||||
import org.whispersystems.textsecuregcm.auth.ChangesPhoneNumber;
|
||||
import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
|
||||
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
|
||||
import org.whispersystems.textsecuregcm.entities.AccountDataReportResponse;
|
||||
|
@ -90,6 +91,7 @@ public class AccountControllerV2 {
|
|||
@Path("/number")
|
||||
@Consumes(MediaType.APPLICATION_JSON)
|
||||
@Produces(MediaType.APPLICATION_JSON)
|
||||
@ChangesPhoneNumber
|
||||
@Operation(summary = "Change number", description = "Changes a phone number for an existing account.")
|
||||
@ApiResponse(responseCode = "200", description = "The phone number associated with the authenticated account was changed successfully", useReturnTypeSchema = true)
|
||||
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
|
||||
import org.whispersystems.websocket.auth.PrincipalSupplier;
|
||||
|
||||
public class AccountPrincipalSupplier implements PrincipalSupplier<AuthenticatedAccount> {
|
||||
|
||||
private final AccountsManager accountsManager;
|
||||
|
||||
public AccountPrincipalSupplier(final AccountsManager accountsManager) {
|
||||
this.accountsManager = accountsManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AuthenticatedAccount refresh(final AuthenticatedAccount oldAccount) {
|
||||
final Account account = accountsManager.getByAccountIdentifier(oldAccount.getAccount().getUuid())
|
||||
.orElseThrow(() -> new RefreshingAccountNotFoundException("Could not find account"));
|
||||
final Device device = account.getDevice(oldAccount.getAuthenticatedDevice().getId())
|
||||
.orElseThrow(() -> new RefreshingAccountNotFoundException("Could not find device"));
|
||||
return new AuthenticatedAccount(account, device);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AuthenticatedAccount deepCopy(final AuthenticatedAccount authenticatedAccount) {
|
||||
final Account cloned = AccountUtil.cloneAccountAsNotStale(authenticatedAccount.getAccount());
|
||||
return new AuthenticatedAccount(
|
||||
cloned,
|
||||
cloned.getDevice(authenticatedAccount.getAuthenticatedDevice().getId())
|
||||
.orElseThrow(() -> new IllegalStateException(
|
||||
"Could not find device from a clone of an account where the device was present")));
|
||||
}
|
||||
}
|
|
@ -5,10 +5,12 @@
|
|||
|
||||
package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
import org.whispersystems.websocket.auth.PrincipalSupplier;
|
||||
import java.io.IOException;
|
||||
|
||||
class AccountUtil {
|
||||
public class AccountUtil {
|
||||
|
||||
static Account cloneAccountAsNotStale(final Account account) {
|
||||
try {
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
/*
|
||||
* Copyright 2021 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
import java.util.function.Supplier;
|
||||
import org.whispersystems.textsecuregcm.util.Pair;
|
||||
|
||||
public class RefreshingAccountAndDeviceSupplier implements Supplier<Pair<Account, Device>> {
|
||||
|
||||
private Account account;
|
||||
private Device device;
|
||||
private final AccountsManager accountsManager;
|
||||
|
||||
public RefreshingAccountAndDeviceSupplier(Account account, byte deviceId, AccountsManager accountsManager) {
|
||||
this.account = account;
|
||||
this.device = account.getDevice(deviceId)
|
||||
.orElseThrow(() -> new RefreshingAccountAndDeviceNotFoundException("Could not find device"));
|
||||
this.accountsManager = accountsManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Pair<Account, Device> get() {
|
||||
if (account.isStale()) {
|
||||
account = accountsManager.getByAccountIdentifier(account.getUuid())
|
||||
.orElseThrow(() -> new RuntimeException("Could not find account"));
|
||||
device = account.getDevice(device.getId())
|
||||
.orElseThrow(() -> new RefreshingAccountAndDeviceNotFoundException("Could not find device"));
|
||||
}
|
||||
|
||||
return new Pair<>(account, device);
|
||||
}
|
||||
}
|
|
@ -5,9 +5,9 @@
|
|||
|
||||
package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
public class RefreshingAccountAndDeviceNotFoundException extends RuntimeException {
|
||||
public class RefreshingAccountNotFoundException extends RuntimeException {
|
||||
|
||||
public RefreshingAccountAndDeviceNotFoundException(final String message) {
|
||||
public RefreshingAccountNotFoundException(final String message) {
|
||||
super(message);
|
||||
}
|
||||
|
|
@ -11,41 +11,40 @@ import com.google.common.net.HttpHeaders;
|
|||
import io.dropwizard.auth.basic.BasicCredentials;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import javax.annotation.Nullable;
|
||||
import org.eclipse.jetty.websocket.api.UpgradeRequest;
|
||||
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
|
||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
|
||||
import org.whispersystems.websocket.ReusableAuth;
|
||||
import org.whispersystems.websocket.auth.AuthenticationException;
|
||||
import org.whispersystems.websocket.auth.PrincipalSupplier;
|
||||
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
|
||||
|
||||
|
||||
public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<AuthenticatedAccount> {
|
||||
|
||||
private static final AuthenticationResult<AuthenticatedAccount> CREDENTIALS_NOT_PRESENTED =
|
||||
new AuthenticationResult<>(Optional.empty(), false);
|
||||
private static final ReusableAuth<AuthenticatedAccount> CREDENTIALS_NOT_PRESENTED = ReusableAuth.anonymous();
|
||||
|
||||
private static final AuthenticationResult<AuthenticatedAccount> INVALID_CREDENTIALS_PRESENTED =
|
||||
new AuthenticationResult<>(Optional.empty(), true);
|
||||
private static final ReusableAuth<AuthenticatedAccount> INVALID_CREDENTIALS_PRESENTED = ReusableAuth.invalid();
|
||||
|
||||
private final AccountAuthenticator accountAuthenticator;
|
||||
private final PrincipalSupplier<AuthenticatedAccount> principalSupplier;
|
||||
|
||||
|
||||
public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator) {
|
||||
public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator,
|
||||
final PrincipalSupplier<AuthenticatedAccount> principalSupplier) {
|
||||
this.accountAuthenticator = accountAuthenticator;
|
||||
this.principalSupplier = principalSupplier;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AuthenticationResult<AuthenticatedAccount> authenticate(final UpgradeRequest request)
|
||||
public ReusableAuth<AuthenticatedAccount> authenticate(final UpgradeRequest request)
|
||||
throws AuthenticationException {
|
||||
try {
|
||||
final AuthenticationResult<AuthenticatedAccount> authResultFromHeader =
|
||||
authenticatedAccountFromHeaderAuth(request.getHeader(HttpHeaders.AUTHORIZATION));
|
||||
// the logic here is that if the `Authorization` header was set for the request,
|
||||
// it takes the priority and we use the result of the header-based auth
|
||||
// ignoring the result of the query-based auth.
|
||||
if (authResultFromHeader.credentialsPresented()) {
|
||||
return authResultFromHeader;
|
||||
// If the `Authorization` header was set for the request it takes priority, and we use the result of the
|
||||
// header-based auth ignoring the result of the query-based auth.
|
||||
final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);
|
||||
if (authHeader != null) {
|
||||
return authenticatedAccountFromHeaderAuth(authHeader);
|
||||
}
|
||||
return authenticatedAccountFromQueryParams(request);
|
||||
} catch (final Exception e) {
|
||||
|
@ -55,7 +54,7 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Aut
|
|||
}
|
||||
}
|
||||
|
||||
private AuthenticationResult<AuthenticatedAccount> authenticatedAccountFromQueryParams(final UpgradeRequest request) {
|
||||
private ReusableAuth<AuthenticatedAccount> authenticatedAccountFromQueryParams(final UpgradeRequest request) {
|
||||
final Map<String, List<String>> parameters = request.getParameterMap();
|
||||
final List<String> usernames = parameters.get("login");
|
||||
final List<String> passwords = parameters.get("password");
|
||||
|
@ -65,16 +64,19 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Aut
|
|||
}
|
||||
final BasicCredentials credentials = new BasicCredentials(usernames.get(0).replace(" ", "+"),
|
||||
passwords.get(0).replace(" ", "+"));
|
||||
return new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true);
|
||||
return accountAuthenticator.authenticate(credentials)
|
||||
.map(authenticatedAccount -> ReusableAuth.authenticated(authenticatedAccount, this.principalSupplier))
|
||||
.orElse(INVALID_CREDENTIALS_PRESENTED);
|
||||
}
|
||||
|
||||
private AuthenticationResult<AuthenticatedAccount> authenticatedAccountFromHeaderAuth(@Nullable final String authHeader)
|
||||
private ReusableAuth<AuthenticatedAccount> authenticatedAccountFromHeaderAuth(@Nullable final String authHeader)
|
||||
throws AuthenticationException {
|
||||
if (authHeader == null) {
|
||||
return CREDENTIALS_NOT_PRESENTED;
|
||||
}
|
||||
return basicCredentialsFromAuthHeader(authHeader)
|
||||
.map(credentials -> new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true))
|
||||
.flatMap(credentials -> accountAuthenticator.authenticate(credentials))
|
||||
.map(authenticatedAccount -> ReusableAuth.authenticated(authenticatedAccount, this.principalSupplier))
|
||||
.orElse(INVALID_CREDENTIALS_PRESENTED);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,279 @@
|
|||
package org.whispersystems.textsecuregcm;
|
||||
|
||||
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.reset;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
|
||||
|
||||
import io.dropwizard.auth.Auth;
|
||||
import io.dropwizard.core.Application;
|
||||
import io.dropwizard.core.Configuration;
|
||||
import io.dropwizard.core.setup.Environment;
|
||||
import io.dropwizard.testing.junit5.DropwizardAppExtension;
|
||||
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.util.EnumSet;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
import java.util.stream.IntStream;
|
||||
import javax.servlet.DispatcherType;
|
||||
import javax.servlet.ServletRegistration;
|
||||
import javax.ws.rs.GET;
|
||||
import javax.ws.rs.Path;
|
||||
import javax.ws.rs.PathParam;
|
||||
import org.eclipse.jetty.websocket.client.WebSocketClient;
|
||||
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
|
||||
import org.glassfish.jersey.server.ManagedAsync;
|
||||
import org.glassfish.jersey.server.ServerProperties;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
|
||||
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
|
||||
import org.whispersystems.textsecuregcm.storage.RefreshingAccountNotFoundException;
|
||||
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
|
||||
import org.whispersystems.websocket.ReusableAuth;
|
||||
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
|
||||
import org.whispersystems.websocket.auth.PrincipalSupplier;
|
||||
import org.whispersystems.websocket.auth.ReadOnly;
|
||||
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
|
||||
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
|
||||
import org.whispersystems.websocket.setup.WebSocketEnvironment;
|
||||
|
||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||
public class WebsocketReuseAuthIntegrationTest {
|
||||
|
||||
private static final AuthenticatedAccount ACCOUNT = mock(AuthenticatedAccount.class);
|
||||
@SuppressWarnings("unchecked")
|
||||
private static final PrincipalSupplier<AuthenticatedAccount> PRINCIPAL_SUPPLIER = mock(PrincipalSupplier.class);
|
||||
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION =
|
||||
new DropwizardAppExtension<>(TestApplication.class);
|
||||
|
||||
|
||||
private WebSocketClient client;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() throws Exception {
|
||||
reset(PRINCIPAL_SUPPLIER);
|
||||
reset(ACCOUNT);
|
||||
when(ACCOUNT.getName()).thenReturn("original");
|
||||
client = new WebSocketClient();
|
||||
client.start();
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() throws Exception {
|
||||
client.stop();
|
||||
}
|
||||
|
||||
|
||||
public static class TestApplication extends Application<Configuration> {
|
||||
|
||||
@Override
|
||||
public void run(final Configuration configuration, final Environment environment) throws Exception {
|
||||
final TestController testController = new TestController();
|
||||
|
||||
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
|
||||
|
||||
final WebSocketEnvironment<AuthenticatedAccount> webSocketEnvironment =
|
||||
new WebSocketEnvironment<>(environment, webSocketConfiguration);
|
||||
|
||||
environment.jersey().register(testController);
|
||||
environment.servlets()
|
||||
.addFilter("RemoteAddressFilter", new RemoteAddressFilter(true))
|
||||
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
|
||||
webSocketEnvironment.jersey().register(testController);
|
||||
webSocketEnvironment.jersey().register(new RemoteAddressFilter(true));
|
||||
webSocketEnvironment.setAuthenticator(upgradeRequest -> ReusableAuth.authenticated(ACCOUNT, PRINCIPAL_SUPPLIER));
|
||||
|
||||
webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE);
|
||||
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {
|
||||
});
|
||||
|
||||
final WebSocketResourceProviderFactory<AuthenticatedAccount> webSocketServlet =
|
||||
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedAccount.class,
|
||||
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
|
||||
|
||||
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
|
||||
|
||||
final ServletRegistration.Dynamic websocketServlet =
|
||||
environment.servlets().addServlet("WebSocket", webSocketServlet);
|
||||
|
||||
websocketServlet.addMapping("/websocket");
|
||||
websocketServlet.setAsyncSupported(true);
|
||||
}
|
||||
}
|
||||
|
||||
private WebSocketResponseMessage make1WebsocketRequest(final String requestPath) throws IOException {
|
||||
|
||||
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
|
||||
|
||||
client.connect(testWebsocketListener,
|
||||
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
|
||||
return testWebsocketListener.doGet(requestPath).join();
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(strings = {"/test/read-auth", "/test/optional-read-auth"})
|
||||
public void readAuth(final String path) throws IOException {
|
||||
final WebSocketResponseMessage response = make1WebsocketRequest(path);
|
||||
assertThat(response.getStatus()).isEqualTo(200);
|
||||
verifyNoMoreInteractions(PRINCIPAL_SUPPLIER);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(strings = {"/test/write-auth", "/test/optional-write-auth"})
|
||||
public void writeAuth(final String path) throws IOException {
|
||||
final AuthenticatedAccount copiedAccount = mock(AuthenticatedAccount.class);
|
||||
when(copiedAccount.getName()).thenReturn("copy");
|
||||
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(copiedAccount);
|
||||
|
||||
final WebSocketResponseMessage response = make1WebsocketRequest(path);
|
||||
assertThat(response.getStatus()).isEqualTo(200);
|
||||
assertThat(response.getBody().map(String::new)).get().isEqualTo("copy");
|
||||
verify(PRINCIPAL_SUPPLIER, times(1)).deepCopy(any());
|
||||
verifyNoMoreInteractions(PRINCIPAL_SUPPLIER);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void readAfterWrite() throws IOException {
|
||||
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(ACCOUNT);
|
||||
final AuthenticatedAccount account2 = mock(AuthenticatedAccount.class);
|
||||
when(account2.getName()).thenReturn("refresh");
|
||||
when(PRINCIPAL_SUPPLIER.refresh(any())).thenReturn(account2);
|
||||
|
||||
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
|
||||
client.connect(testWebsocketListener,
|
||||
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
|
||||
|
||||
final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/read-auth").join();
|
||||
assertThat(readResponse.getBody().map(String::new)).get().isEqualTo("original");
|
||||
|
||||
final WebSocketResponseMessage writeResponse = testWebsocketListener.doGet("/test/write-auth").join();
|
||||
assertThat(writeResponse.getBody().map(String::new)).get().isEqualTo("original");
|
||||
|
||||
final WebSocketResponseMessage readResponse2 = testWebsocketListener.doGet("/test/read-auth").join();
|
||||
assertThat(readResponse2.getBody().map(String::new)).get().isEqualTo("refresh");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void readAfterWriteRefreshFails() throws IOException {
|
||||
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(ACCOUNT);
|
||||
when(PRINCIPAL_SUPPLIER.refresh(any())).thenThrow(RefreshingAccountNotFoundException.class);
|
||||
|
||||
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
|
||||
client.connect(testWebsocketListener,
|
||||
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
|
||||
|
||||
final WebSocketResponseMessage writeResponse = testWebsocketListener.doGet("/test/write-auth").join();
|
||||
assertThat(writeResponse.getBody().map(String::new)).get().isEqualTo("original");
|
||||
|
||||
final WebSocketResponseMessage readResponse2 = testWebsocketListener.doGet("/test/read-auth").join();
|
||||
assertThat(readResponse2.getStatus()).isEqualTo(500);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void readConcurrentWithWrite() throws IOException, ExecutionException, InterruptedException, TimeoutException {
|
||||
final AuthenticatedAccount deepCopy = mock(AuthenticatedAccount.class);
|
||||
when(deepCopy.getName()).thenReturn("deepCopy");
|
||||
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(deepCopy);
|
||||
|
||||
final AuthenticatedAccount refresh = mock(AuthenticatedAccount.class);
|
||||
when(refresh.getName()).thenReturn("refresh");
|
||||
when(PRINCIPAL_SUPPLIER.refresh(any())).thenReturn(refresh);
|
||||
|
||||
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
|
||||
client.connect(testWebsocketListener,
|
||||
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
|
||||
|
||||
// start a write request that takes a while to finish
|
||||
final CompletableFuture<WebSocketResponseMessage> writeResponse =
|
||||
testWebsocketListener.doGet("/test/start-delayed-write/foo");
|
||||
|
||||
// send a bunch of reads, they should reflect the original auth
|
||||
final List<CompletableFuture<WebSocketResponseMessage>> futures = IntStream.range(0, 10)
|
||||
.boxed().map(i -> testWebsocketListener.doGet("/test/read-auth"))
|
||||
.toList();
|
||||
CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join();
|
||||
for (CompletableFuture<WebSocketResponseMessage> future : futures) {
|
||||
assertThat(future.join().getBody().map(String::new)).get().isEqualTo("original");
|
||||
}
|
||||
|
||||
assertThat(writeResponse.isDone()).isFalse();
|
||||
|
||||
// finish the delayed write request
|
||||
testWebsocketListener.doGet("/test/finish-delayed-write/foo").get(1, TimeUnit.SECONDS);
|
||||
assertThat(writeResponse.join().getBody().map(String::new)).get().isEqualTo("deepCopy");
|
||||
|
||||
// subsequent reads should have the refreshed auth
|
||||
final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/read-auth").join();
|
||||
assertThat(readResponse.getBody().map(String::new)).get().isEqualTo("refresh");
|
||||
}
|
||||
|
||||
|
||||
@Path("/test")
|
||||
public static class TestController {
|
||||
|
||||
private final ConcurrentHashMap<String, CountDownLatch> delayedWriteLatches = new ConcurrentHashMap<>();
|
||||
|
||||
@GET
|
||||
@Path("/read-auth")
|
||||
@ManagedAsync
|
||||
public String readAuth(@ReadOnly @Auth final AuthenticatedAccount account) {
|
||||
return account.getName();
|
||||
}
|
||||
|
||||
@GET
|
||||
@Path("/optional-read-auth")
|
||||
@ManagedAsync
|
||||
public String optionalReadAuth(@ReadOnly @Auth final Optional<AuthenticatedAccount> account) {
|
||||
return account.map(AuthenticatedAccount::getName).orElse("empty");
|
||||
}
|
||||
|
||||
@GET
|
||||
@Path("/write-auth")
|
||||
@ManagedAsync
|
||||
public String writeAuth(@Auth final AuthenticatedAccount account) {
|
||||
return account.getName();
|
||||
}
|
||||
|
||||
@GET
|
||||
@Path("/optional-write-auth")
|
||||
@ManagedAsync
|
||||
public String optionalWriteAuth(@Auth final Optional<AuthenticatedAccount> account) {
|
||||
return account.map(AuthenticatedAccount::getName).orElse("empty");
|
||||
}
|
||||
|
||||
@GET
|
||||
@Path("/start-delayed-write/{id}")
|
||||
@ManagedAsync
|
||||
public String startDelayedWrite(@Auth final AuthenticatedAccount account, @PathParam("id") String id)
|
||||
throws InterruptedException {
|
||||
delayedWriteLatches.computeIfAbsent(id, i -> new CountDownLatch(1)).await();
|
||||
return account.getName();
|
||||
}
|
||||
|
||||
@GET
|
||||
@Path("/finish-delayed-write/{id}")
|
||||
@ManagedAsync
|
||||
public String finishDelayedWrite(@PathParam("id") String id) {
|
||||
delayedWriteLatches.computeIfAbsent(id, i -> new CountDownLatch(1)).countDown();
|
||||
return "ok";
|
||||
}
|
||||
}
|
||||
}
|
|
@ -7,8 +7,6 @@ package org.whispersystems.textsecuregcm.auth;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertAll;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
|
@ -30,7 +28,6 @@ import java.nio.ByteBuffer;
|
|||
import java.nio.charset.StandardCharsets;
|
||||
import java.security.Principal;
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Base64;
|
||||
import java.util.LinkedList;
|
||||
|
@ -76,7 +73,9 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
|||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
import org.whispersystems.websocket.ReusableAuth;
|
||||
import org.whispersystems.websocket.WebSocketResourceProvider;
|
||||
import org.whispersystems.websocket.auth.PrincipalSupplier;
|
||||
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
|
||||
import org.whispersystems.websocket.logging.WebsocketRequestLog;
|
||||
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
|
||||
|
@ -132,38 +131,6 @@ class AuthEnablementRefreshRequirementProviderTest {
|
|||
.forEach(device -> when(clientPresenceManager.isPresent(uuid, device.getId())).thenReturn(true));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBuildDevicesEnabled() {
|
||||
|
||||
final byte disabledDeviceId = 3;
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
|
||||
final List<Device> devices = new ArrayList<>();
|
||||
when(account.getDevices()).thenReturn(devices);
|
||||
|
||||
IntStream.range(1, 5)
|
||||
.forEach(id -> {
|
||||
final Device device = mock(Device.class);
|
||||
when(device.getId()).thenReturn((byte) id);
|
||||
when(device.isEnabled()).thenReturn(id != disabledDeviceId);
|
||||
devices.add(device);
|
||||
});
|
||||
|
||||
final Map<Byte, Boolean> devicesEnabled = AuthEnablementRefreshRequirementProvider.buildDevicesEnabledMap(account);
|
||||
|
||||
assertEquals(4, devicesEnabled.size());
|
||||
|
||||
assertAll(devicesEnabled.entrySet().stream()
|
||||
.map(deviceAndEnabled -> () -> {
|
||||
if (deviceAndEnabled.getKey().equals(disabledDeviceId)) {
|
||||
assertFalse(deviceAndEnabled.getValue());
|
||||
} else {
|
||||
assertTrue(deviceAndEnabled.getValue());
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void testDeviceEnabledChanged(final Map<Byte, Boolean> initialEnabled, final Map<Byte, Boolean> finalEnabled) {
|
||||
|
@ -308,7 +275,7 @@ class AuthEnablementRefreshRequirementProviderTest {
|
|||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
|
||||
provider = new WebSocketResourceProvider<>("127.0.0.1", RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME,
|
||||
applicationHandler, requestLog, new TestPrincipal("test", account, authenticatedDevice),
|
||||
applicationHandler, requestLog, TestPrincipal.reusableAuth("test", account, authenticatedDevice),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
remoteEndpoint = mock(RemoteEndpoint.class);
|
||||
|
@ -349,7 +316,7 @@ class AuthEnablementRefreshRequirementProviderTest {
|
|||
private final Account account;
|
||||
private final Device device;
|
||||
|
||||
private TestPrincipal(String name, final Account account, final Device device) {
|
||||
private TestPrincipal(final String name, final Account account, final Device device) {
|
||||
this.name = name;
|
||||
this.account = account;
|
||||
this.device = device;
|
||||
|
@ -369,6 +336,11 @@ class AuthEnablementRefreshRequirementProviderTest {
|
|||
public Device getAuthenticatedDevice() {
|
||||
return device;
|
||||
}
|
||||
|
||||
public static ReusableAuth<TestPrincipal> reusableAuth(final String name, final Account account, final Device device) {
|
||||
return ReusableAuth.authenticated(new TestPrincipal(name, account, device), PrincipalSupplier.forImmutablePrincipal());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Path("/v1/test")
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.auth;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertAll;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class ContainerRequestUtilTest {
|
||||
|
||||
@Test
|
||||
void testBuildDevicesEnabled() {
|
||||
|
||||
final byte disabledDeviceId = 3;
|
||||
|
||||
final Account account = mock(Account.class);
|
||||
|
||||
final List<Device> devices = new ArrayList<>();
|
||||
when(account.getDevices()).thenReturn(devices);
|
||||
|
||||
IntStream.range(1, 5)
|
||||
.forEach(id -> {
|
||||
final Device device = mock(Device.class);
|
||||
when(device.getId()).thenReturn((byte) id);
|
||||
when(device.isEnabled()).thenReturn(id != disabledDeviceId);
|
||||
devices.add(device);
|
||||
});
|
||||
|
||||
final Map<Byte, Boolean> devicesEnabled = ContainerRequestUtil.AccountInfo.fromAccount(account).devicesEnabled();
|
||||
|
||||
assertEquals(4, devicesEnabled.size());
|
||||
|
||||
assertAll(devicesEnabled.entrySet().stream()
|
||||
.map(deviceAndEnabled -> () -> {
|
||||
if (deviceAndEnabled.getKey().equals(disabledDeviceId)) {
|
||||
assertFalse(deviceAndEnabled.getValue());
|
||||
} else {
|
||||
assertTrue(deviceAndEnabled.getValue());
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
|
@ -5,108 +5,292 @@
|
|||
|
||||
package org.whispersystems.textsecuregcm.auth;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.reset;
|
||||
import static org.mockito.Mockito.timeout;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
|
||||
|
||||
import com.google.common.net.HttpHeaders;
|
||||
import io.dropwizard.auth.Auth;
|
||||
import io.dropwizard.auth.AuthDynamicFeature;
|
||||
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
|
||||
import io.dropwizard.core.Application;
|
||||
import io.dropwizard.core.Configuration;
|
||||
import io.dropwizard.core.setup.Environment;
|
||||
import io.dropwizard.testing.junit5.DropwizardAppExtension;
|
||||
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.EnumSet;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import javax.annotation.Nullable;
|
||||
import javax.ws.rs.core.SecurityContext;
|
||||
import org.glassfish.jersey.server.ContainerRequest;
|
||||
import org.glassfish.jersey.server.monitoring.RequestEvent;
|
||||
import javax.servlet.DispatcherType;
|
||||
import javax.servlet.ServletRegistration;
|
||||
import javax.ws.rs.GET;
|
||||
import javax.ws.rs.Path;
|
||||
import javax.ws.rs.client.Invocation;
|
||||
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
|
||||
import org.eclipse.jetty.websocket.client.WebSocketClient;
|
||||
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
|
||||
import org.glassfish.jersey.server.ManagedAsync;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.EnumSource;
|
||||
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
|
||||
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.util.Pair;
|
||||
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
|
||||
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
|
||||
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
||||
import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator;
|
||||
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
|
||||
import org.whispersystems.websocket.auth.PrincipalSupplier;
|
||||
import org.whispersystems.websocket.auth.ReadOnly;
|
||||
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
|
||||
import org.whispersystems.websocket.setup.WebSocketEnvironment;
|
||||
|
||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||
class PhoneNumberChangeRefreshRequirementProviderTest {
|
||||
|
||||
private PhoneNumberChangeRefreshRequirementProvider provider;
|
||||
|
||||
private Account account;
|
||||
private RequestEvent requestEvent;
|
||||
private ContainerRequest request;
|
||||
|
||||
private static final UUID ACCOUNT_UUID = UUID.randomUUID();
|
||||
private static final String NUMBER = "+18005551234";
|
||||
private static final String CHANGED_NUMBER = "+18005554321";
|
||||
private static final String TEST_CRED_HEADER = HeaderUtils.basicAuthHeader("test", "password");
|
||||
|
||||
|
||||
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION = new DropwizardAppExtension<>(
|
||||
TestApplication.class);
|
||||
|
||||
private static final AccountAuthenticator AUTHENTICATOR = mock(AccountAuthenticator.class);
|
||||
private static final AccountsManager ACCOUNTS_MANAGER = mock(AccountsManager.class);
|
||||
private static final ClientPresenceManager CLIENT_PRESENCE = mock(ClientPresenceManager.class);
|
||||
|
||||
private WebSocketClient client;
|
||||
private final Account account1 = new Account();
|
||||
private final Account account2 = new Account();
|
||||
private final Device authenticatedDevice = DevicesHelper.createDevice(Device.PRIMARY_ID);
|
||||
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
provider = new PhoneNumberChangeRefreshRequirementProvider();
|
||||
void setUp() throws Exception {
|
||||
reset(AUTHENTICATOR, CLIENT_PRESENCE, ACCOUNTS_MANAGER);
|
||||
client = new WebSocketClient();
|
||||
client.start();
|
||||
|
||||
account = mock(Account.class);
|
||||
final Device device = mock(Device.class);
|
||||
final UUID uuid = UUID.randomUUID();
|
||||
account1.setUuid(uuid);
|
||||
account1.addDevice(authenticatedDevice);
|
||||
account1.setNumber(NUMBER, UUID.randomUUID());
|
||||
|
||||
when(account.getUuid()).thenReturn(ACCOUNT_UUID);
|
||||
when(account.getNumber()).thenReturn(NUMBER);
|
||||
when(account.getDevices()).thenReturn(List.of(device));
|
||||
when(device.getId()).thenReturn(Device.PRIMARY_ID);
|
||||
account2.setUuid(uuid);
|
||||
account2.addDevice(authenticatedDevice);
|
||||
account2.setNumber(CHANGED_NUMBER, UUID.randomUUID());
|
||||
|
||||
request = mock(ContainerRequest.class);
|
||||
}
|
||||
|
||||
final Map<String, Object> requestProperties = new HashMap<>();
|
||||
@AfterEach
|
||||
void tearDown() throws Exception {
|
||||
client.stop();
|
||||
}
|
||||
|
||||
doAnswer(invocation -> {
|
||||
requestProperties.put(invocation.getArgument(0, String.class), invocation.getArgument(1));
|
||||
return null;
|
||||
}).when(request).setProperty(anyString(), any());
|
||||
|
||||
when(request.getProperty(anyString())).thenAnswer(
|
||||
invocation -> requestProperties.get(invocation.getArgument(0, String.class)));
|
||||
public static class TestApplication extends Application<Configuration> {
|
||||
|
||||
requestEvent = mock(RequestEvent.class);
|
||||
when(requestEvent.getContainerRequest()).thenReturn(request);
|
||||
@Override
|
||||
public void run(final Configuration configuration, final Environment environment) throws Exception {
|
||||
final TestController testController = new TestController();
|
||||
|
||||
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
|
||||
|
||||
final WebSocketEnvironment<AuthenticatedAccount> webSocketEnvironment =
|
||||
new WebSocketEnvironment<>(environment, webSocketConfiguration);
|
||||
|
||||
environment.jersey().register(testController);
|
||||
webSocketEnvironment.jersey().register(testController);
|
||||
environment.servlets()
|
||||
.addFilter("RemoteAddressFilter", new RemoteAddressFilter(true))
|
||||
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
|
||||
webSocketEnvironment.jersey().register(new RemoteAddressFilter(true));
|
||||
webSocketEnvironment.jersey()
|
||||
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE));
|
||||
environment.jersey()
|
||||
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE));
|
||||
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {
|
||||
});
|
||||
|
||||
|
||||
environment.jersey().register(new AuthDynamicFeature(new BasicCredentialAuthFilter.Builder<AuthenticatedAccount>()
|
||||
.setAuthenticator(AUTHENTICATOR)
|
||||
.buildAuthFilter()));
|
||||
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(AUTHENTICATOR, mock(PrincipalSupplier.class)));
|
||||
|
||||
final WebSocketResourceProviderFactory<AuthenticatedAccount> webSocketServlet =
|
||||
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedAccount.class,
|
||||
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
|
||||
|
||||
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
|
||||
|
||||
final ServletRegistration.Dynamic websocketServlet =
|
||||
environment.servlets().addServlet("WebSocket", webSocketServlet);
|
||||
|
||||
websocketServlet.addMapping("/websocket");
|
||||
websocketServlet.setAsyncSupported(true);
|
||||
}
|
||||
}
|
||||
|
||||
enum Protocol { HTTP, WEBSOCKET }
|
||||
|
||||
private void makeAnonymousRequest(final Protocol protocol, final String requestPath) throws IOException {
|
||||
makeRequest(protocol, requestPath, true);
|
||||
}
|
||||
|
||||
/*
|
||||
* Make an authenticated request that will return account1 as the principal
|
||||
*/
|
||||
private void makeAuthenticatedRequest(
|
||||
final Protocol protocol,
|
||||
final String requestPath) throws IOException {
|
||||
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedAccount(account1, authenticatedDevice)));
|
||||
makeRequest(protocol,requestPath, false);
|
||||
}
|
||||
|
||||
private void makeRequest(final Protocol protocol, final String requestPath, final boolean anonymous) throws IOException {
|
||||
switch (protocol) {
|
||||
case WEBSOCKET -> {
|
||||
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
|
||||
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
|
||||
if (!anonymous) {
|
||||
upgradeRequest.setHeader(HttpHeaders.AUTHORIZATION, TEST_CRED_HEADER);
|
||||
}
|
||||
client.connect(
|
||||
testWebsocketListener,
|
||||
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())),
|
||||
upgradeRequest);
|
||||
testWebsocketListener.sendRequest(requestPath, "GET", Collections.emptyList(), Optional.empty()).join();
|
||||
}
|
||||
case HTTP -> {
|
||||
final Invocation.Builder request = DROPWIZARD_APP_EXTENSION.client()
|
||||
.target("http://127.0.0.1:%s%s".formatted(DROPWIZARD_APP_EXTENSION.getLocalPort(), requestPath))
|
||||
.request();
|
||||
if (!anonymous) {
|
||||
request.header(HttpHeaders.AUTHORIZATION, TEST_CRED_HEADER);
|
||||
}
|
||||
request.get();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(Protocol.class)
|
||||
void handleRequestNoChange(final Protocol protocol) throws IOException {
|
||||
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account1));
|
||||
makeAuthenticatedRequest(protocol, "/test/annotated");
|
||||
|
||||
// Event listeners can fire after responses are sent
|
||||
verify(ACCOUNTS_MANAGER, timeout(5000).times(1)).getByAccountIdentifier(eq(account1.getUuid()));
|
||||
verifyNoMoreInteractions(CLIENT_PRESENCE);
|
||||
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(Protocol.class)
|
||||
void handleRequestChange(final Protocol protocol) throws IOException {
|
||||
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account2));
|
||||
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedAccount(account1, authenticatedDevice)));
|
||||
|
||||
makeAuthenticatedRequest(protocol, "/test/annotated");
|
||||
|
||||
// Make sure we disconnect the account if the account has changed numbers. Event listeners can fire after responses
|
||||
// are sent, so use a timeout.
|
||||
verify(CLIENT_PRESENCE, timeout(5000))
|
||||
.disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId()));
|
||||
verifyNoMoreInteractions(CLIENT_PRESENCE);
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleRequestNoChange() {
|
||||
setAuthenticatedAccount(request, account);
|
||||
void handleRequestChangeAsyncEndpoint() throws IOException {
|
||||
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account2));
|
||||
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedAccount(account1, authenticatedDevice)));
|
||||
|
||||
provider.handleRequestFiltered(requestEvent);
|
||||
assertEquals(Collections.emptyList(), provider.handleRequestFinished(requestEvent));
|
||||
// Event listeners with asynchronous HTTP endpoints don't currently correctly maintain state between request and
|
||||
// response
|
||||
makeAuthenticatedRequest(Protocol.WEBSOCKET, "/test/async-annotated");
|
||||
|
||||
// Make sure we disconnect the account if the account has changed numbers. Event listeners can fire after responses
|
||||
// are sent, so use a timeout.
|
||||
verify(CLIENT_PRESENCE, timeout(5000))
|
||||
.disconnectPresence(eq(account1.getUuid()), eq(authenticatedDevice.getId()));
|
||||
verifyNoMoreInteractions(CLIENT_PRESENCE);
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleRequestNumberChange() {
|
||||
setAuthenticatedAccount(request, account);
|
||||
@ParameterizedTest
|
||||
@EnumSource(Protocol.class)
|
||||
void handleRequestNotAnnotated(final Protocol protocol) throws IOException, InterruptedException {
|
||||
makeAuthenticatedRequest(protocol,"/test/not-annotated");
|
||||
|
||||
provider.handleRequestFiltered(requestEvent);
|
||||
when(account.getNumber()).thenReturn(CHANGED_NUMBER);
|
||||
assertEquals(List.of(new Pair<>(ACCOUNT_UUID, Device.PRIMARY_ID)), provider.handleRequestFinished(requestEvent));
|
||||
// Give a tick for event listeners to run. Racy, but should occasionally catch an errant running listener if one is
|
||||
// introduced.
|
||||
Thread.sleep(100);
|
||||
|
||||
// Shouldn't even read the account if the method has not been annotated
|
||||
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
|
||||
verifyNoMoreInteractions(CLIENT_PRESENCE);
|
||||
}
|
||||
|
||||
@Test
|
||||
void handleRequestNoAuthenticatedAccount() {
|
||||
final ContainerRequest request = mock(ContainerRequest.class);
|
||||
setAuthenticatedAccount(request, null);
|
||||
@ParameterizedTest
|
||||
@EnumSource(Protocol.class)
|
||||
void handleRequestNotAuthenticated(final Protocol protocol) throws IOException, InterruptedException {
|
||||
makeAnonymousRequest(protocol, "/test/not-authenticated");
|
||||
|
||||
when(requestEvent.getContainerRequest()).thenReturn(request);
|
||||
// Give a tick for event listeners to run. Racy, but should occasionally catch an errant running listener if one is
|
||||
// introduced.
|
||||
Thread.sleep(100);
|
||||
|
||||
provider.handleRequestFiltered(requestEvent);
|
||||
assertEquals(Collections.emptyList(), provider.handleRequestFinished(requestEvent));
|
||||
// Shouldn't even read the account if the method has not been annotated
|
||||
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
|
||||
verifyNoMoreInteractions(CLIENT_PRESENCE);
|
||||
}
|
||||
|
||||
private static void setAuthenticatedAccount(final ContainerRequest mockRequest, @Nullable final Account account) {
|
||||
final SecurityContext securityContext = mock(SecurityContext.class);
|
||||
|
||||
when(mockRequest.getSecurityContext()).thenReturn(securityContext);
|
||||
@Path("/test")
|
||||
public static class TestController {
|
||||
|
||||
if (account != null) {
|
||||
final AuthenticatedAccount authenticatedAccount = mock(AuthenticatedAccount.class);
|
||||
@GET
|
||||
@Path("/annotated")
|
||||
@ChangesPhoneNumber
|
||||
public String annotated(@ReadOnly @Auth final AuthenticatedAccount account) {
|
||||
return "ok";
|
||||
}
|
||||
|
||||
when(securityContext.getUserPrincipal()).thenReturn(authenticatedAccount);
|
||||
when(authenticatedAccount.getAccount()).thenReturn(account);
|
||||
} else {
|
||||
when(securityContext.getUserPrincipal()).thenReturn(null);
|
||||
@GET
|
||||
@Path("/async-annotated")
|
||||
@ChangesPhoneNumber
|
||||
@ManagedAsync
|
||||
public String asyncAnnotated(@ReadOnly @Auth final AuthenticatedAccount account) {
|
||||
return "ok";
|
||||
}
|
||||
|
||||
@GET
|
||||
@Path("/not-authenticated")
|
||||
@ChangesPhoneNumber
|
||||
public String notAuthenticated() {
|
||||
return "ok";
|
||||
}
|
||||
|
||||
@GET
|
||||
@Path("/not-annotated")
|
||||
public String notAnnotated(@ReadOnly @Auth final AuthenticatedAccount account) {
|
||||
return "ok";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -87,7 +87,7 @@ class BasicCredentialAuthenticationInterceptorTest {
|
|||
when(device.getId()).thenReturn(Device.PRIMARY_ID);
|
||||
|
||||
when(accountAuthenticator.authenticate(any()))
|
||||
.thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(account, device))));
|
||||
.thenReturn(Optional.of(new AuthenticatedAccount(account, device)));
|
||||
} else {
|
||||
when(accountAuthenticator.authenticate(any()))
|
||||
.thenReturn(Optional.empty());
|
||||
|
|
|
@ -39,7 +39,7 @@ class DirectoryControllerV2Test {
|
|||
when(account.getUuid()).thenReturn(uuid);
|
||||
|
||||
final ExternalServiceCredentials credentials = (ExternalServiceCredentials) controller.getAuthToken(
|
||||
new AuthenticatedAccount(() -> new Pair<>(account, mock(Device.class)))).getEntity();
|
||||
new AuthenticatedAccount(account, mock(Device.class))).getEntity();
|
||||
|
||||
assertEquals(credentials.username(), "d369bc712e2e0dd36258");
|
||||
assertEquals(credentials.password(), "1633738643:4433b0fab41f25f79dd4");
|
||||
|
|
|
@ -51,7 +51,10 @@ import org.junit.jupiter.api.Test;
|
|||
import org.mockito.ArgumentCaptor;
|
||||
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
|
||||
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
|
||||
import org.whispersystems.textsecuregcm.tests.util.TestPrincipal;
|
||||
import org.whispersystems.websocket.ReusableAuth;
|
||||
import org.whispersystems.websocket.WebSocketResourceProvider;
|
||||
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
|
||||
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
|
||||
import org.whispersystems.websocket.logging.WebsocketRequestLog;
|
||||
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
|
||||
|
@ -139,7 +142,7 @@ class MetricsRequestEventListenerTest {
|
|||
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
|
||||
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.reusableAuth("foo"),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
final Session session = mock(Session.class);
|
||||
|
@ -201,7 +204,7 @@ class MetricsRequestEventListenerTest {
|
|||
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
|
||||
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.reusableAuth("foo"),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
final Session session = mock(Session.class);
|
||||
|
@ -252,19 +255,6 @@ class MetricsRequestEventListenerTest {
|
|||
return SubProtocol.WebSocketMessage.parseFrom(responseCaptor.getValue().array()).getResponse();
|
||||
}
|
||||
|
||||
public static class TestPrincipal implements Principal {
|
||||
|
||||
private final String name;
|
||||
|
||||
private TestPrincipal(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
}
|
||||
|
||||
@Path("/v1/test")
|
||||
public static class TestResource {
|
||||
|
|
|
@ -1,71 +0,0 @@
|
|||
/*
|
||||
* Copyright 2021 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotSame;
|
||||
import static org.junit.jupiter.api.Assertions.assertSame;
|
||||
import static org.mockito.Mockito.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.whispersystems.textsecuregcm.util.Pair;
|
||||
|
||||
class RefreshingAccountAndDeviceSupplierTest {
|
||||
|
||||
@Test
|
||||
void test() {
|
||||
|
||||
final AccountsManager accountsManager = mock(AccountsManager.class);
|
||||
|
||||
final UUID uuid = UUID.randomUUID();
|
||||
final byte deviceId = 2;
|
||||
|
||||
final Account initialAccount = mock(Account.class);
|
||||
final Device initialDevice = mock(Device.class);
|
||||
|
||||
when(initialAccount.getUuid()).thenReturn(uuid);
|
||||
when(initialDevice.getId()).thenReturn(deviceId);
|
||||
when(initialAccount.getDevice(deviceId)).thenReturn(Optional.of(initialDevice));
|
||||
|
||||
when(accountsManager.getByAccountIdentifier(any(UUID.class))).thenAnswer(answer -> {
|
||||
final Account account = mock(Account.class);
|
||||
final Device device = mock(Device.class);
|
||||
|
||||
when(account.getUuid()).thenReturn(answer.getArgument(0, UUID.class));
|
||||
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
|
||||
when(device.getId()).thenReturn(deviceId);
|
||||
|
||||
return Optional.of(account);
|
||||
});
|
||||
|
||||
final RefreshingAccountAndDeviceSupplier refreshingAccountAndDeviceSupplier = new RefreshingAccountAndDeviceSupplier(
|
||||
initialAccount, deviceId, accountsManager);
|
||||
|
||||
Pair<Account, Device> accountAndDevice = refreshingAccountAndDeviceSupplier.get();
|
||||
|
||||
assertSame(initialAccount, accountAndDevice.first());
|
||||
assertSame(initialDevice, accountAndDevice.second());
|
||||
|
||||
accountAndDevice = refreshingAccountAndDeviceSupplier.get();
|
||||
|
||||
assertSame(initialAccount, accountAndDevice.first());
|
||||
assertSame(initialDevice, accountAndDevice.second());
|
||||
|
||||
when(initialAccount.isStale()).thenReturn(true);
|
||||
|
||||
accountAndDevice = refreshingAccountAndDeviceSupplier.get();
|
||||
|
||||
assertNotSame(initialAccount, accountAndDevice.first());
|
||||
assertNotSame(initialDevice, accountAndDevice.second());
|
||||
|
||||
assertEquals(uuid, accountAndDevice.first().getUuid());
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.tests.util;
|
||||
|
||||
import java.security.Principal;
|
||||
import org.whispersystems.websocket.ReusableAuth;
|
||||
import org.whispersystems.websocket.auth.PrincipalSupplier;
|
||||
|
||||
public class TestPrincipal implements Principal {
|
||||
|
||||
private final String name;
|
||||
|
||||
private TestPrincipal(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public static ReusableAuth<TestPrincipal> reusableAuth(final String name) {
|
||||
return ReusableAuth.authenticated(new TestPrincipal(name), PrincipalSupplier.forImmutablePrincipal());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.textsecuregcm.tests.util;
|
||||
|
||||
import org.eclipse.jetty.websocket.api.Session;
|
||||
import org.eclipse.jetty.websocket.api.WebSocketListener;
|
||||
import org.whispersystems.websocket.messages.WebSocketMessage;
|
||||
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
|
||||
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
|
||||
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
public class TestWebsocketListener implements WebSocketListener {
|
||||
|
||||
private final AtomicLong requestId = new AtomicLong();
|
||||
private final CompletableFuture<Session> started = new CompletableFuture<>();
|
||||
private final ConcurrentHashMap<Long, CompletableFuture<WebSocketResponseMessage>> responseFutures = new ConcurrentHashMap<>();
|
||||
private final WebSocketMessageFactory messageFactory;
|
||||
|
||||
public TestWebsocketListener() {
|
||||
this.messageFactory = new ProtobufWebSocketMessageFactory();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void onWebSocketConnect(final Session session) {
|
||||
started.complete(session);
|
||||
|
||||
}
|
||||
|
||||
public CompletableFuture<WebSocketResponseMessage> doGet(final String requestPath) {
|
||||
return sendRequest(requestPath, "GET", List.of("Accept: application/json"), Optional.empty());
|
||||
}
|
||||
|
||||
public CompletableFuture<WebSocketResponseMessage> sendRequest(
|
||||
final String requestPath,
|
||||
final String verb,
|
||||
final List<String> headers,
|
||||
final Optional<byte[]> body) {
|
||||
return started.thenCompose(session -> {
|
||||
final long id = requestId.incrementAndGet();
|
||||
final CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
|
||||
responseFutures.put(id, future);
|
||||
final byte[] requestBytes = messageFactory.createRequest(
|
||||
Optional.of(id), verb, requestPath, headers, body).toByteArray();
|
||||
try {
|
||||
session.getRemote().sendBytes(ByteBuffer.wrap(requestBytes));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return future;
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onWebSocketBinary(final byte[] payload, final int offset, final int length) {
|
||||
try {
|
||||
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
|
||||
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) {
|
||||
responseFutures.get(webSocketMessage.getResponseMessage().getRequestId())
|
||||
.complete(webSocketMessage.getResponseMessage());
|
||||
} else {
|
||||
throw new RuntimeException("Unexpected message type: " + webSocketMessage.getType());
|
||||
}
|
||||
} catch (final Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -57,6 +57,7 @@ import org.junit.jupiter.params.provider.MethodSource;
|
|||
import org.slf4j.Logger;
|
||||
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
|
||||
import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper;
|
||||
import org.whispersystems.textsecuregcm.tests.util.TestPrincipal;
|
||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||
import org.whispersystems.websocket.WebSocketResourceProvider;
|
||||
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
|
||||
|
@ -175,7 +176,8 @@ class LoggingUnhandledExceptionMapperTest {
|
|||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
|
||||
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog,
|
||||
TestPrincipal.reusableAuth("foo"),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||
|
@ -238,18 +240,4 @@ class LoggingUnhandledExceptionMapperTest {
|
|||
throw new RuntimeException();
|
||||
}
|
||||
}
|
||||
|
||||
public static class TestPrincipal implements Principal {
|
||||
|
||||
private final String name;
|
||||
|
||||
private TestPrincipal(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,8 +28,8 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
|
|||
import org.whispersystems.textsecuregcm.storage.Account;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
||||
import org.whispersystems.textsecuregcm.util.Pair;
|
||||
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
|
||||
import org.whispersystems.websocket.ReusableAuth;
|
||||
import org.whispersystems.websocket.auth.PrincipalSupplier;
|
||||
|
||||
class WebSocketAccountAuthenticatorTest {
|
||||
|
||||
|
@ -52,7 +52,7 @@ class WebSocketAccountAuthenticatorTest {
|
|||
accountAuthenticator = mock(AccountAuthenticator.class);
|
||||
|
||||
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
|
||||
.thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(mock(Account.class), mock(Device.class)))));
|
||||
.thenReturn(Optional.of(new AuthenticatedAccount(mock(Account.class), mock(Device.class))));
|
||||
|
||||
when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD))))
|
||||
.thenReturn(Optional.empty());
|
||||
|
@ -66,7 +66,7 @@ class WebSocketAccountAuthenticatorTest {
|
|||
@Nullable final String authorizationHeaderValue,
|
||||
final Map<String, List<String>> upgradeRequestParameters,
|
||||
final boolean expectAccount,
|
||||
final boolean expectCredentialsPresented) throws Exception {
|
||||
final boolean expectInvalid) throws Exception {
|
||||
|
||||
when(upgradeRequest.getParameterMap()).thenReturn(upgradeRequestParameters);
|
||||
if (authorizationHeaderValue != null) {
|
||||
|
@ -74,13 +74,13 @@ class WebSocketAccountAuthenticatorTest {
|
|||
}
|
||||
|
||||
final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(
|
||||
accountAuthenticator);
|
||||
accountAuthenticator,
|
||||
mock(PrincipalSupplier.class));
|
||||
|
||||
final WebSocketAuthenticator.AuthenticationResult<AuthenticatedAccount> result = webSocketAuthenticator.authenticate(
|
||||
upgradeRequest);
|
||||
final ReusableAuth<AuthenticatedAccount> result = webSocketAuthenticator.authenticate(upgradeRequest);
|
||||
|
||||
assertEquals(expectAccount, result.getUser().isPresent());
|
||||
assertEquals(expectCredentialsPresented, result.credentialsPresented());
|
||||
assertEquals(expectAccount, result.ref().isPresent());
|
||||
assertEquals(expectInvalid, result.invalidCredentialsProvided());
|
||||
}
|
||||
|
||||
private static Stream<Arguments> testAuthenticate() {
|
||||
|
@ -94,17 +94,17 @@ class WebSocketAccountAuthenticatorTest {
|
|||
HeaderUtils.basicAuthHeader(INVALID_USER, INVALID_PASSWORD);
|
||||
return Stream.of(
|
||||
// if `Authorization` header is present, outcome should not depend on the value of query parameters
|
||||
Arguments.of(headerWithValidAuth, Map.of(), true, true),
|
||||
Arguments.of(headerWithValidAuth, Map.of(), true, false),
|
||||
Arguments.of(headerWithInvalidAuth, Map.of(), false, true),
|
||||
Arguments.of("invalid header value", Map.of(), false, true),
|
||||
Arguments.of(headerWithValidAuth, paramsMapWithValidAuth, true, true),
|
||||
Arguments.of(headerWithValidAuth, paramsMapWithValidAuth, true, false),
|
||||
Arguments.of(headerWithInvalidAuth, paramsMapWithValidAuth, false, true),
|
||||
Arguments.of("invalid header value", paramsMapWithValidAuth, false, true),
|
||||
Arguments.of(headerWithValidAuth, paramsMapWithInvalidAuth, true, true),
|
||||
Arguments.of(headerWithValidAuth, paramsMapWithInvalidAuth, true, false),
|
||||
Arguments.of(headerWithInvalidAuth, paramsMapWithInvalidAuth, false, true),
|
||||
Arguments.of("invalid header value", paramsMapWithInvalidAuth, false, true),
|
||||
// if `Authorization` header is not set, outcome should match the query params based auth
|
||||
Arguments.of(null, paramsMapWithValidAuth, true, true),
|
||||
Arguments.of(null, paramsMapWithValidAuth, true, false),
|
||||
Arguments.of(null, paramsMapWithInvalidAuth, false, true),
|
||||
Arguments.of(null, Map.of(), false, false)
|
||||
);
|
||||
|
|
|
@ -125,7 +125,7 @@ class WebSocketConnectionIntegrationTest {
|
|||
final WebSocketConnection webSocketConnection = new WebSocketConnection(
|
||||
mock(ReceiptSender.class),
|
||||
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
|
||||
new AuthenticatedAccount(() -> new Pair<>(account, device)),
|
||||
new AuthenticatedAccount(account, device),
|
||||
device,
|
||||
webSocketClient,
|
||||
scheduledExecutorService,
|
||||
|
@ -210,7 +210,7 @@ class WebSocketConnectionIntegrationTest {
|
|||
final WebSocketConnection webSocketConnection = new WebSocketConnection(
|
||||
mock(ReceiptSender.class),
|
||||
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
|
||||
new AuthenticatedAccount(() -> new Pair<>(account, device)),
|
||||
new AuthenticatedAccount(account, device),
|
||||
device,
|
||||
webSocketClient,
|
||||
scheduledExecutorService,
|
||||
|
@ -276,7 +276,7 @@ class WebSocketConnectionIntegrationTest {
|
|||
final WebSocketConnection webSocketConnection = new WebSocketConnection(
|
||||
mock(ReceiptSender.class),
|
||||
new MessagesManager(messagesDynamoDb, messagesCache, reportMessageManager, sharedExecutorService),
|
||||
new AuthenticatedAccount(() -> new Pair<>(account, device)),
|
||||
new AuthenticatedAccount(account, device),
|
||||
device,
|
||||
webSocketClient,
|
||||
100, // use a very short timeout, so that this test completes quickly
|
||||
|
|
|
@ -64,9 +64,9 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
|||
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
|
||||
import org.whispersystems.textsecuregcm.storage.Device;
|
||||
import org.whispersystems.textsecuregcm.storage.MessagesManager;
|
||||
import org.whispersystems.textsecuregcm.util.Pair;
|
||||
import org.whispersystems.websocket.ReusableAuth;
|
||||
import org.whispersystems.websocket.WebSocketClient;
|
||||
import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult;
|
||||
import org.whispersystems.websocket.auth.PrincipalSupplier;
|
||||
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
|
||||
import org.whispersystems.websocket.session.WebSocketSessionContext;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
@ -101,7 +101,7 @@ class WebSocketConnectionTest {
|
|||
accountsManager = mock(AccountsManager.class);
|
||||
account = mock(Account.class);
|
||||
device = mock(Device.class);
|
||||
auth = new AuthenticatedAccount(() -> new Pair<>(account, device));
|
||||
auth = new AuthenticatedAccount(account, device);
|
||||
upgradeRequest = mock(UpgradeRequest.class);
|
||||
messagesManager = mock(MessagesManager.class);
|
||||
receiptSender = mock(ReceiptSender.class);
|
||||
|
@ -118,18 +118,19 @@ class WebSocketConnectionTest {
|
|||
|
||||
@Test
|
||||
void testCredentials() throws Exception {
|
||||
WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
|
||||
WebSocketAccountAuthenticator webSocketAuthenticator =
|
||||
new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class));
|
||||
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager,
|
||||
mock(PushNotificationManager.class), mock(ClientPresenceManager.class),
|
||||
retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager);
|
||||
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
|
||||
|
||||
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
|
||||
.thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(account, device))));
|
||||
.thenReturn(Optional.of(new AuthenticatedAccount(account, device)));
|
||||
|
||||
AuthenticationResult<AuthenticatedAccount> account = webSocketAuthenticator.authenticate(upgradeRequest);
|
||||
when(sessionContext.getAuthenticated()).thenReturn(account.getUser().orElse(null));
|
||||
when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.getUser().orElse(null));
|
||||
ReusableAuth<AuthenticatedAccount> account = webSocketAuthenticator.authenticate(upgradeRequest);
|
||||
when(sessionContext.getAuthenticated()).thenReturn(account.ref().orElse(null));
|
||||
when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.ref().orElse(null));
|
||||
|
||||
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
|
||||
when(webSocketClient.getUserAgent()).thenReturn("Signal-Android/6.22.8");
|
||||
|
@ -144,8 +145,8 @@ class WebSocketConnectionTest {
|
|||
// unauthenticated
|
||||
when(upgradeRequest.getParameterMap()).thenReturn(Map.of());
|
||||
account = webSocketAuthenticator.authenticate(upgradeRequest);
|
||||
assertFalse(account.getUser().isPresent());
|
||||
assertFalse(account.credentialsPresented());
|
||||
assertFalse(account.ref().isPresent());
|
||||
assertFalse(account.invalidCredentialsProvided());
|
||||
|
||||
connectListener.onWebSocketConnect(sessionContext);
|
||||
verify(sessionContext, times(2)).addWebsocketClosedListener(
|
||||
|
|
|
@ -0,0 +1,182 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.websocket;
|
||||
|
||||
import java.security.Principal;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import org.whispersystems.websocket.auth.PrincipalSupplier;
|
||||
|
||||
/**
|
||||
* This class holds a principal that can be reused across requests on a websocket. Since two requests may operate
|
||||
* concurrently on the same principal, and some principals contain non thread-safe mutable state, appropriate use of
|
||||
* this class ensures that no data races occur. It also ensures that after a principal is modified, a subsequent request
|
||||
* gets the up-to-date principal
|
||||
*
|
||||
* @param <T> The underlying principal type
|
||||
* @see PrincipalSupplier
|
||||
*/
|
||||
public abstract sealed class ReusableAuth<T extends Principal> {
|
||||
|
||||
/**
|
||||
* Get a reference to the underlying principal that callers pledge not to modify.
|
||||
* <p>
|
||||
* The reference returned will potentially be provided to many threads concurrently accessing the principal. Callers
|
||||
* should use this method only if they can ensure that they will not modify the in-memory principal object AND they do
|
||||
* not intend to modify the underlying canonical representation of the principal.
|
||||
* <p>
|
||||
* For example, if a caller retrieves a reference to a principal, does not modify the in memory state, but updates a
|
||||
* field on a database that should be reflected in subsequent retrievals of the principal, they will have met the
|
||||
* first criteria, but not the second. In that case they should instead use {@link #mutableRef()}.
|
||||
* <p>
|
||||
* If other callers have modified the underlying principal by using {@link #mutableRef()}, this method may need to
|
||||
* refresh the principal via {@link PrincipalSupplier#refresh} which could be a blocking operation.
|
||||
*
|
||||
* @return If authenticated, a reference to the underlying principal that should not be modified
|
||||
*/
|
||||
public abstract Optional<T> ref();
|
||||
|
||||
|
||||
public interface MutableRef<T> {
|
||||
|
||||
T ref();
|
||||
|
||||
void close();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a reference to the underlying principal that may be modified.
|
||||
* <p>
|
||||
* The underlying principal can be safely modified. Multiple threads may operate on the same {@link ReusableAuth} so
|
||||
* long as they each have their own {@link MutableRef}. After any modifications, the caller must call
|
||||
* {@link MutableRef#close} to notify the principal has become dirty. Close should be called after modifications but
|
||||
* before sending a response on the websocket. This ensures that a request that comes in after a modification response
|
||||
* is received is guaranteed to see the modification.
|
||||
*
|
||||
* @return If authenticated, a reference to the underlying principal that may be modified
|
||||
*/
|
||||
public abstract Optional<MutableRef<T>> mutableRef();
|
||||
|
||||
public boolean invalidCredentialsProvided() {
|
||||
return switch (this) {
|
||||
case Invalid<T> ignored -> true;
|
||||
case ReusableAuth.Anonymous<T> ignored -> false;
|
||||
case ReusableAuth.Authenticated<T> ignored-> false;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* @return A {@link ReusableAuth} indicating no credential were provided
|
||||
*/
|
||||
public static <T extends Principal> ReusableAuth<T> anonymous() {
|
||||
//noinspection unchecked
|
||||
return (ReusableAuth<T>) Anonymous.ANON_RESULT;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return A {@link ReusableAuth} indicating that invalid credentials were provided
|
||||
*/
|
||||
public static <T extends Principal> ReusableAuth<T> invalid() {
|
||||
//noinspection unchecked
|
||||
return (ReusableAuth<T>) Invalid.INVALID_RESULT;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a successfully authenticated {@link ReusableAuth}
|
||||
*
|
||||
* @param principal The authenticated principal
|
||||
* @param principalSupplier Instructions for how to refresh or copy this principal
|
||||
* @param <T> The principal type
|
||||
* @return A {@link ReusableAuth} for a successfully authenticated principal
|
||||
*/
|
||||
public static <T extends Principal> ReusableAuth<T> authenticated(T principal,
|
||||
PrincipalSupplier<T> principalSupplier) {
|
||||
return new Authenticated<>(principal, principalSupplier);
|
||||
}
|
||||
|
||||
|
||||
private static final class Invalid<T extends Principal> extends ReusableAuth<T> {
|
||||
|
||||
@SuppressWarnings({"rawtypes"})
|
||||
private static final ReusableAuth INVALID_RESULT = new Invalid();
|
||||
|
||||
@Override
|
||||
public Optional<T> ref() {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<MutableRef<T>> mutableRef() {
|
||||
return Optional.empty();
|
||||
}
|
||||
}
|
||||
|
||||
private static final class Anonymous<T extends Principal> extends ReusableAuth<T> {
|
||||
|
||||
@SuppressWarnings({"rawtypes"})
|
||||
private static final ReusableAuth ANON_RESULT = new Anonymous();
|
||||
|
||||
@Override
|
||||
public Optional<T> ref() {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<MutableRef<T>> mutableRef() {
|
||||
return Optional.empty();
|
||||
}
|
||||
}
|
||||
|
||||
private static final class Authenticated<T extends Principal> extends ReusableAuth<T> {
|
||||
|
||||
private T basePrincipal;
|
||||
private final AtomicBoolean needRefresh = new AtomicBoolean(false);
|
||||
private final PrincipalSupplier<T> principalSupplier;
|
||||
|
||||
Authenticated(final T basePrincipal, PrincipalSupplier<T> principalSupplier) {
|
||||
this.basePrincipal = basePrincipal;
|
||||
this.principalSupplier = principalSupplier;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<T> ref() {
|
||||
maybeRefresh();
|
||||
return Optional.of(basePrincipal);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<MutableRef<T>> mutableRef() {
|
||||
maybeRefresh();
|
||||
return Optional.of(new AuthenticatedMutableRef(principalSupplier.deepCopy(basePrincipal)));
|
||||
}
|
||||
|
||||
private void maybeRefresh() {
|
||||
if (needRefresh.compareAndSet(true, false)) {
|
||||
basePrincipal = principalSupplier.refresh(basePrincipal);
|
||||
}
|
||||
}
|
||||
|
||||
private class AuthenticatedMutableRef implements MutableRef<T> {
|
||||
|
||||
final T ref;
|
||||
|
||||
private AuthenticatedMutableRef(T ref) {
|
||||
this.ref = ref;
|
||||
}
|
||||
|
||||
public T ref() {
|
||||
return ref;
|
||||
}
|
||||
|
||||
public void close() {
|
||||
needRefresh.set(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private ReusableAuth() {
|
||||
}
|
||||
}
|
|
@ -58,7 +58,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
|
|||
|
||||
private final Map<Long, CompletableFuture<WebSocketResponseMessage>> requestMap = new ConcurrentHashMap<>();
|
||||
|
||||
private final T authenticated;
|
||||
private final ReusableAuth<T> reusableAuth;
|
||||
private final WebSocketMessageFactory messageFactory;
|
||||
private final Optional<WebSocketConnectListener> connectListener;
|
||||
private final ApplicationHandler jerseyHandler;
|
||||
|
@ -77,7 +77,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
|
|||
String remoteAddressPropertyName,
|
||||
ApplicationHandler jerseyHandler,
|
||||
WebsocketRequestLog requestLog,
|
||||
T authenticated,
|
||||
ReusableAuth<T> authenticated,
|
||||
WebSocketMessageFactory messageFactory,
|
||||
Optional<WebSocketConnectListener> connectListener,
|
||||
Duration idleTimeout) {
|
||||
|
@ -85,7 +85,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
|
|||
this.remoteAddressPropertyName = remoteAddressPropertyName;
|
||||
this.jerseyHandler = jerseyHandler;
|
||||
this.requestLog = requestLog;
|
||||
this.authenticated = authenticated;
|
||||
this.reusableAuth = authenticated;
|
||||
this.messageFactory = messageFactory;
|
||||
this.connectListener = connectListener;
|
||||
this.idleTimeout = idleTimeout;
|
||||
|
@ -97,7 +97,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
|
|||
this.remoteEndpoint = session.getRemote();
|
||||
this.context = new WebSocketSessionContext(
|
||||
new WebSocketClient(session, remoteEndpoint, messageFactory, requestMap));
|
||||
this.context.setAuthenticated(authenticated);
|
||||
this.context.setAuthenticated(reusableAuth.ref().orElse(null));
|
||||
this.session.setIdleTimeout(idleTimeout);
|
||||
|
||||
connectListener.ifPresent(listener -> listener.onWebSocketConnect(this.context));
|
||||
|
@ -162,6 +162,17 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
|
|||
logger.debug("onWebSocketText!");
|
||||
}
|
||||
|
||||
/**
|
||||
* The property name where {@link org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider} can find an
|
||||
* {@link ReusableAuth} object that lives for the lifetime of the websocket
|
||||
*/
|
||||
public static final String REUSABLE_AUTH_PROPERTY = WebSocketResourceProvider.class.getName() + ".reusableAuth";
|
||||
|
||||
/**
|
||||
* The property name where {@link org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider} can install a
|
||||
* {@link org.whispersystems.websocket.ReusableAuth.MutableRef} for us to close when the request is finished
|
||||
*/
|
||||
public static final String RESOLVED_PRINCIPAL_PROPERTY = WebSocketResourceProvider.class.getName() + ".resolvedPrincipal";
|
||||
private void handleRequest(WebSocketRequestMessage requestMessage) {
|
||||
ContainerRequest containerRequest = new ContainerRequest(null, URI.create(requestMessage.getPath()),
|
||||
requestMessage.getVerb(), new WebSocketSecurityContext(new ContextPrincipal(context)),
|
||||
|
@ -173,30 +184,43 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
|
|||
}
|
||||
|
||||
containerRequest.setProperty(remoteAddressPropertyName, remoteAddress);
|
||||
containerRequest.setProperty(REUSABLE_AUTH_PROPERTY, reusableAuth);
|
||||
|
||||
ByteArrayOutputStream responseBody = new ByteArrayOutputStream();
|
||||
CompletableFuture<ContainerResponse> responseFuture = (CompletableFuture<ContainerResponse>) jerseyHandler.apply(
|
||||
containerRequest, responseBody);
|
||||
|
||||
responseFuture.thenAccept(response -> {
|
||||
try {
|
||||
sendResponse(requestMessage, response, responseBody);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
requestLog.log(remoteAddress, containerRequest, response);
|
||||
}).exceptionally(exception -> {
|
||||
logger.warn("Websocket Error: " + requestMessage.getVerb() + " " + requestMessage.getPath() + "\n"
|
||||
+ requestMessage.getBody(), exception);
|
||||
try {
|
||||
sendErrorResponse(requestMessage, Response.status(500).build());
|
||||
} catch (IOException e) {
|
||||
logger.warn("Failed to send error response", e);
|
||||
}
|
||||
requestLog.log(remoteAddress, containerRequest,
|
||||
new ContainerResponse(containerRequest, Response.status(500).build()));
|
||||
return null;
|
||||
});
|
||||
responseFuture
|
||||
.whenComplete((ignoredResponse, ignoredError) -> {
|
||||
// If the request ended up being one that mutates our principal, we have to close it to indicate we're done
|
||||
// with the mutation operation
|
||||
final Object resolvedPrincipal = containerRequest.getProperty(RESOLVED_PRINCIPAL_PROPERTY);
|
||||
if (resolvedPrincipal instanceof ReusableAuth.MutableRef ref) {
|
||||
ref.close();
|
||||
} else if (resolvedPrincipal != null) {
|
||||
logger.warn("unexpected resolved principal type {} : {}", resolvedPrincipal.getClass(), resolvedPrincipal);
|
||||
}
|
||||
})
|
||||
.thenAccept(response -> {
|
||||
try {
|
||||
sendResponse(requestMessage, response, responseBody);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
requestLog.log(remoteAddress, containerRequest, response);
|
||||
})
|
||||
.exceptionally(exception -> {
|
||||
logger.warn("Websocket Error: " + requestMessage.getVerb() + " " + requestMessage.getPath() + "\n"
|
||||
+ requestMessage.getBody(), exception);
|
||||
try {
|
||||
sendErrorResponse(requestMessage, Response.status(500).build());
|
||||
} catch (IOException e) {
|
||||
logger.warn("Failed to send error response", e);
|
||||
}
|
||||
requestLog.log(remoteAddress, containerRequest,
|
||||
new ContainerResponse(containerRequest, Response.status(500).build()));
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
|
|
|
@ -22,7 +22,6 @@ import org.slf4j.Logger;
|
|||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.websocket.auth.AuthenticationException;
|
||||
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
|
||||
import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult;
|
||||
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
|
||||
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
|
||||
import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider;
|
||||
|
@ -57,17 +56,17 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
|
|||
public Object createWebSocket(final JettyServerUpgradeRequest request, final JettyServerUpgradeResponse response) {
|
||||
try {
|
||||
Optional<WebSocketAuthenticator<T>> authenticator = Optional.ofNullable(environment.getAuthenticator());
|
||||
T authenticated = null;
|
||||
|
||||
final ReusableAuth<T> authenticated;
|
||||
if (authenticator.isPresent()) {
|
||||
AuthenticationResult<T> authenticationResult = authenticator.get().authenticate(request);
|
||||
authenticated = authenticator.get().authenticate(request);
|
||||
|
||||
if (authenticationResult.getUser().isEmpty() && authenticationResult.credentialsPresented()) {
|
||||
if (authenticated.invalidCredentialsProvided()) {
|
||||
response.sendForbidden("Unauthorized");
|
||||
return null;
|
||||
} else {
|
||||
authenticated = authenticationResult.getUser().orElse(null);
|
||||
}
|
||||
} else {
|
||||
authenticated = ReusableAuth.anonymous();
|
||||
}
|
||||
|
||||
return new WebSocketResourceProvider<>(getRemoteAddress(request),
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.websocket.auth;
|
||||
|
||||
import io.dropwizard.auth.Auth;
|
||||
|
||||
import java.lang.annotation.ElementType;
|
||||
import java.lang.annotation.Retention;
|
||||
import java.lang.annotation.RetentionPolicy;
|
||||
import java.lang.annotation.Target;
|
||||
|
||||
/**
|
||||
* An @{@link Auth} object annotated with {@link Mutable} indicates that the consumer of the object
|
||||
* will modify the object or its underlying canonical source.
|
||||
*
|
||||
* Note: An {@link Auth} object that does not specify @{@link ReadOnly} will be assumed to be @Mutable
|
||||
*
|
||||
* @see org.whispersystems.websocket.ReusableAuth
|
||||
*/
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
@Target({ElementType.FIELD, ElementType.PARAMETER})
|
||||
public @interface Mutable {
|
||||
}
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.websocket.auth;
|
||||
|
||||
/**
|
||||
* Teach {@link org.whispersystems.websocket.ReusableAuth} how to make a deep copy of a principal (that is safe to
|
||||
* concurrently modify while the original principal is being read), and how to refresh a principal after it has been
|
||||
* potentially modified.
|
||||
*
|
||||
* @param <T> The underlying principal type
|
||||
*/
|
||||
public interface PrincipalSupplier<T> {
|
||||
|
||||
/**
|
||||
* Re-fresh the principal after it has been modified.
|
||||
* <p>
|
||||
* If the principal is populated from a backing store, refresh should re-read it.
|
||||
*
|
||||
* @param t the potentially stale principal to refresh
|
||||
* @return The up-to-date principal
|
||||
*/
|
||||
T refresh(T t);
|
||||
|
||||
/**
|
||||
* Create a deep, in-memory copy of the principal. This should be identical to the original principal, but should
|
||||
* share no mutable state with the original. It should be safe for two threads to independently write and read from
|
||||
* two independent deep copies.
|
||||
*
|
||||
* @param t the principal to copy
|
||||
* @return An in-memory copy of the principal
|
||||
*/
|
||||
T deepCopy(T t);
|
||||
|
||||
class ImmutablePrincipalSupplier<T> implements PrincipalSupplier<T> {
|
||||
@SuppressWarnings({"rawtypes"})
|
||||
private static final PrincipalSupplier INSTANCE = new ImmutablePrincipalSupplier();
|
||||
|
||||
@Override
|
||||
public T refresh(final T t) {
|
||||
return t;
|
||||
}
|
||||
|
||||
@Override
|
||||
public T deepCopy(final T t) {
|
||||
return t;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @return A principal supplier that can be used if the principal type does not support modification.
|
||||
*/
|
||||
static <T> PrincipalSupplier<T> forImmutablePrincipal() {
|
||||
//noinspection unchecked
|
||||
return (PrincipalSupplier<T>) ImmutablePrincipalSupplier.INSTANCE;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
/*
|
||||
* Copyright 2024 Signal Messenger, LLC
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
package org.whispersystems.websocket.auth;
|
||||
|
||||
import java.lang.annotation.ElementType;
|
||||
import java.lang.annotation.Retention;
|
||||
import java.lang.annotation.RetentionPolicy;
|
||||
import java.lang.annotation.Target;
|
||||
|
||||
/**
|
||||
* An @{@link io.dropwizard.auth.Auth} object annotated with {@link ReadOnly} indicates that the consumer of the object
|
||||
* will never modify the object, nor its underlying canonical source.
|
||||
* <p>
|
||||
* For example, a consumer of a @ReadOnly AuthenticatedAccount promises to never modify the in-memory
|
||||
* AuthenticatedAccount and to never modify the underlying Account database for the account.
|
||||
*
|
||||
* @see org.whispersystems.websocket.ReusableAuth
|
||||
*/
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
@Target({ElementType.FIELD, ElementType.PARAMETER})
|
||||
public @interface ReadOnly {
|
||||
}
|
||||
|
|
@ -7,27 +7,8 @@ package org.whispersystems.websocket.auth;
|
|||
import java.security.Principal;
|
||||
import java.util.Optional;
|
||||
import org.eclipse.jetty.websocket.api.UpgradeRequest;
|
||||
import org.whispersystems.websocket.ReusableAuth;
|
||||
|
||||
public interface WebSocketAuthenticator<T extends Principal> {
|
||||
|
||||
AuthenticationResult<T> authenticate(UpgradeRequest request) throws AuthenticationException;
|
||||
|
||||
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
|
||||
class AuthenticationResult<T> {
|
||||
private final Optional<T> user;
|
||||
private final boolean credentialsPresented;
|
||||
|
||||
public AuthenticationResult(final Optional<T> user, final boolean credentialsPresented) {
|
||||
this.user = user;
|
||||
this.credentialsPresented = credentialsPresented;
|
||||
}
|
||||
|
||||
public Optional<T> getUser() {
|
||||
return user;
|
||||
}
|
||||
|
||||
public boolean credentialsPresented() {
|
||||
return credentialsPresented;
|
||||
}
|
||||
}
|
||||
ReusableAuth<T> authenticate(UpgradeRequest request) throws AuthenticationException;
|
||||
}
|
||||
|
|
|
@ -5,24 +5,28 @@
|
|||
package org.whispersystems.websocket.auth;
|
||||
|
||||
import io.dropwizard.auth.Auth;
|
||||
import java.lang.reflect.ParameterizedType;
|
||||
import java.security.Principal;
|
||||
import java.util.Optional;
|
||||
import java.util.function.Function;
|
||||
import javax.annotation.Nullable;
|
||||
import javax.inject.Inject;
|
||||
import javax.inject.Singleton;
|
||||
import javax.ws.rs.WebApplicationException;
|
||||
import org.glassfish.jersey.internal.inject.AbstractBinder;
|
||||
import org.glassfish.jersey.server.ContainerRequest;
|
||||
import org.glassfish.jersey.server.internal.inject.AbstractValueParamProvider;
|
||||
import org.glassfish.jersey.server.internal.inject.MultivaluedParameterExtractorProvider;
|
||||
import org.glassfish.jersey.server.model.Parameter;
|
||||
import org.glassfish.jersey.server.spi.internal.ValueParamProvider;
|
||||
|
||||
import javax.annotation.Nullable;
|
||||
import javax.inject.Inject;
|
||||
import javax.inject.Singleton;
|
||||
import javax.ws.rs.WebApplicationException;
|
||||
import java.lang.reflect.ParameterizedType;
|
||||
import java.security.Principal;
|
||||
import java.util.Optional;
|
||||
import java.util.function.Function;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.websocket.ReusableAuth;
|
||||
import org.whispersystems.websocket.WebSocketResourceProvider;
|
||||
|
||||
@Singleton
|
||||
public class WebsocketAuthValueFactoryProvider<T extends Principal> extends AbstractValueParamProvider {
|
||||
private static final Logger logger = LoggerFactory.getLogger(WebsocketAuthValueFactoryProvider.class);
|
||||
|
||||
private final Class<T> principalClass;
|
||||
|
||||
|
@ -39,18 +43,38 @@ public class WebsocketAuthValueFactoryProvider<T extends Principal> extends Abst
|
|||
return null;
|
||||
}
|
||||
|
||||
if (parameter.getRawType() == Optional.class &&
|
||||
ParameterizedType.class.isAssignableFrom(parameter.getType().getClass()) &&
|
||||
principalClass == ((ParameterizedType)parameter.getType()).getActualTypeArguments()[0])
|
||||
{
|
||||
return request -> new OptionalContainerRequestValueFactory(request).provide();
|
||||
final boolean readOnly = parameter.isAnnotationPresent(ReadOnly.class);
|
||||
|
||||
if (parameter.getRawType() == Optional.class
|
||||
&& ParameterizedType.class.isAssignableFrom(parameter.getType().getClass())
|
||||
&& principalClass == ((ParameterizedType) parameter.getType()).getActualTypeArguments()[0]) {
|
||||
return containerRequest -> createPrincipal(containerRequest, readOnly);
|
||||
} else if (principalClass.equals(parameter.getRawType())) {
|
||||
return request -> new StandardContainerRequestValueFactory(request).provide();
|
||||
return containerRequest ->
|
||||
createPrincipal(containerRequest, readOnly)
|
||||
.orElseThrow(() -> new WebApplicationException("Authenticated resource", 401));
|
||||
} else {
|
||||
throw new IllegalStateException("Can't inject unassignable principal: " + principalClass + " for parameter: " + parameter);
|
||||
}
|
||||
}
|
||||
|
||||
private Optional<? extends Principal> createPrincipal(final ContainerRequest request, final boolean readOnly) {
|
||||
final Object obj = request.getProperty(WebSocketResourceProvider.REUSABLE_AUTH_PROPERTY);
|
||||
if (!(obj instanceof ReusableAuth<?>)) {
|
||||
logger.warn("Unexpected reusable auth property type {} : {}", obj.getClass(), obj);
|
||||
return Optional.empty();
|
||||
}
|
||||
@SuppressWarnings("unchecked") final ReusableAuth<T> reusableAuth = (ReusableAuth<T>) obj;
|
||||
if (readOnly) {
|
||||
return reusableAuth.ref();
|
||||
} else {
|
||||
return reusableAuth.mutableRef().map(writeRef -> {
|
||||
request.setProperty(WebSocketResourceProvider.RESOLVED_PRINCIPAL_PROPERTY, writeRef);
|
||||
return writeRef.ref();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@Singleton
|
||||
static class WebsocketPrincipalClassProvider<T extends Principal> {
|
||||
|
||||
|
@ -80,38 +104,4 @@ public class WebsocketAuthValueFactoryProvider<T extends Principal> extends Abst
|
|||
bind(WebsocketAuthValueFactoryProvider.class).to(ValueParamProvider.class).in(Singleton.class);
|
||||
}
|
||||
}
|
||||
|
||||
private static class StandardContainerRequestValueFactory {
|
||||
|
||||
private final ContainerRequest request;
|
||||
|
||||
public StandardContainerRequestValueFactory(ContainerRequest request) {
|
||||
this.request = request;
|
||||
}
|
||||
|
||||
public Principal provide() {
|
||||
final Principal principal = request.getSecurityContext().getUserPrincipal();
|
||||
|
||||
if (principal == null) {
|
||||
throw new WebApplicationException("Authenticated resource", 401);
|
||||
}
|
||||
|
||||
return principal;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static class OptionalContainerRequestValueFactory {
|
||||
|
||||
private final ContainerRequest request;
|
||||
|
||||
public OptionalContainerRequestValueFactory(ContainerRequest request) {
|
||||
this.request = request;
|
||||
}
|
||||
|
||||
public Optional<Principal> provide() {
|
||||
return Optional.ofNullable(request.getSecurityContext().getUserPrincipal());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ package org.whispersystems.websocket;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
@ -27,6 +28,7 @@ import org.glassfish.jersey.server.ResourceConfig;
|
|||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.whispersystems.websocket.auth.AuthenticationException;
|
||||
import org.whispersystems.websocket.auth.PrincipalSupplier;
|
||||
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
|
||||
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
|
||||
import org.whispersystems.websocket.setup.WebSocketEnvironment;
|
||||
|
@ -56,8 +58,7 @@ public class WebSocketResourceProviderFactoryTest {
|
|||
@Test
|
||||
void testUnauthorized() throws AuthenticationException, IOException {
|
||||
when(environment.getAuthenticator()).thenReturn(authenticator);
|
||||
when(authenticator.authenticate(eq(request))).thenReturn(
|
||||
new WebSocketAuthenticator.AuthenticationResult<>(Optional.empty(), true));
|
||||
when(authenticator.authenticate(eq(request))).thenReturn(ReusableAuth.invalid());
|
||||
when(environment.jersey()).thenReturn(jerseyEnvironment);
|
||||
|
||||
WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
|
||||
|
@ -74,8 +75,8 @@ public class WebSocketResourceProviderFactoryTest {
|
|||
Account account = new Account();
|
||||
|
||||
when(environment.getAuthenticator()).thenReturn(authenticator);
|
||||
when(authenticator.authenticate(eq(request))).thenReturn(
|
||||
new WebSocketAuthenticator.AuthenticationResult<>(Optional.of(account), true));
|
||||
when(authenticator.authenticate(eq(request)))
|
||||
.thenReturn(ReusableAuth.authenticated(account, PrincipalSupplier.forImmutablePrincipal()));
|
||||
when(environment.jersey()).thenReturn(jerseyEnvironment);
|
||||
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
|
||||
when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1");
|
||||
|
@ -137,6 +138,7 @@ public class WebSocketResourceProviderFactoryTest {
|
|||
public boolean implies(Subject subject) {
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -59,6 +59,7 @@ import org.glassfish.jersey.server.ResourceConfig;
|
|||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.stubbing.Answer;
|
||||
import org.whispersystems.websocket.auth.PrincipalSupplier;
|
||||
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
|
||||
import org.whispersystems.websocket.logging.WebsocketRequestLog;
|
||||
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
|
||||
|
@ -80,7 +81,7 @@ class WebSocketResourceProviderTest {
|
|||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||
REMOTE_ADDRESS_PROPERTY_NAME,
|
||||
applicationHandler, requestLog,
|
||||
new TestPrincipal("fooz"),
|
||||
immutableTestPrincipal("fooz"),
|
||||
new ProtobufWebSocketMessageFactory(),
|
||||
Optional.of(connectListener),
|
||||
Duration.ofMillis(30000));
|
||||
|
@ -108,7 +109,7 @@ class WebSocketResourceProviderTest {
|
|||
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
|
||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
Session session = mock(Session.class);
|
||||
|
@ -184,7 +185,7 @@ class WebSocketResourceProviderTest {
|
|||
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
|
||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
Session session = mock(Session.class);
|
||||
|
@ -240,7 +241,7 @@ class WebSocketResourceProviderTest {
|
|||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
Session session = mock(Session.class);
|
||||
|
@ -280,7 +281,7 @@ class WebSocketResourceProviderTest {
|
|||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
Session session = mock(Session.class);
|
||||
|
@ -320,7 +321,7 @@ class WebSocketResourceProviderTest {
|
|||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("authorizedUserName"),
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("authorizedUserName"),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
Session session = mock(Session.class);
|
||||
|
@ -360,8 +361,8 @@ class WebSocketResourceProviderTest {
|
|||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(),
|
||||
Optional.empty(), Duration.ofMillis(30000));
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, ReusableAuth.anonymous(),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
Session session = mock(Session.class);
|
||||
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||
|
@ -399,7 +400,7 @@ class WebSocketResourceProviderTest {
|
|||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("something"),
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("something"),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
Session session = mock(Session.class);
|
||||
|
@ -439,8 +440,8 @@ class WebSocketResourceProviderTest {
|
|||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(),
|
||||
Optional.empty(), Duration.ofMillis(30000));
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, ReusableAuth.anonymous(),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
Session session = mock(Session.class);
|
||||
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||
|
@ -479,7 +480,7 @@ class WebSocketResourceProviderTest {
|
|||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
Session session = mock(Session.class);
|
||||
|
@ -520,7 +521,7 @@ class WebSocketResourceProviderTest {
|
|||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
Session session = mock(Session.class);
|
||||
|
@ -562,7 +563,7 @@ class WebSocketResourceProviderTest {
|
|||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
Session session = mock(Session.class);
|
||||
|
@ -602,7 +603,7 @@ class WebSocketResourceProviderTest {
|
|||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
|
||||
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
|
||||
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||
|
||||
Session session = mock(Session.class);
|
||||
|
@ -727,6 +728,10 @@ class WebSocketResourceProviderTest {
|
|||
}
|
||||
}
|
||||
|
||||
public static ReusableAuth<TestPrincipal> immutableTestPrincipal(final String name) {
|
||||
return ReusableAuth.authenticated(new TestPrincipal(name), PrincipalSupplier.forImmutablePrincipal());
|
||||
}
|
||||
|
||||
public static class TestException extends Exception {
|
||||
|
||||
public TestException(String message) {
|
||||
|
|
Loading…
Reference in New Issue