Don't cache authenticated accounts in memory

This commit is contained in:
Jon Chambers 2025-06-23 08:40:05 -05:00 committed by GitHub
parent 9dfe51eac4
commit c952baa672
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
86 changed files with 961 additions and 2264 deletions

View File

@ -85,7 +85,6 @@ import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator
import org.whispersystems.textsecuregcm.auth.IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter;
import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener;
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.backup.BackupAuthManager;
@ -212,7 +211,6 @@ import org.whispersystems.textsecuregcm.spam.RegistrationRecoveryChecker;
import org.whispersystems.textsecuregcm.spam.SpamChecker;
import org.whispersystems.textsecuregcm.spam.SpamFilter;
import org.whispersystems.textsecuregcm.storage.AccountLockManager;
import org.whispersystems.textsecuregcm.storage.AccountPrincipalSupplier;
import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ChangeNumberManager;
@ -980,23 +978,19 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.jersey().register(MultiRecipientMessageProvider.class);
environment.jersey().register(new AuthDynamicFeature(accountAuthFilter));
environment.jersey().register(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class));
environment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager,
disconnectionRequestManager));
environment.jersey().register(new TimestampResponseFilter());
///
WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment = new WebSocketEnvironment<>(environment,
config.getWebSocketConfiguration(), Duration.ofMillis(90000));
webSocketEnvironment.jersey().register(new VirtualExecutorServiceProvider("managed-async-websocket-virtual-thread-"));
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator, new AccountPrincipalSupplier(accountsManager)));
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator));
webSocketEnvironment.setAuthenticatedWebSocketUpgradeFilter(new IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter(
keysManager, config.idlePrimaryDeviceReminderConfiguration().minIdleDuration(), Clock.systemUTC()));
webSocketEnvironment.setConnectListener(
new AuthenticatedConnectListener(receiptSender, messagesManager, messageMetrics, pushNotificationManager,
new AuthenticatedConnectListener(accountsManager, receiptSender, messagesManager, messageMetrics, pushNotificationManager,
pushNotificationScheduler, webSocketConnectionEventManager, websocketScheduledExecutor,
messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor, experimentEnrollmentManager));
webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(accountsManager, disconnectionRequestManager));
webSocketEnvironment.jersey().register(new RateLimitByIpFilter(rateLimiters));
webSocketEnvironment.jersey().register(new RequestStatisticsFilter(TrafficSource.WEBSOCKET));
webSocketEnvironment.jersey().register(MultiRecipientMessageProvider.class);
@ -1083,15 +1077,15 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
registrationLockVerificationManager, rateLimiters),
new AttachmentControllerV4(rateLimiters, gcsAttachmentGenerator, tusAttachmentGenerator,
experimentEnrollmentManager),
new ArchiveController(backupAuthManager, backupManager, backupMetrics),
new ArchiveController(accountsManager, backupAuthManager, backupManager, backupMetrics),
new CallRoutingControllerV2(rateLimiters, cloudflareTurnCredentialsManager),
new CallLinkController(rateLimiters, callingGenericZkSecretParams),
new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().certificate(),
new CertificateController(accountsManager, new CertificateGenerator(config.getDeliveryCertificate().certificate(),
config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()),
zkAuthOperations, callingGenericZkSecretParams, clock),
new ChallengeController(rateLimitChallengeManager, challengeConstraintChecker),
new ChallengeController(accountsManager, rateLimitChallengeManager, challengeConstraintChecker),
new DeviceController(accountsManager, clientPublicKeysManager, rateLimiters, persistentTimer, config.getMaxDevices()),
new DeviceCheckController(clock, backupAuthManager, appleDeviceCheckManager, rateLimiters,
new DeviceCheckController(clock, accountsManager, backupAuthManager, appleDeviceCheckManager, rateLimiters,
config.getDeviceCheck().backupRedemptionLevel(),
config.getDeviceCheck().backupRedemptionDuration()),
new DirectoryV2Controller(directoryV2CredentialsGenerator),
@ -1140,8 +1134,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
WebSocketEnvironment<AuthenticatedDevice> provisioningEnvironment = new WebSocketEnvironment<>(environment,
webSocketEnvironment.getRequestLog(), Duration.ofMillis(60000));
provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager,
disconnectionRequestManager));
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager, provisioningWebsocketTimeoutExecutor, Duration.ofSeconds(90)));
provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET, clientReleaseManager));
provisioningEnvironment.jersey().register(new KeepAliveController(webSocketConnectionEventManager));

View File

@ -7,10 +7,20 @@ package org.whispersystems.textsecuregcm.auth;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import java.time.Instant;
import java.util.UUID;
public interface AccountAndAuthenticatedDeviceHolder {
UUID getAccountIdentifier();
byte getDeviceId();
Instant getPrimaryDeviceLastSeen();
@Deprecated(forRemoval = true)
Account getAccount();
@Deprecated(forRemoval = true)
Device getAuthenticatedDevice();
}

View File

@ -6,7 +6,10 @@
package org.whispersystems.textsecuregcm.auth;
import java.security.Principal;
import java.time.Instant;
import java.util.UUID;
import javax.security.auth.Subject;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
@ -30,6 +33,21 @@ public class AuthenticatedDevice implements Principal, AccountAndAuthenticatedDe
return device;
}
@Override
public UUID getAccountIdentifier() {
return account.getIdentifier(IdentityType.ACI);
}
@Override
public byte getDeviceId() {
return device.getId();
}
@Override
public Instant getPrimaryDeviceLastSeen() {
return Instant.ofEpochMilli(account.getPrimaryDevice().getLastSeen());
}
// Principal implementation
@Override

View File

@ -31,9 +31,9 @@ public class CertificateGenerator {
this.serverCertificate = ServerCertificate.parseFrom(serverCertificate);
}
public byte[] createFor(Account account, Device device, boolean includeE164) throws InvalidKeyException {
public byte[] createFor(final Account account, final byte deviceId, boolean includeE164) throws InvalidKeyException {
SenderCertificate.Certificate.Builder builder = SenderCertificate.Certificate.newBuilder()
.setSenderDevice(Math.toIntExact(device.getId()))
.setSenderDevice(Math.toIntExact(deviceId))
.setExpires(System.currentTimeMillis() + TimeUnit.DAYS.toMillis(expiresDays))
.setIdentityKey(ByteString.copyFrom(account.getIdentityKey(IdentityType.ACI).serialize()))
.setSigner(serverCertificate)

View File

@ -1,44 +0,0 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import jakarta.ws.rs.core.SecurityContext;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import org.glassfish.jersey.server.ContainerRequest;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
class ContainerRequestUtil {
/**
* A read-only subset of the authenticated Account object, to enforce that filter-based consumers do not perform
* account modifying operations.
*/
record AccountInfo(UUID accountId, String e164, Set<Byte> deviceIds) {
static AccountInfo fromAccount(final Account account) {
return new AccountInfo(
account.getUuid(),
account.getNumber(),
account.getDevices().stream().map(Device::getId).collect(Collectors.toSet()));
}
}
static Optional<AccountInfo> getAuthenticatedAccount(final ContainerRequest request) {
return Optional.ofNullable(request.getSecurityContext())
.map(SecurityContext::getUserPrincipal)
.map(principal -> {
if (principal instanceof AccountAndAuthenticatedDeviceHolder aaadh) {
return aaadh.getAccount();
}
return null;
})
.map(AccountInfo::fromAccount);
}
}

View File

@ -8,17 +8,16 @@ package org.whispersystems.textsecuregcm.auth;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Optional;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter;
public class IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter implements
AuthenticatedWebSocketUpgradeFilter<AuthenticatedDevice> {
@ -58,21 +57,19 @@ public class IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter implements
}
@Override
public void handleAuthentication(final ReusableAuth<AuthenticatedDevice> authenticated,
public void handleAuthentication(final Optional<AuthenticatedDevice> authenticated,
final JettyServerUpgradeRequest request,
final JettyServerUpgradeResponse response) {
// No action needed if the connection is unauthenticated (in which case we don't know when we've last seen the
// primary device) or if the authenticated device IS the primary device
authenticated.ref()
.filter(authenticatedDevice -> !authenticatedDevice.getAuthenticatedDevice().isPrimary())
authenticated
.filter(authenticatedDevice -> authenticatedDevice.getDeviceId() != Device.PRIMARY_ID)
.ifPresent(authenticatedDevice -> {
final Instant primaryDeviceLastSeen =
Instant.ofEpochMilli(authenticatedDevice.getAccount().getPrimaryDevice().getLastSeen());
final Instant primaryDeviceLastSeen = authenticatedDevice.getPrimaryDeviceLastSeen();
if (primaryDeviceLastSeen.isBefore(clock.instant().minus(PQ_KEY_CHECK_THRESHOLD)) &&
keysManager.getLastResort(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI), Device.PRIMARY_ID)
.join().isEmpty()) {
keysManager.getLastResort(authenticatedDevice.getAccountIdentifier(), Device.PRIMARY_ID).join().isEmpty()) {
response.addHeader(ALERT_HEADER, CRITICAL_IDLE_PRIMARY_DEVICE_ALERT);
CRITICAL_IDLE_PRIMARY_WARNING_COUNTER.increment();

View File

@ -1,96 +0,0 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.util.Pair;
/**
* This {@link WebsocketRefreshRequirementProvider} observes intra-request changes in devices linked to an
* {@link Account} and triggers a WebSocket refresh if that set changes. If a change in linked devices is observed, then
* any active WebSocket connections for the account must be closed in order for clients to get a refreshed
* {@link io.dropwizard.auth.Auth} object with a current device list.
*
* @see AuthenticatedDevice
*/
public class LinkedDeviceRefreshRequirementProvider implements WebsocketRefreshRequirementProvider {
private final AccountsManager accountsManager;
private static final Logger logger = LoggerFactory.getLogger(LinkedDeviceRefreshRequirementProvider.class);
private static final String ACCOUNT_UUID = LinkedDeviceRefreshRequirementProvider.class.getName() + ".accountUuid";
private static final String LINKED_DEVICE_IDS = LinkedDeviceRefreshRequirementProvider.class.getName() + ".deviceIds";
public LinkedDeviceRefreshRequirementProvider(final AccountsManager accountsManager) {
this.accountsManager = accountsManager;
}
@Override
public void handleRequestFiltered(final RequestEvent requestEvent) {
if (requestEvent.getUriInfo().getMatchedResourceMethod().getInvocable().getHandlingMethod().getAnnotation(
ChangesLinkedDevices.class) != null) {
// The authenticated principal, if any, will be available after filters have run. Now that the account is known,
// capture a snapshot of the account's linked devices before carrying out the requests business logic.
ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest())
.ifPresent(account -> setAccount(requestEvent.getContainerRequest(), account));
}
}
public static void setAccount(final ContainerRequest containerRequest, final Account account) {
setAccount(containerRequest, ContainerRequestUtil.AccountInfo.fromAccount(account));
}
private static void setAccount(final ContainerRequest containerRequest, final ContainerRequestUtil.AccountInfo info) {
containerRequest.setProperty(ACCOUNT_UUID, info.accountId());
containerRequest.setProperty(LINKED_DEVICE_IDS, info.deviceIds());
}
@Override
public List<Pair<UUID, Byte>> handleRequestFinished(final RequestEvent requestEvent) {
// Now that the request is finished, check whether the set of linked devices has changed. If the value did change or
// if a devices was added or removed, all devices must disconnect and reauthenticate.
if (requestEvent.getContainerRequest().getProperty(LINKED_DEVICE_IDS) != null) {
@SuppressWarnings("unchecked") final Set<Byte> initialLinkedDeviceIds =
(Set<Byte>) requestEvent.getContainerRequest().getProperty(LINKED_DEVICE_IDS);
return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID))
.map(ContainerRequestUtil.AccountInfo::fromAccount)
.map(accountInfo -> {
final Set<Byte> deviceIdsToDisplace;
final Set<Byte> currentLinkedDeviceIds = accountInfo.deviceIds();
if (!initialLinkedDeviceIds.equals(currentLinkedDeviceIds)) {
deviceIdsToDisplace = new HashSet<>(initialLinkedDeviceIds);
deviceIdsToDisplace.addAll(currentLinkedDeviceIds);
} else {
deviceIdsToDisplace = Collections.emptySet();
}
return deviceIdsToDisplace.stream()
.map(deviceId -> new Pair<>(accountInfo.accountId(), deviceId))
.collect(Collectors.toList());
}).orElseGet(() -> {
logger.error("Request had account, but it is no longer present");
return Collections.emptyList();
});
} else {
return Collections.emptyList();
}
}
}

View File

@ -1,56 +0,0 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.util.Pair;
public class PhoneNumberChangeRefreshRequirementProvider implements WebsocketRefreshRequirementProvider {
private static final String ACCOUNT_UUID =
PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".accountUuid";
private static final String INITIAL_NUMBER_KEY =
PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".initialNumber";
private final AccountsManager accountsManager;
public PhoneNumberChangeRefreshRequirementProvider(final AccountsManager accountsManager) {
this.accountsManager = accountsManager;
}
@Override
public void handleRequestFiltered(final RequestEvent requestEvent) {
if (requestEvent.getUriInfo().getMatchedResourceMethod().getInvocable().getHandlingMethod()
.getAnnotation(ChangesPhoneNumber.class) == null) {
return;
}
ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest())
.ifPresent(account -> {
requestEvent.getContainerRequest().setProperty(INITIAL_NUMBER_KEY, account.e164());
requestEvent.getContainerRequest().setProperty(ACCOUNT_UUID, account.accountId());
});
}
@Override
public List<Pair<UUID, Byte>> handleRequestFinished(final RequestEvent requestEvent) {
final String initialNumber = (String) requestEvent.getContainerRequest().getProperty(INITIAL_NUMBER_KEY);
if (initialNumber == null) {
return Collections.emptyList();
}
return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID))
.filter(account -> !initialNumber.equals(account.getNumber()))
.map(account -> account.getDevices().stream()
.map(device -> new Pair<>(account.getUuid(), device.getId()))
.collect(Collectors.toList()))
.orElse(Collections.emptyList());
}
}

View File

@ -1,38 +0,0 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import org.glassfish.jersey.server.monitoring.ApplicationEvent;
import org.glassfish.jersey.server.monitoring.ApplicationEventListener;
import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.glassfish.jersey.server.monitoring.RequestEventListener;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
/**
* Delegates request events to a listener that watches for intra-request changes that require websocket refreshes
*/
public class WebsocketRefreshApplicationEventListener implements ApplicationEventListener {
private final WebsocketRefreshRequestEventListener websocketRefreshRequestEventListener;
public WebsocketRefreshApplicationEventListener(final AccountsManager accountsManager,
final DisconnectionRequestManager disconnectionRequestManager) {
this.websocketRefreshRequestEventListener = new WebsocketRefreshRequestEventListener(
disconnectionRequestManager,
new LinkedDeviceRefreshRequirementProvider(accountsManager),
new PhoneNumberChangeRefreshRequirementProvider(accountsManager));
}
@Override
public void onEvent(final ApplicationEvent event) {
}
@Override
public RequestEventListener onRequest(final RequestEvent requestEvent) {
return websocketRefreshRequestEventListener;
}
}

View File

@ -1,74 +0,0 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import jakarta.ws.rs.container.ResourceInfo;
import jakarta.ws.rs.core.Context;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.glassfish.jersey.server.monitoring.RequestEvent.Type;
import org.glassfish.jersey.server.monitoring.RequestEventListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class WebsocketRefreshRequestEventListener implements RequestEventListener {
private final DisconnectionRequestManager disconnectionRequestManager;
private final WebsocketRefreshRequirementProvider[] providers;
private static final Counter DISPLACED_ACCOUNTS = Metrics.counter(
name(WebsocketRefreshRequestEventListener.class, "displacedAccounts"));
private static final Counter DISPLACED_DEVICES = Metrics.counter(
name(WebsocketRefreshRequestEventListener.class, "displacedDevices"));
private static final Logger logger = LoggerFactory.getLogger(WebsocketRefreshRequestEventListener.class);
public WebsocketRefreshRequestEventListener(
final DisconnectionRequestManager disconnectionRequestManager,
final WebsocketRefreshRequirementProvider... providers) {
this.disconnectionRequestManager = disconnectionRequestManager;
this.providers = providers;
}
@Context
private ResourceInfo resourceInfo;
@Override
public void onEvent(final RequestEvent event) {
if (event.getType() == Type.REQUEST_FILTERED) {
for (final WebsocketRefreshRequirementProvider provider : providers) {
provider.handleRequestFiltered(event);
}
} else if (event.getType() == Type.FINISHED) {
final AtomicInteger displacedDevices = new AtomicInteger(0);
Arrays.stream(providers)
.flatMap(provider -> provider.handleRequestFinished(event).stream())
.distinct()
.forEach(pair -> {
try {
displacedDevices.incrementAndGet();
disconnectionRequestManager.requestDisconnection(pair.first(), List.of(pair.second()));
} catch (final Exception e) {
logger.error("Could not displace device presence", e);
}
});
if (displacedDevices.get() > 0) {
DISPLACED_ACCOUNTS.increment();
DISPLACED_DEVICES.increment(displacedDevices.get());
}
}
}
}

View File

@ -1,34 +0,0 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import java.util.List;
import java.util.UUID;
import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.whispersystems.textsecuregcm.util.Pair;
/**
* A websocket refresh requirement provider watches for intra-request changes (e.g. to authentication status) that
* require a websocket refresh.
*/
public interface WebsocketRefreshRequirementProvider {
/**
* Processes a request after filters have run and the request has been mapped to a destination controller.
*
* @param requestEvent the request event to observe
*/
void handleRequestFiltered(RequestEvent requestEvent);
/**
* Processes a request after all normal request handling has been completed.
*
* @param requestEvent the request event to observe
* @return a list of pairs of account UUID/device ID pairs identifying websockets that need to be refreshed as a
* result of the observed request
*/
List<Pair<UUID, Byte>> handleRequestFinished(RequestEvent requestEvent);
}

View File

@ -66,8 +66,6 @@ import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.auth.Mutable;
import org.whispersystems.websocket.auth.ReadOnly;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v1/accounts")
@ -97,11 +95,14 @@ public class AccountController {
@Path("/gcm/")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public void setGcmRegistrationId(@Mutable @Auth AuthenticatedDevice auth,
public void setGcmRegistrationId(@Auth AuthenticatedDevice auth,
@NotNull @Valid GcmRegistrationId registrationId) {
final Account account = auth.getAccount();
final Device device = auth.getAuthenticatedDevice();
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final Device device = account.getDevice(auth.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
if (Objects.equals(device.getGcmId(), registrationId.gcmRegistrationId())) {
return;
@ -116,9 +117,12 @@ public class AccountController {
@DELETE
@Path("/gcm/")
public void deleteGcmRegistrationId(@Mutable @Auth AuthenticatedDevice auth) {
Account account = auth.getAccount();
Device device = auth.getAuthenticatedDevice();
public void deleteGcmRegistrationId(@Auth AuthenticatedDevice auth) {
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final Device device = account.getDevice(auth.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
accounts.updateDevice(account, device.getId(), d -> {
d.setGcmId(null);
@ -131,11 +135,14 @@ public class AccountController {
@Path("/apn/")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public void setApnRegistrationId(@Mutable @Auth AuthenticatedDevice auth,
public void setApnRegistrationId(@Auth AuthenticatedDevice auth,
@NotNull @Valid ApnRegistrationId registrationId) {
final Account account = auth.getAccount();
final Device device = auth.getAuthenticatedDevice();
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final Device device = account.getDevice(auth.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
// Unlike FCM tokens, we need current "last updated" timestamps for APNs tokens and so update device records
// unconditionally
@ -148,9 +155,12 @@ public class AccountController {
@DELETE
@Path("/apn/")
public void deleteApnRegistrationId(@Mutable @Auth AuthenticatedDevice auth) {
Account account = auth.getAccount();
Device device = auth.getAuthenticatedDevice();
public void deleteApnRegistrationId(@Auth AuthenticatedDevice auth) {
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final Device device = account.getDevice(auth.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
accounts.updateDevice(account, device.getId(), d -> {
d.setApnId(null);
@ -166,17 +176,23 @@ public class AccountController {
@PUT
@Produces(MediaType.APPLICATION_JSON)
@Path("/registration_lock")
public void setRegistrationLock(@Mutable @Auth AuthenticatedDevice auth, @NotNull @Valid RegistrationLock accountLock) {
SaltedTokenHash credentials = SaltedTokenHash.generateFor(accountLock.getRegistrationLock());
public void setRegistrationLock(@Auth AuthenticatedDevice auth, @NotNull @Valid RegistrationLock accountLock) {
final SaltedTokenHash credentials = SaltedTokenHash.generateFor(accountLock.getRegistrationLock());
accounts.update(auth.getAccount(),
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
accounts.update(account,
a -> a.setRegistrationLock(credentials.hash(), credentials.salt()));
}
@DELETE
@Path("/registration_lock")
public void removeRegistrationLock(@Mutable @Auth AuthenticatedDevice auth) {
accounts.update(auth.getAccount(), a -> a.setRegistrationLock(null, null));
public void removeRegistrationLock(@Auth AuthenticatedDevice auth) {
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
accounts.update(account, a -> a.setRegistrationLock(null, null));
}
@PUT
@ -190,7 +206,7 @@ public class AccountController {
@ApiResponse(responseCode = "204", description = "Device name changed successfully")
@ApiResponse(responseCode = "404", description = "No device found with the given ID")
@ApiResponse(responseCode = "403", description = "Not authorized to change the name of the device with the given ID")
public void setName(@Mutable @Auth final AuthenticatedDevice auth,
public void setName(@Auth final AuthenticatedDevice auth,
@NotNull @Valid final DeviceName deviceName,
@Nullable
@ -199,15 +215,16 @@ public class AccountController {
requiredMode = Schema.RequiredMode.NOT_REQUIRED)
final Byte deviceId) {
final Account account = auth.getAccount();
final byte targetDeviceId = deviceId == null ? auth.getAuthenticatedDevice().getId() : deviceId;
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final byte targetDeviceId = deviceId == null ? auth.getDeviceId() : deviceId;
if (account.getDevice(targetDeviceId).isEmpty()) {
throw new NotFoundException();
}
final boolean mayChangeName = auth.getAuthenticatedDevice().isPrimary() ||
auth.getAuthenticatedDevice().getId() == targetDeviceId;
final boolean mayChangeName = auth.getDeviceId() == Device.PRIMARY_ID || auth.getDeviceId() == targetDeviceId;
if (!mayChangeName) {
throw new ForbiddenException();
@ -221,14 +238,14 @@ public class AccountController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public void setAccountAttributes(
@Mutable @Auth AuthenticatedDevice auth,
@Auth AuthenticatedDevice auth,
@HeaderParam(HeaderUtils.X_SIGNAL_AGENT) String userAgent,
@NotNull @Valid AccountAttributes attributes) {
final Account account = auth.getAccount();
final byte deviceId = auth.getAuthenticatedDevice().getId();
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final Account updatedAccount = accounts.update(account, a -> {
a.getDevice(deviceId).ifPresent(d -> {
a.getDevice(auth.getDeviceId()).ifPresent(d -> {
d.setFetchesMessages(attributes.getFetchesMessages());
d.setName(attributes.getName());
d.setLastSeen(Util.todayInMillis());
@ -252,8 +269,11 @@ public class AccountController {
@GET
@Path("/whoami")
@Produces(MediaType.APPLICATION_JSON)
public AccountIdentityResponse whoAmI(@ReadOnly @Auth AuthenticatedDevice auth) {
return AccountIdentityResponseBuilder.fromAccount(auth.getAccount());
public AccountIdentityResponse whoAmI(@Auth final AuthenticatedDevice auth) {
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
return AccountIdentityResponseBuilder.fromAccount(account);
}
@DELETE
@ -267,8 +287,11 @@ public class AccountController {
)
@ApiResponse(responseCode = "204", description = "Username successfully deleted.", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
public CompletableFuture<Response> deleteUsernameHash(@Mutable @Auth final AuthenticatedDevice auth) {
return accounts.clearUsernameHash(auth.getAccount())
public CompletableFuture<Response> deleteUsernameHash(@Auth final AuthenticatedDevice auth) {
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
return accounts.clearUsernameHash(account)
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
}
@ -289,10 +312,13 @@ public class AccountController {
@ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Ratelimited.")
public CompletableFuture<ReserveUsernameHashResponse> reserveUsernameHash(
@Mutable @Auth final AuthenticatedDevice auth,
@Auth final AuthenticatedDevice auth,
@NotNull @Valid final ReserveUsernameHashRequest usernameRequest) throws RateLimitExceededException {
rateLimiters.getUsernameReserveLimiter().validate(auth.getAccount().getUuid());
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
rateLimiters.getUsernameReserveLimiter().validate(auth.getAccountIdentifier());
for (final byte[] hash : usernameRequest.usernameHashes()) {
if (hash.length != USERNAME_HASH_LENGTH) {
@ -300,7 +326,7 @@ public class AccountController {
}
}
return accounts.reserveUsernameHash(auth.getAccount(), usernameRequest.usernameHashes())
return accounts.reserveUsernameHash(account, usernameRequest.usernameHashes())
.thenApply(reservation -> new ReserveUsernameHashResponse(reservation.reservedUsernameHash()))
.exceptionally(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof UsernameHashNotAvailableException) {
@ -329,18 +355,21 @@ public class AccountController {
@ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Ratelimited.")
public CompletableFuture<UsernameHashResponse> confirmUsernameHash(
@Mutable @Auth final AuthenticatedDevice auth,
@Auth final AuthenticatedDevice auth,
@NotNull @Valid final ConfirmUsernameHashRequest confirmRequest) {
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
try {
usernameHashZkProofVerifier.verifyProof(confirmRequest.zkProof(), confirmRequest.usernameHash());
} catch (final BaseUsernameException e) {
throw new WebApplicationException(Response.status(422).build());
}
return rateLimiters.getUsernameSetLimiter().validateAsync(auth.getAccount().getUuid())
return rateLimiters.getUsernameSetLimiter().validateAsync(account.getUuid())
.thenCompose(ignored -> accounts.confirmReservedUsernameHash(
auth.getAccount(),
account,
confirmRequest.usernameHash(),
confirmRequest.encryptedUsername()))
.thenApply(updatedAccount -> new UsernameHashResponse(updatedAccount.getUsernameHash()
@ -374,7 +403,7 @@ public class AccountController {
@ApiResponse(responseCode = "400", description = "Request must not be authenticated.")
@ApiResponse(responseCode = "404", description = "Account not found for the given username.")
public CompletableFuture<AccountIdentifierResponse> lookupUsernameHash(
@ReadOnly @Auth final Optional<AuthenticatedDevice> maybeAuthenticatedAccount,
@Auth final Optional<AuthenticatedDevice> maybeAuthenticatedAccount,
@PathParam("usernameHash") final String usernameHash) {
requireNotAuthenticated(maybeAuthenticatedAccount);
@ -413,12 +442,14 @@ public class AccountController {
@ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Ratelimited.")
public UsernameLinkHandle updateUsernameLink(
@Mutable @Auth final AuthenticatedDevice auth,
@Auth final AuthenticatedDevice auth,
@NotNull @Valid final EncryptedUsername encryptedUsername) throws RateLimitExceededException {
// check ratelimiter for username link operations
rateLimiters.forDescriptor(RateLimiters.For.USERNAME_LINK_OPERATION).validate(auth.getAccount().getUuid());
final Account account = auth.getAccount();
// check ratelimiter for username link operations
rateLimiters.forDescriptor(RateLimiters.For.USERNAME_LINK_OPERATION).validate(auth.getAccountIdentifier());
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
// check if username hash is set for the account
if (account.getUsernameHash().isEmpty()) {
@ -431,7 +462,7 @@ public class AccountController {
} else {
usernameLinkHandle = UUID.randomUUID();
}
updateUsernameLink(auth.getAccount(), usernameLinkHandle, encryptedUsername.usernameLinkEncryptedValue());
updateUsernameLink(account, usernameLinkHandle, encryptedUsername.usernameLinkEncryptedValue());
return new UsernameLinkHandle(usernameLinkHandle);
}
@ -447,10 +478,14 @@ public class AccountController {
@ApiResponse(responseCode = "204", description = "Username Link successfully deleted.", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ApiResponse(responseCode = "429", description = "Ratelimited.")
public void deleteUsernameLink(@Mutable @Auth final AuthenticatedDevice auth) throws RateLimitExceededException {
public void deleteUsernameLink(@Auth final AuthenticatedDevice auth) throws RateLimitExceededException {
// check ratelimiter for username link operations
rateLimiters.forDescriptor(RateLimiters.For.USERNAME_LINK_OPERATION).validate(auth.getAccount().getUuid());
clearUsernameLink(auth.getAccount());
rateLimiters.forDescriptor(RateLimiters.For.USERNAME_LINK_OPERATION).validate(auth.getAccountIdentifier());
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
clearUsernameLink(account);
}
@GET
@ -470,7 +505,7 @@ public class AccountController {
@ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Ratelimited.")
public CompletableFuture<EncryptedUsername> lookupUsernameLink(
@ReadOnly @Auth final Optional<AuthenticatedDevice> maybeAuthenticatedAccount,
@Auth final Optional<AuthenticatedDevice> maybeAuthenticatedAccount,
@PathParam("uuid") final UUID usernameLinkHandle) {
requireNotAuthenticated(maybeAuthenticatedAccount);
@ -496,7 +531,7 @@ public class AccountController {
@Path("/account/{identifier}")
@RateLimitedByIp(RateLimiters.For.CHECK_ACCOUNT_EXISTENCE)
public Response accountExists(
@ReadOnly @Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@Parameter(description = "An ACI or PNI account identifier to check")
@PathParam("identifier") final ServiceIdentifier accountIdentifier) {
@ -511,8 +546,11 @@ public class AccountController {
@DELETE
@Path("/me")
public CompletableFuture<Response> deleteAccount(@Mutable @Auth AuthenticatedDevice auth) {
return accounts.delete(auth.getAccount(), AccountsManager.DeletionReason.USER_REQUEST).thenApply(Util.ASYNC_EMPTY_RESPONSE);
public CompletableFuture<Response> deleteAccount(@Auth AuthenticatedDevice auth) {
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
return accounts.delete(account, AccountsManager.DeletionReason.USER_REQUEST).thenApply(Util.ASYNC_EMPTY_RESPONSE);
}
private void clearUsernameLink(final Account account) {

View File

@ -55,8 +55,7 @@ import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ChangeNumberManager;
import org.whispersystems.websocket.auth.Mutable;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.textsecuregcm.storage.Device;
@Path("/v2/accounts")
@io.swagger.v3.oas.annotations.tags.Tag(name = "Account")
@ -101,12 +100,12 @@ public class AccountControllerV2 {
@ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header(
name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public AccountIdentityResponse changeNumber(@Mutable @Auth final AuthenticatedDevice authenticatedDevice,
public AccountIdentityResponse changeNumber(@Auth final AuthenticatedDevice authenticatedDevice,
@NotNull @Valid final ChangeNumberRequest request,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgentString,
@Context final ContainerRequestContext requestContext) throws RateLimitExceededException, InterruptedException {
if (!authenticatedDevice.getAuthenticatedDevice().isPrimary()) {
if (authenticatedDevice.getDeviceId() != Device.PRIMARY_ID) {
throw new ForbiddenException();
}
@ -116,8 +115,11 @@ public class AccountControllerV2 {
final String number = request.number();
final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
// Only verify and check reglock if there's a data change to be made...
if (!authenticatedDevice.getAccount().getNumber().equals(number)) {
if (!account.getNumber().equals(number)) {
rateLimiters.getRegistrationLimiter().validate(number);
@ -139,7 +141,7 @@ public class AccountControllerV2 {
// ...but always attempt to make the change in case a client retries and needs to re-send messages
try {
final Account updatedAccount = changeNumberManager.changeNumber(
authenticatedDevice.getAccount(),
account,
request.number(),
request.pniIdentityKey(),
request.devicePniSignedPrekeys(),
@ -185,11 +187,11 @@ public class AccountControllerV2 {
content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class)))
@ApiResponse(responseCode = "413", description = "One or more device messages was too large")
public AccountIdentityResponse distributePhoneNumberIdentityKeys(
@Mutable @Auth final AuthenticatedDevice authenticatedDevice,
@Auth final AuthenticatedDevice authenticatedDevice,
@HeaderParam(HttpHeaders.USER_AGENT) @Nullable final String userAgentString,
@NotNull @Valid final PhoneNumberIdentityKeyDistributionRequest request) {
if (!authenticatedDevice.getAuthenticatedDevice().isPrimary()) {
if (authenticatedDevice.getDeviceId() != Device.PRIMARY_ID) {
throw new ForbiddenException();
}
@ -197,9 +199,12 @@ public class AccountControllerV2 {
throw new WebApplicationException("Invalid signature", 422);
}
final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
try {
final Account updatedAccount = changeNumberManager.updatePniKeys(
authenticatedDevice.getAccount(),
account,
request.pniIdentityKey(),
request.devicePniSignedPrekeys(),
request.devicePniPqLastResortPrekeys(),
@ -235,10 +240,13 @@ public class AccountControllerV2 {
@Operation(summary = "Sets whether the account should be discoverable by phone number in the directory.")
@ApiResponse(responseCode = "204", description = "The setting was successfully updated.")
public void setPhoneNumberDiscoverability(
@Mutable @Auth AuthenticatedDevice auth,
@NotNull @Valid PhoneNumberDiscoverabilityRequest phoneNumberDiscoverability
) {
accountsManager.update(auth.getAccount(), a -> a.setDiscoverableByPhoneNumber(
@Auth AuthenticatedDevice auth,
@NotNull @Valid PhoneNumberDiscoverabilityRequest phoneNumberDiscoverability) {
final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
accountsManager.update(account, a -> a.setDiscoverableByPhoneNumber(
phoneNumberDiscoverability.discoverableByPhoneNumber()));
}
@ -249,9 +257,10 @@ public class AccountControllerV2 {
@ApiResponse(responseCode = "200",
description = "Response with data report. A plain text representation is a field in the response.",
useReturnTypeSchema = true)
public AccountDataReportResponse getAccountDataReport(@ReadOnly @Auth final AuthenticatedDevice auth) {
public AccountDataReportResponse getAccountDataReport(@Auth final AuthenticatedDevice auth) {
final Account account = auth.getAccount();
final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return new AccountDataReportResponse(UUID.randomUUID(), Instant.now(),
new AccountDataReportResponse.AccountAndDevicesDataReport(

View File

@ -37,6 +37,7 @@ import jakarta.ws.rs.PUT;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.QueryParam;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import java.io.IOException;
@ -71,14 +72,15 @@ import org.whispersystems.textsecuregcm.backup.MediaEncryptionParameters;
import org.whispersystems.textsecuregcm.entities.RemoteAttachment;
import org.whispersystems.textsecuregcm.metrics.BackupMetrics;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.BackupAuthCredentialAdapter;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import org.whispersystems.textsecuregcm.util.ByteArrayBase64UrlAdapter;
import org.whispersystems.textsecuregcm.util.ECPublicKeyAdapter;
import org.whispersystems.textsecuregcm.util.ExactlySize;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.auth.Mutable;
import org.whispersystems.websocket.auth.ReadOnly;
import reactor.core.publisher.Mono;
@Path("/v1/archives")
@ -88,14 +90,18 @@ public class ArchiveController {
public final static String X_SIGNAL_ZK_AUTH = "X-Signal-ZK-Auth";
public final static String X_SIGNAL_ZK_AUTH_SIGNATURE = "X-Signal-ZK-Auth-Signature";
private final AccountsManager accountsManager;
private final BackupAuthManager backupAuthManager;
private final BackupManager backupManager;
private final BackupMetrics backupMetrics;
public ArchiveController(
final AccountsManager accountsManager,
final BackupAuthManager backupAuthManager,
final BackupManager backupManager,
final BackupMetrics backupMetrics) {
this.accountsManager = accountsManager;
this.backupAuthManager = backupAuthManager;
this.backupManager = backupManager;
this.backupMetrics = backupMetrics;
@ -138,13 +144,22 @@ public class ArchiveController {
@ApiResponse(responseCode = "403", description = "The device did not have permission to set the backup-id. Only the primary device can set the backup-id for an account")
@ApiResponse(responseCode = "429", description = "Rate limited. Too many attempts to change the backup-id have been made")
public CompletionStage<Response> setBackupId(
@Mutable @Auth final AuthenticatedDevice account,
@Auth final AuthenticatedDevice authenticatedDevice,
@Valid @NotNull final SetBackupIdRequest setBackupIdRequest) throws RateLimitExceededException {
return this.backupAuthManager
.commitBackupId(account.getAccount(), account.getAuthenticatedDevice(),
setBackupIdRequest.messagesBackupAuthCredentialRequest,
setBackupIdRequest.mediaBackupAuthCredentialRequest)
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
return accountsManager.getByAccountIdentifierAsync(authenticatedDevice.getAccountIdentifier())
.thenCompose(maybeAccount -> {
final Account account = maybeAccount
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
final Device device = account.getDevice(authenticatedDevice.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return backupAuthManager
.commitBackupId(account, device, setBackupIdRequest.messagesBackupAuthCredentialRequest,
setBackupIdRequest.mediaBackupAuthCredentialRequest)
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
});
}
public record RedeemBackupReceiptRequest(
@ -188,12 +203,17 @@ public class ArchiveController {
@ApiResponse(responseCode = "409", description = "The target account does not have a backup-id commitment")
@ApiResponse(responseCode = "429", description = "Rate limited.")
public CompletionStage<Response> redeemReceipt(
@Mutable @Auth final AuthenticatedDevice account,
@Auth final AuthenticatedDevice authenticatedDevice,
@Valid @NotNull final RedeemBackupReceiptRequest redeemBackupReceiptRequest) {
return this.backupAuthManager.redeemReceipt(
account.getAccount(),
redeemBackupReceiptRequest.receiptCredentialPresentation())
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
return accountsManager.getByAccountIdentifierAsync(authenticatedDevice.getAccountIdentifier())
.thenCompose(maybeAccount -> {
final Account account = maybeAccount
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return backupAuthManager.redeemReceipt(account, redeemBackupReceiptRequest.receiptCredentialPresentation())
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
});
}
public record BackupAuthCredentialsResponse(
@ -252,7 +272,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "404", description = "Could not find an existing blinded backup id")
@ApiResponse(responseCode = "429", description = "Rate limited.")
public CompletionStage<BackupAuthCredentialsResponse> getBackupZKCredentials(
@Mutable @Auth AuthenticatedDevice auth,
@Auth AuthenticatedDevice authenticatedDevice,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@NotNull @QueryParam("redemptionStartSeconds") Long startSeconds,
@NotNull @QueryParam("redemptionEndSeconds") Long endSeconds) {
@ -260,27 +280,33 @@ public class ArchiveController {
final Map<BackupCredentialType, List<BackupAuthCredentialsResponse.BackupAuthCredential>> credentialsByType =
new ConcurrentHashMap<>();
return CompletableFuture.allOf(Arrays.stream(BackupCredentialType.values())
.map(credentialType -> this.backupAuthManager.getBackupAuthCredentials(
auth.getAccount(),
credentialType,
Instant.ofEpochSecond(startSeconds), Instant.ofEpochSecond(endSeconds))
.thenAccept(credentials -> {
backupMetrics.updateGetCredentialCounter(
UserAgentTagUtil.getPlatformTag(userAgent),
credentialType,
credentials.size());
credentialsByType.put(credentialType, credentials.stream()
.map(credential -> new BackupAuthCredentialsResponse.BackupAuthCredential(
credential.credential().serialize(),
credential.redemptionTime().getEpochSecond()))
.toList());
}))
.toArray(CompletableFuture[]::new))
.thenApply(ignored -> new BackupAuthCredentialsResponse(credentialsByType.entrySet().stream()
.collect(Collectors.toMap(
e -> BackupAuthCredentialsResponse.CredentialType.fromLibsignalType(e.getKey()),
Map.Entry::getValue))));
return accountsManager.getByAccountIdentifierAsync(authenticatedDevice.getAccountIdentifier())
.thenCompose(maybeAccount -> {
final Account account = maybeAccount
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return CompletableFuture.allOf(Arrays.stream(BackupCredentialType.values())
.map(credentialType -> this.backupAuthManager.getBackupAuthCredentials(
account,
credentialType,
Instant.ofEpochSecond(startSeconds), Instant.ofEpochSecond(endSeconds))
.thenAccept(credentials -> {
backupMetrics.updateGetCredentialCounter(
UserAgentTagUtil.getPlatformTag(userAgent),
credentialType,
credentials.size());
credentialsByType.put(credentialType, credentials.stream()
.map(credential -> new BackupAuthCredentialsResponse.BackupAuthCredential(
credential.credential().serialize(),
credential.redemptionTime().getEpochSecond()))
.toList());
}))
.toArray(CompletableFuture[]::new))
.thenApply(ignored -> new BackupAuthCredentialsResponse(credentialsByType.entrySet().stream()
.collect(Collectors.toMap(
e -> BackupAuthCredentialsResponse.CredentialType.fromLibsignalType(e.getKey()),
Map.Entry::getValue))));
});
}
@ -343,7 +369,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<ReadAuthResponse> readAuth(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))
@ -395,7 +421,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<BackupInfoResponse> backupInfo(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))
@ -441,7 +467,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "204", description = "The public key was set")
@ApiResponse(responseCode = "429", description = "Rate limited.")
public CompletionStage<Response> setPublicKey(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))
@NotNull
@ -481,7 +507,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<UploadDescriptorResponse> backup(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))
@ -518,7 +544,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<UploadDescriptorResponse> uploadTemporaryAttachment(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@ -606,7 +632,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<CopyMediaResponse> copyMedia(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))
@ -705,7 +731,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<Response> copyMedia(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))
@ -744,7 +770,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<Response> refresh(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))
@ -811,7 +837,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<ListResponse> listMedia(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))
@ -867,7 +893,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<Response> deleteMedia(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))
@ -904,7 +930,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<Response> deleteBackup(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))

View File

@ -26,7 +26,6 @@ import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV3;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.websocket.auth.ReadOnly;
/**
@ -78,11 +77,11 @@ public class AttachmentControllerV4 {
@ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header(
name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public AttachmentDescriptorV3 getAttachmentUploadForm(@ReadOnly @Auth AuthenticatedDevice auth)
public AttachmentDescriptorV3 getAttachmentUploadForm(@Auth AuthenticatedDevice auth)
throws RateLimitExceededException {
rateLimiter.validate(auth.getAccount().getUuid());
rateLimiter.validate(auth.getAccountIdentifier());
final String key = generateAttachmentKey();
final boolean useCdn3 = this.experimentEnrollmentManager.isEnrolled(auth.getAccount().getUuid(), CDN3_EXPERIMENT_NAME);
final boolean useCdn3 = this.experimentEnrollmentManager.isEnrolled(auth.getAccountIdentifier(), CDN3_EXPERIMENT_NAME);
int cdn = useCdn3 ? 3 : 2;
final AttachmentGenerator.Descriptor descriptor = this.attachmentGenerators.get(cdn).generateAttachment(key);
return new AttachmentDescriptorV3(cdn, key, descriptor.headers(), descriptor.signedUploadLocation());

View File

@ -20,7 +20,6 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.entities.CreateCallLinkCredential;
import org.whispersystems.textsecuregcm.entities.GetCreateCallLinkCredentialsRequest;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v1/call-link")
@io.swagger.v3.oas.annotations.tags.Tag(name = "CallLink")
@ -52,11 +51,11 @@ public class CallLinkController {
@ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Ratelimited.")
public CreateCallLinkCredential getCreateAuth(
final @ReadOnly @Auth AuthenticatedDevice auth,
final @Auth AuthenticatedDevice auth,
final @NotNull @Valid GetCreateCallLinkCredentialsRequest request
) throws RateLimitExceededException {
rateLimiters.getCreateCallLinkLimiter().validate(auth.getAccount().getUuid());
rateLimiters.getCreateCallLinkLimiter().validate(auth.getAccountIdentifier());
final Instant truncatedDayTimestamp = Instant.now().truncatedTo(ChronoUnit.DAYS);
@ -68,7 +67,7 @@ public class CallLinkController {
}
return new CreateCallLinkCredential(
createCallLinkCredentialRequest.issueCredential(new ServiceId.Aci(auth.getAccount().getUuid()), truncatedDayTimestamp, genericServerSecretParams).serialize(),
createCallLinkCredentialRequest.issueCredential(new ServiceId.Aci(auth.getAccountIdentifier()), truncatedDayTimestamp, genericServerSecretParams).serialize(),
truncatedDayTimestamp.getEpochSecond()
);
}

View File

@ -18,11 +18,9 @@ import jakarta.ws.rs.Produces;
import jakarta.ws.rs.core.MediaType;
import java.io.IOException;
import java.util.List;
import java.util.UUID;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.CloudflareTurnCredentialsManager;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.websocket.auth.ReadOnly;
@io.swagger.v3.oas.annotations.tags.Tag(name = "Calling")
@Path("/v2/calling")
@ -56,11 +54,10 @@ public class CallRoutingControllerV2 {
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Rate limited.")
public GetCallingRelaysResponse getCallingRelays(final @ReadOnly @Auth AuthenticatedDevice auth)
public GetCallingRelaysResponse getCallingRelays(final @Auth AuthenticatedDevice auth)
throws RateLimitExceededException, IOException {
final UUID aci = auth.getAccount().getUuid();
rateLimiters.getCallEndpointLimiter().validate(aci);
rateLimiters.getCallEndpointLimiter().validate(auth.getAccountIdentifier());
try {
return new GetCallingRelaysResponse(List.of(cloudflareTurnCredentialsManager.retrieveFromCloudflare()));

View File

@ -8,18 +8,18 @@ package org.whispersystems.textsecuregcm.controllers;
import static com.codahale.metrics.MetricRegistry.name;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth;
import io.micrometer.core.instrument.Metrics;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.ws.rs.BadRequestException;
import jakarta.ws.rs.DefaultValue;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.HeaderParam;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.QueryParam;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import java.security.InvalidKeyException;
import java.time.Clock;
import java.time.Duration;
@ -38,13 +38,16 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
import org.whispersystems.textsecuregcm.entities.DeliveryCertificate;
import org.whispersystems.textsecuregcm.entities.GroupCredentials;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v1/certificate")
@Tag(name = "Certificate")
public class CertificateController {
private final AccountsManager accountsManager;
private final CertificateGenerator certificateGenerator;
private final ServerZkAuthOperations serverZkAuthOperations;
private final GenericServerSecretParams genericServerSecretParams;
@ -56,10 +59,13 @@ public class CertificateController {
private static final String INCLUDE_E164_TAG_NAME = "includeE164";
public CertificateController(
final AccountsManager accountsManager,
@Nonnull CertificateGenerator certificateGenerator,
@Nonnull ServerZkAuthOperations serverZkAuthOperations,
@Nonnull GenericServerSecretParams genericServerSecretParams,
@Nonnull Clock clock) {
this.accountsManager = accountsManager;
this.certificateGenerator = Objects.requireNonNull(certificateGenerator);
this.serverZkAuthOperations = Objects.requireNonNull(serverZkAuthOperations);
this.genericServerSecretParams = genericServerSecretParams;
@ -69,23 +75,25 @@ public class CertificateController {
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/delivery")
public DeliveryCertificate getDeliveryCertificate(@ReadOnly @Auth AuthenticatedDevice auth,
public DeliveryCertificate getDeliveryCertificate(@Auth AuthenticatedDevice auth,
@QueryParam("includeE164") @DefaultValue("true") boolean includeE164)
throws InvalidKeyException {
Metrics.counter(GENERATE_DELIVERY_CERTIFICATE_COUNTER_NAME, INCLUDE_E164_TAG_NAME, String.valueOf(includeE164))
.increment();
final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return new DeliveryCertificate(
certificateGenerator.createFor(auth.getAccount(), auth.getAuthenticatedDevice(), includeE164));
certificateGenerator.createFor(account, auth.getDeviceId(), includeE164));
}
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/auth/group")
public GroupCredentials getGroupAuthenticationCredentials(
@ReadOnly @Auth AuthenticatedDevice auth,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent,
@Auth AuthenticatedDevice auth,
@QueryParam("redemptionStartSeconds") long startSeconds,
@QueryParam("redemptionEndSeconds") long endSeconds) {
@ -102,13 +110,16 @@ public class CertificateController {
throw new BadRequestException();
}
final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
final List<GroupCredentials.GroupCredential> credentials = new ArrayList<>();
final List<GroupCredentials.CallLinkAuthCredential> callLinkAuthCredentials = new ArrayList<>();
Instant redemption = redemptionStart;
ServiceId.Aci aci = new ServiceId.Aci(auth.getAccount().getUuid());
ServiceId.Pni pni = new ServiceId.Pni(auth.getAccount().getPhoneNumberIdentifier());
final ServiceId.Aci aci = new ServiceId.Aci(account.getIdentifier(IdentityType.ACI));
final ServiceId.Pni pni = new ServiceId.Pni(account.getIdentifier(IdentityType.PNI));
while (!redemption.isAfter(redemptionEnd)) {
AuthCredentialWithPniResponse authCredentialWithPni = serverZkAuthOperations.issueAuthCredentialWithPniZkc(aci, pni, redemption);

View File

@ -25,6 +25,7 @@ import jakarta.ws.rs.POST;
import jakarta.ws.rs.PUT;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType;
@ -40,12 +41,14 @@ import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker;
import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker.ChallengeConstraints;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
@Path("/v1/challenge")
@Tag(name = "Challenge")
public class ChallengeController {
private final AccountsManager accountsManager;
private final RateLimitChallengeManager rateLimitChallengeManager;
private final ChallengeConstraintChecker challengeConstraintChecker;
@ -53,8 +56,10 @@ public class ChallengeController {
private static final String CHALLENGE_TYPE_TAG = "type";
public ChallengeController(
final AccountsManager accountsManager,
final RateLimitChallengeManager rateLimitChallengeManager,
final ChallengeConstraintChecker challengeConstraintChecker) {
this.accountsManager = accountsManager;
this.rateLimitChallengeManager = rateLimitChallengeManager;
this.challengeConstraintChecker = challengeConstraintChecker;
}
@ -77,15 +82,18 @@ public class ChallengeController {
@ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header(
name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public Response handleChallengeResponse(@ReadOnly @Auth final AuthenticatedDevice auth,
public Response handleChallengeResponse(@Auth final AuthenticatedDevice auth,
@Valid final AnswerChallengeRequest answerRequest,
@Context ContainerRequestContext requestContext,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) throws RateLimitExceededException, IOException {
final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent));
final ChallengeConstraints constraints = challengeConstraintChecker.challengeConstraints(
requestContext, auth.getAccount());
requestContext, account);
try {
if (answerRequest instanceof final AnswerPushChallengeRequest pushChallengeRequest) {
tags = tags.and(CHALLENGE_TYPE_TAG, "push");
@ -93,14 +101,14 @@ public class ChallengeController {
if (!constraints.pushPermitted()) {
return Response.status(429).build();
}
rateLimitChallengeManager.answerPushChallenge(auth.getAccount(), pushChallengeRequest.getChallenge());
rateLimitChallengeManager.answerPushChallenge(account, pushChallengeRequest.getChallenge());
} else if (answerRequest instanceof AnswerCaptchaChallengeRequest captchaChallengeRequest) {
tags = tags.and(CHALLENGE_TYPE_TAG, "captcha");
final String remoteAddress = (String) requestContext.getProperty(
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
boolean success = rateLimitChallengeManager.answerCaptchaChallenge(
auth.getAccount(),
account,
captchaChallengeRequest.getCaptcha(),
remoteAddress,
userAgent,
@ -163,15 +171,18 @@ public class ChallengeController {
@ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header(
name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public Response requestPushChallenge(@ReadOnly @Auth final AuthenticatedDevice auth,
public Response requestPushChallenge(@Auth final AuthenticatedDevice auth,
@Context ContainerRequestContext requestContext) {
final ChallengeConstraints constraints = challengeConstraintChecker.challengeConstraints(
requestContext, auth.getAccount());
final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
final ChallengeConstraints constraints = challengeConstraintChecker.challengeConstraints(requestContext, account);
if (!constraints.pushPermitted()) {
return Response.status(429).build();
}
try {
rateLimitChallengeManager.sendPushChallenge(auth.getAccount());
rateLimitChallengeManager.sendPushChallenge(account);
return Response.status(200).build();
} catch (final NotPushRegisteredException e) {
return Response.status(404).build();

View File

@ -33,6 +33,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.backup.BackupAuthManager;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.devicecheck.AppleDeviceCheckManager;
import org.whispersystems.textsecuregcm.storage.devicecheck.ChallengeNotFoundException;
import org.whispersystems.textsecuregcm.storage.devicecheck.DeviceCheckKeyIdNotFoundException;
@ -41,7 +42,6 @@ import org.whispersystems.textsecuregcm.storage.devicecheck.DuplicatePublicKeyEx
import org.whispersystems.textsecuregcm.storage.devicecheck.RequestReuseException;
import org.whispersystems.textsecuregcm.storage.devicecheck.TooManyKeysException;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.auth.ReadOnly;
/**
* Process platform device attestations.
@ -55,6 +55,7 @@ import org.whispersystems.websocket.auth.ReadOnly;
public class DeviceCheckController {
private final Clock clock;
private final AccountsManager accountsManager;
private final BackupAuthManager backupAuthManager;
private final AppleDeviceCheckManager deviceCheckManager;
private final RateLimiters rateLimiters;
@ -63,12 +64,14 @@ public class DeviceCheckController {
public DeviceCheckController(
final Clock clock,
final AccountsManager accountsManager,
final BackupAuthManager backupAuthManager,
final AppleDeviceCheckManager deviceCheckManager,
final RateLimiters rateLimiters,
final long backupRedemptionLevel,
final Duration backupRedemptionDuration) {
this.clock = clock;
this.accountsManager = accountsManager;
this.backupAuthManager = backupAuthManager;
this.deviceCheckManager = deviceCheckManager;
this.backupRedemptionLevel = backupRedemptionLevel;
@ -94,14 +97,17 @@ public class DeviceCheckController {
@ApiResponse(responseCode = "200", description = "The response body includes a challenge")
@ApiResponse(responseCode = "429", description = "Ratelimited.")
@ManagedAsync
public ChallengeResponse attestChallenge(@ReadOnly @Auth AuthenticatedDevice authenticatedDevice)
public ChallengeResponse attestChallenge(@Auth AuthenticatedDevice authenticatedDevice)
throws RateLimitExceededException {
rateLimiters.forDescriptor(RateLimiters.For.DEVICE_CHECK_CHALLENGE)
.validate(authenticatedDevice.getAccount().getUuid());
.validate(authenticatedDevice.getAccountIdentifier());
final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return new ChallengeResponse(deviceCheckManager.createChallenge(
AppleDeviceCheckManager.ChallengeType.ATTEST,
authenticatedDevice.getAccount()));
account));
}
@PUT
@ -125,7 +131,7 @@ public class DeviceCheckController {
@ApiResponse(responseCode = "409", description = "The provided keyId has already been registered to a different account")
@ManagedAsync
public void attest(
@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice,
@Auth final AuthenticatedDevice authenticatedDevice,
@Valid
@NotNull
@ -135,8 +141,11 @@ public class DeviceCheckController {
@RequestBody(description = "The attestation data, created by [attestKey](https://developer.apple.com/documentation/devicecheck/dcappattestservice/attestkey(_:clientdatahash:completionhandler:))")
@NotNull final byte[] attestation) {
final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
try {
deviceCheckManager.registerAttestation(authenticatedDevice.getAccount(), parseKeyId(keyId), attestation);
deviceCheckManager.registerAttestation(account, parseKeyId(keyId), attestation);
} catch (TooManyKeysException e) {
throw new WebApplicationException(Response.status(413).build());
} catch (ChallengeNotFoundException e) {
@ -166,17 +175,19 @@ public class DeviceCheckController {
@ApiResponse(responseCode = "429", description = "Ratelimited.")
@ManagedAsync
public ChallengeResponse assertChallenge(
@ReadOnly @Auth AuthenticatedDevice authenticatedDevice,
@Auth AuthenticatedDevice authenticatedDevice,
@Parameter(schema = @Schema(description = "The type of action you will make an assertion for",
allowableValues = {"backup"},
implementation = String.class))
@QueryParam("action") Action action) throws RateLimitExceededException {
rateLimiters.forDescriptor(RateLimiters.For.DEVICE_CHECK_CHALLENGE)
.validate(authenticatedDevice.getAccount().getUuid());
return new ChallengeResponse(
deviceCheckManager.createChallenge(toChallengeType(action),
authenticatedDevice.getAccount()));
.validate(authenticatedDevice.getAccountIdentifier());
final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return new ChallengeResponse(deviceCheckManager.createChallenge(toChallengeType(action), account));
}
@POST
@ -199,7 +210,7 @@ public class DeviceCheckController {
@ApiResponse(responseCode = "401", description = "The assertion could not be verified")
@ManagedAsync
public void assertion(
@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice,
@Auth final AuthenticatedDevice authenticatedDevice,
@Valid
@NotNull
@ -218,9 +229,12 @@ public class DeviceCheckController {
@RequestBody(description = "The assertion created by [generateAssertion](https://developer.apple.com/documentation/devicecheck/dcappattestservice/generateassertion(_:clientdatahash:completionhandler:))")
@NotNull final byte[] assertion) {
final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
try {
deviceCheckManager.validateAssert(
authenticatedDevice.getAccount(),
account,
parseKeyId(keyId),
toChallengeType(request.assertionRequest().action()),
request.assertionRequest().challenge(),
@ -237,7 +251,7 @@ public class DeviceCheckController {
// The request assertion was validated, execute it
switch (request.assertionRequest().action()) {
case BACKUP -> backupAuthManager.extendBackupVoucher(
authenticatedDevice.getAccount(),
account,
new Account.BackupVoucher(backupRedemptionLevel, clock.instant().plus(backupRedemptionDuration)))
.join();
}

View File

@ -34,7 +34,6 @@ import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.QueryParam;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import java.time.Duration;
@ -51,11 +50,9 @@ import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.glassfish.jersey.server.ContainerRequest;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader;
import org.whispersystems.textsecuregcm.auth.ChangesLinkedDevices;
import org.whispersystems.textsecuregcm.auth.LinkedDeviceRefreshRequirementProvider;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest;
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
@ -87,11 +84,10 @@ import org.whispersystems.textsecuregcm.util.DeviceCapabilityAdapter;
import org.whispersystems.textsecuregcm.util.EnumMapUtil;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.LinkDeviceToken;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.websocket.auth.Mutable;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v1/devices")
@Tag(name = "Devices")
@ -152,10 +148,10 @@ public class DeviceController {
@GET
@Produces(MediaType.APPLICATION_JSON)
public DeviceInfoList getDevices(@ReadOnly @Auth AuthenticatedDevice auth) {
public DeviceInfoList getDevices(@Auth AuthenticatedDevice auth) {
// Devices may change their own names (and primary devices may change the names of linked devices) and so the device
// state associated with the authenticated account may be stale. Fetch a fresh copy to compensate.
return accounts.getByAccountIdentifier(auth.getAccount().getIdentifier(IdentityType.ACI))
return accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.map(account -> new DeviceInfoList(account.getDevices().stream()
.map(DeviceInfo::forDevice)
.toList()))
@ -166,9 +162,8 @@ public class DeviceController {
@Produces(MediaType.APPLICATION_JSON)
@Path("/{device_id}")
@ChangesLinkedDevices
public void removeDevice(@Mutable @Auth AuthenticatedDevice auth, @PathParam("device_id") byte deviceId) {
if (auth.getAuthenticatedDevice().getId() != Device.PRIMARY_ID &&
auth.getAuthenticatedDevice().getId() != deviceId) {
public void removeDevice(@Auth AuthenticatedDevice auth, @PathParam("device_id") byte deviceId) {
if (auth.getDeviceId() != Device.PRIMARY_ID && auth.getDeviceId() != deviceId) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}
@ -176,13 +171,16 @@ public class DeviceController {
throw new ForbiddenException();
}
accounts.removeDevice(auth.getAccount(), deviceId).join();
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
accounts.removeDevice(account, deviceId).join();
}
/**
* Generates a signed device-linking token. Generally, primary devices will include the signed device-linking token in
* a provisioning message to a new device, and then the new device will include the token in its request to
* {@link #linkDevice(BasicAuthorizationHeader, String, LinkDeviceRequest, ContainerRequest)}.
* {@link #linkDevice(BasicAuthorizationHeader, String, LinkDeviceRequest)}.
*
* @param auth the authenticated account/device
*
@ -207,10 +205,11 @@ public class DeviceController {
@ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header(
name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public LinkDeviceToken createDeviceToken(@ReadOnly @Auth AuthenticatedDevice auth)
public LinkDeviceToken createDeviceToken(@Auth AuthenticatedDevice auth)
throws RateLimitExceededException, DeviceLimitExceededException {
final Account account = auth.getAccount();
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
rateLimiters.getAllocateDeviceLimiter().validate(account.getUuid());
@ -224,7 +223,7 @@ public class DeviceController {
throw new DeviceLimitExceededException(account.getDevices().size(), maxDeviceLimit);
}
if (auth.getAuthenticatedDevice().getId() != Device.PRIMARY_ID) {
if (auth.getDeviceId() != Device.PRIMARY_ID) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}
@ -252,8 +251,7 @@ public class DeviceController {
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public LinkDeviceResponse linkDevice(@HeaderParam(HttpHeaders.AUTHORIZATION) BasicAuthorizationHeader authorizationHeader,
@HeaderParam(HttpHeaders.USER_AGENT) @Nullable String userAgent,
@NotNull @Valid LinkDeviceRequest linkDeviceRequest,
@Context ContainerRequest containerRequest)
@NotNull @Valid LinkDeviceRequest linkDeviceRequest)
throws RateLimitExceededException, DeviceLimitExceededException {
final Account account = accounts.checkDeviceLinkingToken(linkDeviceRequest.verificationCode())
@ -279,11 +277,6 @@ public class DeviceController {
throw new WebApplicationException(Response.status(422).build());
}
// Normally, the "do we need to refresh somebody's websockets" listener can do this on its own. In this case,
// we're not using the conventional authentication system, and so we need to give it a hint so it knows who the
// active user is and what their device states look like.
LinkedDeviceRefreshRequirementProvider.setAccount(containerRequest, account);
final int maxDeviceLimit = maxDeviceConfiguration.getOrDefault(account.getNumber(), MAX_DEVICES);
if (account.getDevices().size() >= maxDeviceLimit) {
@ -351,7 +344,7 @@ public class DeviceController {
@ApiResponse(responseCode = "400", description = "The given token identifier or timeout was invalid")
@ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay")
public CompletionStage<Response> waitForLinkedDevice(
@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice,
@Auth final AuthenticatedDevice authenticatedDevice,
@PathParam("tokenIdentifier")
@Schema(description = "A 'link device' token identifier provided by the 'create link device token' endpoint")
@ -374,12 +367,18 @@ public class DeviceController {
final AtomicInteger linkedDeviceListenerCounter = getCounterForLinkedDeviceListeners(userAgent);
linkedDeviceListenerCounter.incrementAndGet();
return rateLimiters.getWaitForLinkedDeviceLimiter()
.validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI))
.thenCompose(ignored -> persistentTimer.start(WAIT_FOR_LINKED_DEVICE_TIMER_NAMESPACE, tokenIdentifier))
.thenCompose(sample -> accounts.waitForNewLinkedDevice(
authenticatedDevice.getAccount().getUuid(),
authenticatedDevice.getAuthenticatedDevice(),
return rateLimiters.getWaitForLinkedDeviceLimiter().validateAsync(authenticatedDevice.getAccountIdentifier())
.thenCompose(ignored -> accounts.getByAccountIdentifierAsync(authenticatedDevice.getAccountIdentifier()))
.thenCompose(maybeAccount -> {
final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return persistentTimer.start(WAIT_FOR_LINKED_DEVICE_TIMER_NAMESPACE, tokenIdentifier)
.thenApply(sample -> new Pair<>(account, sample));
})
.thenCompose(accountAndSample -> accounts.waitForNewLinkedDevice(
authenticatedDevice.getAccountIdentifier(),
accountAndSample.first().getDevice(authenticatedDevice.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)),
tokenIdentifier,
Duration.ofSeconds(timeoutSeconds))
.thenApply(maybeDeviceInfo -> maybeDeviceInfo
@ -391,7 +390,7 @@ public class DeviceController {
linkedDeviceListenerCounter.decrementAndGet();
if (response != null && response.getStatus() == Response.Status.OK.getStatusCode()) {
sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME)
accountAndSample.second().stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME)
.publishPercentileHistogram(true)
.tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)))
.register(Metrics.globalRegistry));
@ -410,14 +409,15 @@ public class DeviceController {
@PUT
@Produces(MediaType.APPLICATION_JSON)
@Path("/capabilities")
public void setCapabilities(@Mutable @Auth final AuthenticatedDevice auth,
public void setCapabilities(@Auth final AuthenticatedDevice auth,
@NotNull
final Map<String, Boolean> capabilities) {
assert (auth.getAuthenticatedDevice() != null);
final byte deviceId = auth.getAuthenticatedDevice().getId();
accounts.updateDevice(auth.getAccount(), deviceId,
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
accounts.updateDevice(account, auth.getDeviceId(),
d -> d.setCapabilities(DeviceCapabilityAdapter.mapToSet(capabilities)));
}
@ -435,12 +435,13 @@ public class DeviceController {
@ApiResponse(responseCode = "200", description = "Public key stored successfully")
@ApiResponse(responseCode = "401", description = "Account authentication check failed")
@ApiResponse(responseCode = "422", description = "Invalid request format")
public CompletableFuture<Void> setPublicKey(@Mutable @Auth final AuthenticatedDevice auth,
public CompletableFuture<Void> setPublicKey(@Auth final AuthenticatedDevice auth,
final SetPublicKeyRequest setPublicKeyRequest) {
return clientPublicKeysManager.setPublicKey(auth.getAccount(),
auth.getAuthenticatedDevice().getId(),
setPublicKeyRequest.publicKey());
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return clientPublicKeysManager.setPublicKey(account, auth.getDeviceId(), setPublicKeyRequest.publicKey());
}
private static boolean isCapabilityDowngrade(final Account account, final Set<DeviceCapability> capabilities) {
@ -531,15 +532,21 @@ public class DeviceController {
@ApiResponse(responseCode = "204", description = "Success")
@ApiResponse(responseCode = "422", description = "The request object could not be parsed or was otherwise invalid")
@ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay")
public CompletionStage<Void> recordTransferArchiveUploaded(@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice,
public CompletionStage<Void> recordTransferArchiveUploaded(@Auth final AuthenticatedDevice authenticatedDevice,
@NotNull @Valid final TransferArchiveUploadedRequest transferArchiveUploadedRequest) {
return rateLimiters.getUploadTransferArchiveLimiter()
.validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI))
.thenCompose(ignored -> accounts.recordTransferArchiveUpload(authenticatedDevice.getAccount(),
transferArchiveUploadedRequest.destinationDeviceId(),
Instant.ofEpochMilli(transferArchiveUploadedRequest.destinationDeviceCreated()),
transferArchiveUploadedRequest.transferArchive()));
.validateAsync(authenticatedDevice.getAccountIdentifier())
.thenCompose(ignored -> accounts.getByAccountIdentifierAsync(authenticatedDevice.getAccountIdentifier()))
.thenCompose(maybeAccount -> {
final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return accounts.recordTransferArchiveUpload(account,
transferArchiveUploadedRequest.destinationDeviceId(),
Instant.ofEpochMilli(transferArchiveUploadedRequest.destinationDeviceCreated()),
transferArchiveUploadedRequest.transferArchive());
});
}
@GET
@ -558,7 +565,7 @@ public class DeviceController {
@ApiResponse(responseCode = "204", description = "No transfer archive was uploaded before the call completed; clients may repeat the call to continue waiting")
@ApiResponse(responseCode = "400", description = "The given timeout was invalid")
@ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay")
public CompletionStage<Response> waitForTransferArchive(@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice,
public CompletionStage<Response> waitForTransferArchive(@Auth final AuthenticatedDevice authenticatedDevice,
@QueryParam("timeout")
@DefaultValue("30")
@ -575,24 +582,30 @@ public class DeviceController {
@HeaderParam(HttpHeaders.USER_AGENT) @Nullable String userAgent) {
final String rateLimiterKey = authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI) +
":" + authenticatedDevice.getAuthenticatedDevice().getId();
final String rateLimiterKey = authenticatedDevice.getAccountIdentifier() + ":" + authenticatedDevice.getDeviceId();
return rateLimiters.getWaitForTransferArchiveLimiter().validateAsync(rateLimiterKey)
.thenCompose(ignored -> persistentTimer.start(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAMESPACE, rateLimiterKey))
.thenCompose(sample -> accounts.waitForTransferArchive(authenticatedDevice.getAccount(),
authenticatedDevice.getAuthenticatedDevice(),
.thenCompose(ignored -> accounts.getByAccountIdentifierAsync(authenticatedDevice.getAccountIdentifier()))
.thenCompose(maybeAccount -> {
final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return persistentTimer.start(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAMESPACE, rateLimiterKey)
.thenApply(sample -> new Pair<>(account, sample));
})
.thenCompose(accountAndSample -> accounts.waitForTransferArchive(accountAndSample.first(),
accountAndSample.first().getDevice(authenticatedDevice.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)),
Duration.ofSeconds(timeoutSeconds))
.thenApply(maybeTransferArchive -> maybeTransferArchive
.map(transferArchive -> Response.status(Response.Status.OK).entity(transferArchive).build())
.orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build()))
.whenComplete((response, throwable) -> {
if (response != null && response.getStatus() == Response.Status.OK.getStatusCode()) {
sample.stop(Timer.builder(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAME)
accountAndSample.second().stop(Timer.builder(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAME)
.publishPercentileHistogram(true)
.tags(Tags.of(
UserAgentTagUtil.getPlatformTag(userAgent),
primaryPlatformTag(authenticatedDevice.getAccount())))
primaryPlatformTag(accountAndSample.first())))
.register(Metrics.globalRegistry));
}
}));

View File

@ -14,12 +14,10 @@ import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.core.MediaType;
import java.time.Clock;
import java.util.UUID;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
import org.whispersystems.textsecuregcm.configuration.DirectoryV2ClientConfiguration;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v2/directory")
@Tag(name = "Directory")
@ -57,8 +55,7 @@ public class DirectoryV2Controller {
"""
)
@ApiResponse(responseCode = "200", description = "`JSON` with generated credentials.", useReturnTypeSchema = true)
public ExternalServiceCredentials getAuthToken(final @ReadOnly @Auth AuthenticatedDevice auth) {
final UUID uuid = auth.getAccount().getUuid();
return directoryServiceTokenGenerator.generateForUuid(uuid);
public ExternalServiceCredentials getAuthToken(final @Auth AuthenticatedDevice auth) {
return directoryServiceTokenGenerator.generateForUuid(auth.getAccountIdentifier());
}
}

View File

@ -15,6 +15,7 @@ import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.Response.Status;
@ -33,10 +34,10 @@ import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration;
import org.whispersystems.textsecuregcm.entities.RedeemReceiptRequest;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountBadge;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.RedeemedReceiptsManager;
import org.whispersystems.websocket.auth.Mutable;
@Path("/v1/donation")
@Tag(name = "Donations")
@ -86,7 +87,7 @@ public class DonationController {
""")
@ApiResponse(responseCode = "429", description = "Rate limited.")
public CompletionStage<Response> redeemReceipt(
@Mutable @Auth final AuthenticatedDevice auth,
@Auth final AuthenticatedDevice auth,
@NotNull @Valid final RedeemReceiptRequest request) {
return CompletableFuture.supplyAsync(() -> {
ReceiptCredentialPresentation receiptCredentialPresentation;
@ -118,23 +119,29 @@ public class DonationController {
.type(MediaType.TEXT_PLAIN_TYPE)
.build());
}
return redeemedReceiptsManager.put(
receiptSerial, receiptExpiration.getEpochSecond(), receiptLevel, auth.getAccount().getUuid())
.thenCompose(receiptMatched -> {
if (!receiptMatched) {
return CompletableFuture.completedFuture(Response.status(Status.BAD_REQUEST)
.entity("receipt serial is already redeemed")
.type(MediaType.TEXT_PLAIN_TYPE)
.build());
}
return accountsManager.updateAsync(auth.getAccount(), a -> {
a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible()));
if (request.isPrimary()) {
a.makeBadgePrimaryIfExists(clock, badgeId);
return accountsManager.getByAccountIdentifierAsync(auth.getAccountIdentifier())
.thenCompose(maybeAccount -> {
final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
return redeemedReceiptsManager.put(
receiptSerial, receiptExpiration.getEpochSecond(), receiptLevel, auth.getAccountIdentifier())
.thenCompose(receiptMatched -> {
if (!receiptMatched) {
return CompletableFuture.completedFuture(Response.status(Status.BAD_REQUEST)
.entity("receipt serial is already redeemed")
.type(MediaType.TEXT_PLAIN_TYPE)
.build());
}
})
.thenApply(ignored -> Response.ok().build());
return accountsManager.updateAsync(account, a -> {
a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible()));
if (request.isPrimary()) {
a.makeBadgePrimaryIfExists(clock, badgeId);
}
})
.thenApply(ignored -> Response.ok().build());
});
});
}).thenCompose(Function.identity());
}

View File

@ -5,9 +5,8 @@
package org.whispersystems.textsecuregcm.controllers;
import org.whispersystems.textsecuregcm.auth.TurnToken;
import java.util.List;
import org.whispersystems.textsecuregcm.auth.TurnToken;
public record GetCallingRelaysResponse(List<TurnToken> relays) {
}

View File

@ -23,7 +23,6 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.websocket.session.WebSocketSession;
import org.whispersystems.websocket.session.WebSocketSessionContext;
@ -45,16 +44,16 @@ public class KeepAliveController {
}
@GET
public Response getKeepAlive(@ReadOnly @Auth Optional<AuthenticatedDevice> maybeAuth,
public Response getKeepAlive(@Auth Optional<AuthenticatedDevice> maybeAuth,
@WebSocketSession WebSocketSessionContext context) {
maybeAuth.ifPresent(auth -> {
if (!webSocketConnectionEventManager.isLocallyPresent(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId())) {
if (!webSocketConnectionEventManager.isLocallyPresent(auth.getAccountIdentifier(), auth.getDeviceId())) {
final Duration age = Duration.between(context.getClient().getCreated(), Instant.now());
logger.debug("***** No local subscription found for {}::{}; age = {}ms, User-Agent = {}",
auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), age.toMillis(),
auth.getAccountIdentifier(), auth.getDeviceId(), age.toMillis(),
context.getClient().getUserAgent());
context.getClient().close(1000, "OK");

View File

@ -49,7 +49,6 @@ import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceCl
import org.whispersystems.textsecuregcm.limits.RateLimitedByIp;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v1/key-transparency")
@Tag(name = "KeyTransparency")
@ -90,7 +89,7 @@ public class KeyTransparencyController {
@RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_SEARCH_PER_IP)
@Produces(MediaType.APPLICATION_JSON)
public KeyTransparencySearchResponse search(
@ReadOnly @Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@NotNull @Valid final KeyTransparencySearchRequest request) {
// Disallow clients from making authenticated requests to this endpoint
@ -142,7 +141,7 @@ public class KeyTransparencyController {
@RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_MONITOR_PER_IP)
@Produces(MediaType.APPLICATION_JSON)
public KeyTransparencyMonitorResponse monitor(
@ReadOnly @Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@NotNull @Valid final KeyTransparencyMonitorRequest request) {
// Disallow clients from making authenticated requests to this endpoint
@ -204,7 +203,7 @@ public class KeyTransparencyController {
@RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_DISTINGUISHED_PER_IP)
@Produces(MediaType.APPLICATION_JSON)
public KeyTransparencyDistinguishedKeyResponse getDistinguishedKey(
@ReadOnly @Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@Parameter(description = "The distinguished tree head size returned by a previously verified call")
@QueryParam("lastTreeHeadSize") @Valid final Optional<@Positive Long> lastTreeHeadSize) {

View File

@ -74,7 +74,6 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.auth.ReadOnly;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v2/keys")
@ -111,16 +110,21 @@ public class KeysController {
description = "Gets the number of one-time prekeys uploaded for this device and still available")
@ApiResponse(responseCode = "200", description = "Body contains the number of available one-time prekeys for the device.", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
public CompletableFuture<PreKeyCount> getStatus(@ReadOnly @Auth final AuthenticatedDevice auth,
public CompletableFuture<PreKeyCount> getStatus(@Auth final AuthenticatedDevice auth,
@QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) {
final CompletableFuture<Integer> ecCountFuture =
keysManager.getEcCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId());
return accounts.getByAccountIdentifierAsync(auth.getAccountIdentifier())
.thenCompose(maybeAccount -> {
final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
final CompletableFuture<Integer> pqCountFuture =
keysManager.getPqCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId());
final CompletableFuture<Integer> ecCountFuture =
keysManager.getEcCount(account.getIdentifier(identityType), auth.getDeviceId());
return ecCountFuture.thenCombine(pqCountFuture, PreKeyCount::new);
final CompletableFuture<Integer> pqCountFuture =
keysManager.getPqCount(account.getIdentifier(identityType), auth.getDeviceId());
return ecCountFuture.thenCombine(pqCountFuture, PreKeyCount::new);
});
}
@PUT
@ -132,7 +136,7 @@ public class KeysController {
@ApiResponse(responseCode = "403", description = "Attempt to change identity key from a non-primary device.")
@ApiResponse(responseCode = "422", description = "Invalid request format.")
public CompletableFuture<Response> setKeys(
@ReadOnly @Auth final AuthenticatedDevice auth,
@Auth final AuthenticatedDevice auth,
@RequestBody @NotNull @Valid final SetKeysRequest setKeysRequest,
@Parameter(allowEmptyValue=true)
@ -143,63 +147,70 @@ public class KeysController {
@QueryParam("identity") @DefaultValue("aci") final IdentityType identityType,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) {
final Account account = auth.getAccount();
final Device device = auth.getAuthenticatedDevice();
final UUID identifier = account.getIdentifier(identityType);
return accounts.getByAccountIdentifierAsync(auth.getAccountIdentifier())
.thenCompose(maybeAccount -> {
final Account account = maybeAccount
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
checkSignedPreKeySignatures(setKeysRequest, account.getIdentityKey(identityType), userAgent);
final Device device = account.getDevice(auth.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent);
final Tag primaryDeviceTag = Tag.of(PRIMARY_DEVICE_TAG_NAME, String.valueOf(auth.getAuthenticatedDevice().isPrimary()));
final Tag identityTypeTag = Tag.of(IDENTITY_TYPE_TAG_NAME, identityType.name());
final UUID identifier = account.getIdentifier(identityType);
final List<CompletableFuture<Void>> storeFutures = new ArrayList<>(4);
checkSignedPreKeySignatures(setKeysRequest, account.getIdentityKey(identityType), userAgent);
if (!setKeysRequest.preKeys().isEmpty()) {
final Tags tags = Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "ec"));
final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent);
final Tag primaryDeviceTag = Tag.of(PRIMARY_DEVICE_TAG_NAME, String.valueOf(auth.getDeviceId() == Device.PRIMARY_ID));
final Tag identityTypeTag = Tag.of(IDENTITY_TYPE_TAG_NAME, identityType.name());
Metrics.counter(STORE_KEYS_COUNTER_NAME, tags).increment();
final List<CompletableFuture<Void>> storeFutures = new ArrayList<>(4);
DistributionSummary.builder(STORE_KEY_BUNDLE_SIZE_DISTRIBUTION_NAME)
.tags(tags)
.publishPercentileHistogram()
.register(Metrics.globalRegistry)
.record(setKeysRequest.preKeys().size());
if (!setKeysRequest.preKeys().isEmpty()) {
final Tags tags = Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "ec"));
storeFutures.add(keysManager.storeEcOneTimePreKeys(identifier, device.getId(), setKeysRequest.preKeys()));
}
Metrics.counter(STORE_KEYS_COUNTER_NAME, tags).increment();
if (setKeysRequest.signedPreKey() != null) {
Metrics.counter(STORE_KEYS_COUNTER_NAME,
Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "ec-signed")))
.increment();
DistributionSummary.builder(STORE_KEY_BUNDLE_SIZE_DISTRIBUTION_NAME)
.tags(tags)
.publishPercentileHistogram()
.register(Metrics.globalRegistry)
.record(setKeysRequest.preKeys().size());
storeFutures.add(keysManager.storeEcSignedPreKeys(identifier, device.getId(), setKeysRequest.signedPreKey()));
}
storeFutures.add(keysManager.storeEcOneTimePreKeys(identifier, device.getId(), setKeysRequest.preKeys()));
}
if (!setKeysRequest.pqPreKeys().isEmpty()) {
final Tags tags = Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "kyber"));
Metrics.counter(STORE_KEYS_COUNTER_NAME, tags).increment();
if (setKeysRequest.signedPreKey() != null) {
Metrics.counter(STORE_KEYS_COUNTER_NAME,
Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "ec-signed")))
.increment();
DistributionSummary.builder(STORE_KEY_BUNDLE_SIZE_DISTRIBUTION_NAME)
.tags(tags)
.publishPercentileHistogram()
.register(Metrics.globalRegistry)
.record(setKeysRequest.pqPreKeys().size());
storeFutures.add(keysManager.storeEcSignedPreKeys(identifier, device.getId(), setKeysRequest.signedPreKey()));
}
storeFutures.add(keysManager.storeKemOneTimePreKeys(identifier, device.getId(), setKeysRequest.pqPreKeys()));
}
if (!setKeysRequest.pqPreKeys().isEmpty()) {
final Tags tags = Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "kyber"));
Metrics.counter(STORE_KEYS_COUNTER_NAME, tags).increment();
if (setKeysRequest.pqLastResortPreKey() != null) {
Metrics.counter(STORE_KEYS_COUNTER_NAME,
Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "kyber-last-resort")))
.increment();
DistributionSummary.builder(STORE_KEY_BUNDLE_SIZE_DISTRIBUTION_NAME)
.tags(tags)
.publishPercentileHistogram()
.register(Metrics.globalRegistry)
.record(setKeysRequest.pqPreKeys().size());
storeFutures.add(keysManager.storePqLastResort(identifier, device.getId(), setKeysRequest.pqLastResortPreKey()));
}
storeFutures.add(keysManager.storeKemOneTimePreKeys(identifier, device.getId(), setKeysRequest.pqPreKeys()));
}
return CompletableFuture.allOf(storeFutures.toArray(EMPTY_FUTURE_ARRAY))
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
if (setKeysRequest.pqLastResortPreKey() != null) {
Metrics.counter(STORE_KEYS_COUNTER_NAME,
Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "kyber-last-resort")))
.increment();
storeFutures.add(keysManager.storePqLastResort(identifier, device.getId(), setKeysRequest.pqLastResortPreKey()));
}
return CompletableFuture.allOf(storeFutures.toArray(EMPTY_FUTURE_ARRAY))
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
});
}
private void checkSignedPreKeySignatures(final SetKeysRequest setKeysRequest,
@ -253,64 +264,69 @@ public class KeysController {
""")
@ApiResponse(responseCode = "422", description = "Invalid request format")
public CompletableFuture<Response> checkKeys(
@ReadOnly @Auth final AuthenticatedDevice auth,
@Auth final AuthenticatedDevice auth,
@RequestBody @NotNull @Valid final CheckKeysRequest checkKeysRequest) {
final UUID identifier = auth.getAccount().getIdentifier(checkKeysRequest.identityType());
final byte deviceId = auth.getAuthenticatedDevice().getId();
return accounts.getByAccountIdentifierAsync(auth.getAccountIdentifier())
.thenCompose(maybeAccount -> {
final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
final CompletableFuture<Optional<ECSignedPreKey>> ecSignedPreKeyFuture =
keysManager.getEcSignedPreKey(identifier, deviceId);
final UUID identifier = account.getIdentifier(checkKeysRequest.identityType());
final byte deviceId = auth.getDeviceId();
final CompletableFuture<Optional<KEMSignedPreKey>> lastResortKeyFuture =
keysManager.getLastResort(identifier, deviceId);
final CompletableFuture<Optional<ECSignedPreKey>> ecSignedPreKeyFuture =
keysManager.getEcSignedPreKey(identifier, deviceId);
return CompletableFuture.allOf(ecSignedPreKeyFuture, lastResortKeyFuture)
.thenApply(ignored -> {
final Optional<ECSignedPreKey> maybeSignedPreKey = ecSignedPreKeyFuture.join();
final Optional<KEMSignedPreKey> maybeLastResortKey = lastResortKeyFuture.join();
final CompletableFuture<Optional<KEMSignedPreKey>> lastResortKeyFuture =
keysManager.getLastResort(identifier, deviceId);
final boolean digestsMatch;
return CompletableFuture.allOf(ecSignedPreKeyFuture, lastResortKeyFuture)
.thenApply(ignored -> {
final Optional<ECSignedPreKey> maybeSignedPreKey = ecSignedPreKeyFuture.join();
final Optional<KEMSignedPreKey> maybeLastResortKey = lastResortKeyFuture.join();
if (maybeSignedPreKey.isPresent() && maybeLastResortKey.isPresent()) {
final IdentityKey identityKey = auth.getAccount().getIdentityKey(checkKeysRequest.identityType());
final ECSignedPreKey ecSignedPreKey = maybeSignedPreKey.get();
final KEMSignedPreKey lastResortKey = maybeLastResortKey.get();
final boolean digestsMatch;
final MessageDigest messageDigest;
if (maybeSignedPreKey.isPresent() && maybeLastResortKey.isPresent()) {
final IdentityKey identityKey = account.getIdentityKey(checkKeysRequest.identityType());
final ECSignedPreKey ecSignedPreKey = maybeSignedPreKey.get();
final KEMSignedPreKey lastResortKey = maybeLastResortKey.get();
try {
messageDigest = MessageDigest.getInstance("SHA-256");
} catch (final NoSuchAlgorithmException e) {
throw new AssertionError("Every implementation of the Java platform is required to support SHA-256", e);
}
final MessageDigest messageDigest;
messageDigest.update(identityKey.serialize());
try {
messageDigest = MessageDigest.getInstance("SHA-256");
} catch (final NoSuchAlgorithmException e) {
throw new AssertionError("Every implementation of the Java platform is required to support SHA-256", e);
}
{
final ByteBuffer ecSignedPreKeyIdBuffer = ByteBuffer.allocate(Long.BYTES);
ecSignedPreKeyIdBuffer.putLong(ecSignedPreKey.keyId());
ecSignedPreKeyIdBuffer.flip();
messageDigest.update(identityKey.serialize());
messageDigest.update(ecSignedPreKeyIdBuffer);
messageDigest.update(ecSignedPreKey.serializedPublicKey());
}
{
final ByteBuffer ecSignedPreKeyIdBuffer = ByteBuffer.allocate(Long.BYTES);
ecSignedPreKeyIdBuffer.putLong(ecSignedPreKey.keyId());
ecSignedPreKeyIdBuffer.flip();
{
final ByteBuffer lastResortKeyIdBuffer = ByteBuffer.allocate(Long.BYTES);
lastResortKeyIdBuffer.putLong(lastResortKey.keyId());
lastResortKeyIdBuffer.flip();
messageDigest.update(ecSignedPreKeyIdBuffer);
messageDigest.update(ecSignedPreKey.serializedPublicKey());
}
messageDigest.update(lastResortKeyIdBuffer);
messageDigest.update(lastResortKey.serializedPublicKey());
}
{
final ByteBuffer lastResortKeyIdBuffer = ByteBuffer.allocate(Long.BYTES);
lastResortKeyIdBuffer.putLong(lastResortKey.keyId());
lastResortKeyIdBuffer.flip();
digestsMatch = MessageDigest.isEqual(messageDigest.digest(), checkKeysRequest.digest());
} else {
digestsMatch = false;
}
messageDigest.update(lastResortKeyIdBuffer);
messageDigest.update(lastResortKey.serializedPublicKey());
}
return Response.status(digestsMatch ? Response.Status.OK : Response.Status.CONFLICT).build();
digestsMatch = MessageDigest.isEqual(messageDigest.digest(), checkKeysRequest.digest());
} else {
digestsMatch = false;
}
return Response.status(digestsMatch ? Response.Status.OK : Response.Status.CONFLICT).build();
});
});
}
@ -327,7 +343,7 @@ public class KeysController {
name = "Retry-After",
description = "If present, a positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public PreKeyResponse getDeviceKeys(
@ReadOnly @Auth Optional<AuthenticatedDevice> auth,
@Auth Optional<AuthenticatedDevice> maybeAuthenticatedDevice,
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@HeaderParam(HeaderUtils.GROUP_SEND_TOKEN) Optional<GroupSendTokenHeader> groupSendToken,
@ -340,15 +356,18 @@ public class KeysController {
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent)
throws RateLimitExceededException {
if (auth.isEmpty() && accessKey.isEmpty() && groupSendToken.isEmpty()) {
if (maybeAuthenticatedDevice.isEmpty() && accessKey.isEmpty() && groupSendToken.isEmpty()) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}
final Optional<Account> account = auth.map(AuthenticatedDevice::getAccount);
final Optional<Account> account = maybeAuthenticatedDevice
.map(authenticatedDevice -> accounts.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)));
final Optional<Account> maybeTarget = accounts.getByServiceIdentifier(targetIdentifier);
if (groupSendToken.isPresent()) {
if (auth.isPresent() || accessKey.isPresent()) {
if (maybeAuthenticatedDevice.isPresent() || accessKey.isPresent()) {
throw new BadRequestException();
}
try {
@ -364,7 +383,7 @@ public class KeysController {
if (account.isPresent()) {
rateLimiters.getPreKeysLimiter().validate(
account.get().getUuid() + "." + auth.get().getAuthenticatedDevice().getId() + "__" + targetIdentifier.uuid()
account.get().getUuid() + "." + maybeAuthenticatedDevice.get().getDeviceId() + "__" + targetIdentifier.uuid()
+ "." + deviceId);
}

View File

@ -105,6 +105,7 @@ import org.whispersystems.textsecuregcm.spam.SpamChecker;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
@ -112,7 +113,6 @@ import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
import org.whispersystems.websocket.WebsocketHeaders;
import org.whispersystems.websocket.auth.ReadOnly;
import reactor.core.scheduler.Scheduler;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@ -236,7 +236,7 @@ public class MessageController {
@ApiResponse(
responseCode="428",
description="The sender should complete a challenge before proceeding")
public Response sendMessage(@ReadOnly @Auth final Optional<AuthenticatedDevice> source,
public Response sendMessage(@Auth final Optional<AuthenticatedDevice> source,
@Parameter(description="The recipient's unidentified access key")
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) final Optional<Anonymous> accessKey,
@ -274,12 +274,14 @@ public class MessageController {
sendStoryMessage(destinationIdentifier, messages, context);
} else if (source.isPresent()) {
final AuthenticatedDevice authenticatedDevice = source.get();
final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
if (authenticatedDevice.getAccount().isIdentifiedBy(destinationIdentifier)) {
if (account.isIdentifiedBy(destinationIdentifier)) {
needsSync = false;
sendSyncMessage(source.get(), destinationIdentifier, messages, context);
sendSyncMessage(source.get(), account, destinationIdentifier, messages, context);
} else {
needsSync = authenticatedDevice.getAccount().getDevices().size() > 1;
needsSync = account.getDevices().size() > 1;
sendIdentifiedSenderIndividualMessage(authenticatedDevice, destinationIdentifier, messages, context);
}
} else {
@ -302,7 +304,7 @@ public class MessageController {
final Account destination =
accountsManager.getByServiceIdentifier(destinationIdentifier).orElseThrow(NotFoundException::new);
rateLimiters.getMessagesLimiter().validate(source.getAccount().getUuid(), destination.getUuid());
rateLimiters.getMessagesLimiter().validate(source.getAccountIdentifier(), destination.getUuid());
sendIndividualMessage(destination,
destinationIdentifier,
@ -314,6 +316,7 @@ public class MessageController {
}
private void sendSyncMessage(final AuthenticatedDevice source,
final Account sourceAccount,
final ServiceIdentifier destinationIdentifier,
final IncomingMessageList messages,
final ContainerRequestContext context)
@ -323,7 +326,7 @@ public class MessageController {
throw new WebApplicationException(Status.FORBIDDEN);
}
sendIndividualMessage(source.getAccount(),
sendIndividualMessage(sourceAccount,
destinationIdentifier,
source,
messages,
@ -420,8 +423,8 @@ public class MessageController {
try {
return message.toEnvelope(
destinationIdentifier,
sender != null ? sender.getAccount() : null,
sender != null ? sender.getAuthenticatedDevice().getId() : null,
sender != null ? new AciServiceIdentifier(sender.getAccountIdentifier()) : null,
sender != null ? sender.getDeviceId() : null,
messages.timestamp() == 0 ? System.currentTimeMillis() : messages.timestamp(),
isStory,
messages.online(),
@ -437,7 +440,7 @@ public class MessageController {
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
final Optional<Byte> syncMessageSenderDeviceId = messageType == MessageType.SYNC
? Optional.ofNullable(sender).map(authenticatedDevice -> authenticatedDevice.getAuthenticatedDevice().getId())
? Optional.ofNullable(sender).map(AuthenticatedDevice::getDeviceId)
: Optional.empty();
try {
@ -755,31 +758,37 @@ public class MessageController {
@Timed
@GET
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<OutgoingMessageEntityList> getPendingMessages(@ReadOnly @Auth AuthenticatedDevice auth,
public CompletableFuture<OutgoingMessageEntityList> getPendingMessages(@Auth AuthenticatedDevice auth,
@HeaderParam(WebsocketHeaders.X_SIGNAL_RECEIVE_STORIES) String receiveStoriesHeader,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent) {
boolean shouldReceiveStories = WebsocketHeaders.parseReceiveStoriesHeader(receiveStoriesHeader);
return accountsManager.getByAccountIdentifierAsync(auth.getAccountIdentifier())
.thenCompose(maybeAccount -> {
final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final Device device = account.getDevice(auth.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), auth.getAuthenticatedDevice(), userAgent);
final boolean shouldReceiveStories = WebsocketHeaders.parseReceiveStoriesHeader(receiveStoriesHeader);
return messagesManager.getMessagesForDevice(
auth.getAccount().getUuid(),
auth.getAuthenticatedDevice(),
false)
.map(messagesAndHasMore -> {
Stream<Envelope> envelopes = messagesAndHasMore.first().stream();
if (!shouldReceiveStories) {
envelopes = envelopes.filter(e -> !e.getStory());
}
pushNotificationManager.handleMessagesRetrieved(account, device, userAgent);
return messagesManager.getMessagesForDevice(
auth.getAccountIdentifier(),
device,
false)
.map(messagesAndHasMore -> {
Stream<Envelope> envelopes = messagesAndHasMore.first().stream();
if (!shouldReceiveStories) {
envelopes = envelopes.filter(e -> !e.getStory());
}
final OutgoingMessageEntityList messages = new OutgoingMessageEntityList(envelopes
.map(OutgoingMessageEntity::fromEnvelope)
.peek(outgoingMessageEntity -> {
messageMetrics.measureAccountOutgoingMessageUuidMismatches(auth.getAccount(), outgoingMessageEntity);
messageMetrics.measureAccountOutgoingMessageUuidMismatches(account, outgoingMessageEntity);
messageMetrics.measureOutgoingMessageLatency(outgoingMessageEntity.serverTimestamp(),
"rest",
auth.getAuthenticatedDevice().isPrimary(),
auth.getDeviceId() == Device.PRIMARY_ID,
outgoingMessageEntity.urgent(),
// Messages fetched via this endpoint (as opposed to WebSocketConnection) are never ephemeral
// because, by definition, the client doesn't have a "live" connection via which to receive
@ -791,26 +800,27 @@ public class MessageController {
.collect(Collectors.toList()),
messagesAndHasMore.second());
Metrics.summary(OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)))
.record(estimateMessageListSizeBytes(messages));
Metrics.summary(OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)))
.record(estimateMessageListSizeBytes(messages));
if (!messages.messages().isEmpty()) {
messageDeliveryLoopMonitor.recordDeliveryAttempt(auth.getAccount().getIdentifier(IdentityType.ACI),
auth.getAuthenticatedDevice().getId(),
messages.messages().getFirst().guid(),
userAgent,
"rest");
}
if (!messages.messages().isEmpty()) {
messageDeliveryLoopMonitor.recordDeliveryAttempt(auth.getAccountIdentifier(),
auth.getDeviceId(),
messages.messages().getFirst().guid(),
userAgent,
"rest");
}
if (messagesAndHasMore.second()) {
pushNotificationScheduler.scheduleDelayedNotification(auth.getAccount(), auth.getAuthenticatedDevice(), NOTIFY_FOR_REMAINING_MESSAGES_DELAY);
}
if (messagesAndHasMore.second()) {
pushNotificationScheduler.scheduleDelayedNotification(account, device, NOTIFY_FOR_REMAINING_MESSAGES_DELAY);
}
return messages;
})
.timeout(Duration.ofSeconds(5))
.subscribeOn(messageDeliveryScheduler)
.toFuture();
return messages;
})
.timeout(Duration.ofSeconds(5))
.subscribeOn(messageDeliveryScheduler)
.toFuture();
});
}
private static long estimateMessageListSizeBytes(final OutgoingMessageEntityList messageList) {
@ -827,22 +837,27 @@ public class MessageController {
@Timed
@DELETE
@Path("/uuid/{uuid}")
public CompletableFuture<Response> removePendingMessage(@ReadOnly @Auth AuthenticatedDevice auth, @PathParam("uuid") UUID uuid) {
public CompletableFuture<Response> removePendingMessage(@Auth AuthenticatedDevice auth, @PathParam("uuid") UUID uuid) {
final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final Device device = account.getDevice(auth.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
return messagesManager.delete(
auth.getAccount().getUuid(),
auth.getAuthenticatedDevice(),
auth.getAccountIdentifier(),
device,
uuid,
null)
.thenAccept(maybeRemovedMessage -> maybeRemovedMessage.ifPresent(removedMessage -> {
WebSocketConnection.recordMessageDeliveryDuration(removedMessage.serverTimestamp(),
auth.getAuthenticatedDevice());
WebSocketConnection.recordMessageDeliveryDuration(removedMessage.serverTimestamp(), device);
if (removedMessage.sourceServiceId().isPresent()
&& removedMessage.envelopeType() != Type.SERVER_DELIVERY_RECEIPT) {
if (removedMessage.sourceServiceId().get() instanceof AciServiceIdentifier aciServiceIdentifier) {
try {
receiptSender.sendReceipt(removedMessage.destinationServiceId(), auth.getAuthenticatedDevice().getId(),
receiptSender.sendReceipt(removedMessage.destinationServiceId(), auth.getDeviceId(),
aciServiceIdentifier, removedMessage.clientTimestamp());
} catch (Exception e) {
logger.warn("Failed to send delivery receipt", e);
@ -863,7 +878,7 @@ public class MessageController {
@Consumes(MediaType.APPLICATION_JSON)
@Path("/report/{source}/{messageGuid}")
public Response reportSpamMessage(
@ReadOnly @Auth AuthenticatedDevice auth,
@Auth AuthenticatedDevice auth,
@PathParam("source") String source,
@PathParam("messageGuid") UUID messageGuid,
@Nullable SpamReport spamReport,
@ -899,7 +914,7 @@ public class MessageController {
}
}
UUID spamReporterUuid = auth.getAccount().getUuid();
UUID spamReporterUuid = auth.getAccountIdentifier();
// spam report token is optional, but if provided ensure it is non-empty.
final Optional<byte[]> maybeSpamReportToken =

View File

@ -5,8 +5,8 @@
package org.whispersystems.textsecuregcm.controllers;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import java.util.Map;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
public class MultiRecipientMismatchedDevicesException extends Exception {

View File

@ -69,7 +69,6 @@ import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.websocket.auth.ReadOnly;
/**
@ -163,7 +162,7 @@ public class OneTimeDonationController {
@StringToClassMapItem(key = "error", value = String.class)
})))
public CompletableFuture<Response> createBoostPaymentIntent(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@NotNull @Valid CreateBoostRequest request,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) {
@ -249,7 +248,7 @@ public class OneTimeDonationController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> createPayPalBoost(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@NotNull @Valid CreatePayPalBoostRequest request,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Context ContainerRequestContext containerRequestContext) {
@ -296,7 +295,7 @@ public class OneTimeDonationController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> confirmPayPalBoost(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@NotNull @Valid ConfirmPayPalBoostRequest request,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) {
@ -342,7 +341,7 @@ public class OneTimeDonationController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> createBoostReceiptCredentials(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@NotNull @Valid final CreateBoostReceiptCredentialsRequest request,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) {

View File

@ -17,7 +17,6 @@ import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator
import org.whispersystems.textsecuregcm.configuration.PaymentsServiceConfiguration;
import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager;
import org.whispersystems.textsecuregcm.entities.CurrencyConversionEntityList;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v1/payments")
@Tag(name = "Payments")
@ -43,14 +42,14 @@ public class PaymentsController {
@GET
@Path("/auth")
@Produces(MediaType.APPLICATION_JSON)
public ExternalServiceCredentials getAuth(final @ReadOnly @Auth AuthenticatedDevice auth) {
return paymentsServiceCredentialsGenerator.generateForUuid(auth.getAccount().getUuid());
public ExternalServiceCredentials getAuth(final @Auth AuthenticatedDevice auth) {
return paymentsServiceCredentialsGenerator.generateForUuid(auth.getAccountIdentifier());
}
@GET
@Path("/conversions")
@Produces(MediaType.APPLICATION_JSON)
public CurrencyConversionEntityList getConversions(final @ReadOnly @Auth AuthenticatedDevice auth) {
public CurrencyConversionEntityList getConversions(final @Auth AuthenticatedDevice auth) {
return currencyManager.getCurrencyConversions().orElseThrow();
}
}

View File

@ -26,6 +26,7 @@ import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.QueryParam;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.HttpHeaders;
@ -94,8 +95,6 @@ import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.ProfileHelper;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.auth.Mutable;
import org.whispersystems.websocket.auth.ReadOnly;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v1/profile")
@ -152,15 +151,18 @@ public class ProfileController {
@PUT
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
public Response setProfile(@Mutable @Auth AuthenticatedDevice auth, @NotNull @Valid CreateProfileRequest request) {
public Response setProfile(@Auth AuthenticatedDevice auth, @NotNull @Valid CreateProfileRequest request) {
final Optional<VersionedProfile> currentProfile = profilesManager.get(auth.getAccount().getUuid(),
request.version());
final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
final Optional<VersionedProfile> currentProfile =
profilesManager.get(auth.getAccountIdentifier(), request.version());
if (request.paymentAddress() != null && request.paymentAddress().length != 0) {
final boolean hasDisallowedPrefix =
dynamicConfigurationManager.getConfiguration().getPaymentsConfiguration().getDisallowedPrefixes().stream()
.anyMatch(prefix -> auth.getAccount().getNumber().startsWith(prefix));
.anyMatch(prefix -> account.getNumber().startsWith(prefix));
if (hasDisallowedPrefix && currentProfile.map(VersionedProfile::paymentAddress).isEmpty()) {
return Response.status(Response.Status.FORBIDDEN).build();
@ -179,7 +181,7 @@ public class ProfileController {
case UPDATE -> ProfileHelper.generateAvatarObjectName();
};
profilesManager.set(auth.getAccount().getUuid(),
profilesManager.set(auth.getAccountIdentifier(),
new VersionedProfile(
request.version(),
request.name(),
@ -194,7 +196,7 @@ public class ProfileController {
currentAvatar.ifPresent(s -> profilesManager.deleteAvatar(s).join());
}
accountsManager.update(auth.getAccount(), a -> {
accountsManager.update(account, a -> {
final List<AccountBadge> updatedBadges = request.badges()
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, a.getBadges()))
@ -216,7 +218,7 @@ public class ProfileController {
@Path("/{identifier}/{version}")
@ManagedAsync
public VersionedProfileResponse getProfile(
@ReadOnly @Auth Optional<AuthenticatedDevice> auth,
@Auth Optional<AuthenticatedDevice> maybeAuthenticatedDevice,
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@Context ContainerRequestContext containerRequestContext,
@PathParam("identifier") AciServiceIdentifier accountIdentifier,
@ -224,7 +226,11 @@ public class ProfileController {
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent)
throws RateLimitExceededException {
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
final Optional<Account> maybeRequester =
maybeAuthenticatedDevice.map(
authenticatedDevice -> accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)));
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier, "getVersionedProfile", userAgent);
return buildVersionedProfileResponse(targetAccount,
@ -238,7 +244,7 @@ public class ProfileController {
@Produces(MediaType.APPLICATION_JSON)
@Path("/{identifier}/{version}/{credentialRequest}")
public CredentialProfileResponse getProfile(
@ReadOnly @Auth Optional<AuthenticatedDevice> auth,
@Auth Optional<AuthenticatedDevice> maybeAuthenticatedDevice,
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@Context ContainerRequestContext containerRequestContext,
@PathParam("identifier") AciServiceIdentifier accountIdentifier,
@ -252,7 +258,11 @@ public class ProfileController {
throw new BadRequestException();
}
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
final Optional<Account> maybeRequester =
maybeAuthenticatedDevice.map(
authenticatedDevice -> accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)));
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier, "credentialRequest", userAgent);
final boolean isSelf = maybeRequester.map(requester -> ProfileHelper.isSelfProfileRequest(requester.getUuid(), accountIdentifier)).orElse(false);
@ -270,7 +280,7 @@ public class ProfileController {
@Path("/{identifier}")
@ManagedAsync
public BaseProfileResponse getUnversionedProfile(
@ReadOnly @Auth Optional<AuthenticatedDevice> auth,
@Auth Optional<AuthenticatedDevice> maybeAuthenticatedDevice,
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@HeaderParam(HeaderUtils.GROUP_SEND_TOKEN) Optional<GroupSendTokenHeader> groupSendToken,
@Context ContainerRequestContext containerRequestContext,
@ -278,7 +288,10 @@ public class ProfileController {
@PathParam("identifier") ServiceIdentifier identifier)
throws RateLimitExceededException {
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
final Optional<Account> maybeRequester =
maybeAuthenticatedDevice.map(
authenticatedDevice -> accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)));
final Account targetAccount;
if (groupSendToken.isPresent()) {

View File

@ -34,7 +34,6 @@ import org.whispersystems.textsecuregcm.entities.ProvisioningMessage;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
import org.whispersystems.websocket.auth.ReadOnly;
/**
* The provisioning controller facilitates transmission of provisioning messages from the primary device associated with
@ -77,7 +76,7 @@ public class ProvisioningController {
@ApiResponse(responseCode="204", description="The provisioning message was delivered to the given provisioning address")
@ApiResponse(responseCode="400", description="The provisioning message was too large")
@ApiResponse(responseCode="404", description="No device with the given provisioning address was connected at the time of the request")
public void sendProvisioningMessage(@ReadOnly @Auth final AuthenticatedDevice auth,
public void sendProvisioningMessage(@Auth final AuthenticatedDevice auth,
@Parameter(description = "The temporary provisioning address to which to send a provisioning message")
@PathParam("destination") final String provisioningAddress,
@ -93,7 +92,7 @@ public class ProvisioningController {
throw new WebApplicationException(Response.Status.BAD_REQUEST);
}
rateLimiters.getMessagesLimiter().validate(auth.getAccount().getUuid());
rateLimiters.getMessagesLimiter().validate(auth.getAccountIdentifier());
final boolean subscriberPresent =
provisioningManager.sendProvisioningMessage(provisioningAddress, Base64.getMimeDecoder().decode(message.body()));

View File

@ -30,7 +30,6 @@ import org.whispersystems.textsecuregcm.entities.UserRemoteConfigList;
import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager;
import org.whispersystems.textsecuregcm.util.Conversions;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v1/config")
@Tag(name = "Remote Config")
@ -64,7 +63,7 @@ public class RemoteConfigController {
"""
)
@ApiResponse(responseCode = "200", description = "Remote configuration values for the authenticated user", useReturnTypeSchema = true)
public UserRemoteConfigList getAll(@ReadOnly @Auth AuthenticatedDevice auth) {
public UserRemoteConfigList getAll(@Auth AuthenticatedDevice auth) {
try {
MessageDigest digest = MessageDigest.getInstance("SHA1");
@ -73,7 +72,7 @@ public class RemoteConfigController {
return new UserRemoteConfigList(Stream.concat(remoteConfigsManager.getAll().stream().map(config -> {
final byte[] hashKey = config.getHashKey() != null ? config.getHashKey().getBytes(StandardCharsets.UTF_8)
: config.getName().getBytes(StandardCharsets.UTF_8);
boolean inBucket = isInBucket(digest, auth.getAccount().getUuid(), hashKey, config.getPercentage(),
boolean inBucket = isInBucket(digest, auth.getAccountIdentifier(), hashKey, config.getPercentage(),
config.getUuids());
return new UserRemoteConfig(config.getName(), inBucket,
inBucket ? config.getValue() : config.getDefaultValue());

View File

@ -17,7 +17,6 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
import org.whispersystems.textsecuregcm.configuration.SecureStorageServiceConfiguration;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v1/storage")
@Tag(name = "Secure Storage")
@ -47,7 +46,7 @@ public class SecureStorageController {
"""
)
@ApiResponse(responseCode = "200", description = "`JSON` with generated credentials.", useReturnTypeSchema = true)
public ExternalServiceCredentials getAuth(@ReadOnly @Auth AuthenticatedDevice auth) {
return storageServiceCredentialsGenerator.generateForUuid(auth.getAccount().getUuid());
public ExternalServiceCredentials getAuth(@Auth AuthenticatedDevice auth) {
return storageServiceCredentialsGenerator.generateForUuid(auth.getAccountIdentifier());
}
}

View File

@ -34,7 +34,6 @@ import org.whispersystems.textsecuregcm.limits.RateLimitedByIp;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v2/backup")
@Tag(name = "Secure Value Recovery")
@ -78,8 +77,8 @@ public class SecureValueRecovery2Controller {
)
@ApiResponse(responseCode = "200", description = "`JSON` with generated credentials.", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
public ExternalServiceCredentials getAuth(@ReadOnly @Auth final AuthenticatedDevice auth) {
return backupServiceCredentialGenerator.generateFor(auth.getAccount().getUuid().toString());
public ExternalServiceCredentials getAuth(@Auth final AuthenticatedDevice auth) {
return backupServiceCredentialGenerator.generateFor(auth.getAccountIdentifier().toString());
}

View File

@ -28,7 +28,6 @@ import org.whispersystems.textsecuregcm.s3.PolicySigner;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v1/sticker")
@Tag(name = "Stickers")
@ -47,10 +46,10 @@ public class StickerController {
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/pack/form/{count}")
public StickerPackFormUploadAttributes getStickersForm(@ReadOnly @Auth AuthenticatedDevice auth,
public StickerPackFormUploadAttributes getStickersForm(@Auth AuthenticatedDevice auth,
@PathParam("count") @Min(1) @Max(201) int stickerCount)
throws RateLimitExceededException {
rateLimiters.getStickerPackLimiter().validate(auth.getAccount().getUuid());
rateLimiters.getStickerPackLimiter().validate(auth.getAccountIdentifier());
ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC);
String packId = generatePackId();

View File

@ -88,7 +88,6 @@ import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v1/subscription")
@io.swagger.v3.oas.annotations.tags.Tag(name = "Subscriptions")
@ -220,7 +219,7 @@ public class SubscriptionController {
@Path("/{subscriberId}")
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> deleteSubscriber(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId) throws SubscriptionException {
SubscriberCredentials subscriberCredentials =
SubscriberCredentials.process(authenticatedAccount, subscriberId, clock);
@ -232,7 +231,7 @@ public class SubscriptionController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> updateSubscriber(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId) throws SubscriptionException {
SubscriberCredentials subscriberCredentials =
SubscriberCredentials.process(authenticatedAccount, subscriberId, clock);
@ -248,7 +247,7 @@ public class SubscriptionController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> createPaymentMethod(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId,
@QueryParam("type") @DefaultValue("CARD") PaymentMethod paymentMethodType,
@HeaderParam(HttpHeaders.USER_AGENT) @Nullable final String userAgentString) throws SubscriptionException {
@ -284,7 +283,7 @@ public class SubscriptionController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> createPayPalPaymentMethod(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId,
@NotNull @Valid CreatePayPalBillingAgreementRequest request,
@Context ContainerRequestContext containerRequestContext,
@ -323,7 +322,7 @@ public class SubscriptionController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> setDefaultPaymentMethodWithProcessor(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId,
@PathParam("processor") PaymentProvider processor,
@PathParam("paymentMethodToken") @NotEmpty String paymentMethodToken) throws SubscriptionException {
@ -360,7 +359,7 @@ public class SubscriptionController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> setSubscriptionLevel(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId,
@PathParam("level") long level,
@PathParam("currency") String currency,
@ -432,7 +431,7 @@ public class SubscriptionController {
@ApiResponse(responseCode = "409", description = "subscriberId is already linked to a processor that does not support appstore payments. Delete this subscriberId and use a new one.")
@ApiResponse(responseCode = "429", description = "Rate limit exceeded.")
public CompletableFuture<SetSubscriptionLevelSuccessResponse> setAppStoreSubscription(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId,
@PathParam("originalTransactionId") String originalTransactionId) throws SubscriptionException {
final SubscriberCredentials subscriberCredentials =
@ -473,7 +472,7 @@ public class SubscriptionController {
@ApiResponse(responseCode = "404", description = "No such subscriberId exists or subscriberId is malformed or the purchaseToken does not exist")
@ApiResponse(responseCode = "409", description = "subscriberId is already linked to a processor that does not support Play Billing. Delete this subscriberId and use a new one.")
public CompletableFuture<SetSubscriptionLevelSuccessResponse> setPlayStoreSubscription(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId,
@PathParam("purchaseToken") String purchaseToken) throws SubscriptionException {
final SubscriberCredentials subscriberCredentials =
@ -627,7 +626,7 @@ public class SubscriptionController {
@ApiResponse(responseCode = "403", description = "subscriberId authentication failure OR account authentication is present")
@ApiResponse(responseCode = "404", description = "No such subscriberId exists or subscriberId is malformed")
public CompletableFuture<Response> getSubscriptionInformation(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId) throws SubscriptionException {
SubscriberCredentials subscriberCredentials =
SubscriberCredentials.process(authenticatedAccount, subscriberId, clock);
@ -662,7 +661,7 @@ public class SubscriptionController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> createSubscriptionReceiptCredentials(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@PathParam("subscriberId") String subscriberId,
@NotNull @Valid GetReceiptCredentialsRequest request) throws SubscriptionException {
@ -691,7 +690,7 @@ public class SubscriptionController {
@Path("/{subscriberId}/default_payment_method_for_ideal/{setupIntentId}")
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> setDefaultPaymentMethodForIdeal(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId,
@PathParam("setupIntentId") @NotEmpty String setupIntentId) throws SubscriptionException {
SubscriberCredentials subscriberCredentials =

View File

@ -5,9 +5,9 @@
package org.whispersystems.textsecuregcm.controllers;
import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession;
import javax.annotation.Nullable;
import java.time.Duration;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession;
public class VerificationSessionRateLimitExceededException extends RateLimitExceededException {

View File

@ -10,15 +10,14 @@ import com.webauthn4j.converter.jackson.deserializer.json.ByteArrayBase64Deseria
import io.micrometer.core.instrument.Metrics;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.AssertTrue;
import javax.annotation.Nullable;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import java.util.Arrays;
import java.util.Objects;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import java.util.Arrays;
import java.util.Objects;
public record IncomingMessage(int type,
byte destinationDeviceId,
@ -35,7 +34,7 @@ public record IncomingMessage(int type,
MetricsUtil.name(IncomingMessage.class, "rejectInvalidEnvelopeType");
public MessageProtos.Envelope toEnvelope(final ServiceIdentifier destinationIdentifier,
@Nullable Account sourceAccount,
@Nullable AciServiceIdentifier sourceServiceIdentifier,
@Nullable Byte sourceDeviceId,
final long timestamp,
final boolean story,
@ -54,9 +53,9 @@ public record IncomingMessage(int type,
.setEphemeral(ephemeral)
.setUrgent(urgent);
if (sourceAccount != null && sourceDeviceId != null) {
if (sourceServiceIdentifier != null && sourceDeviceId != null) {
envelopeBuilder
.setSourceServiceId(new AciServiceIdentifier(sourceAccount.getUuid()).toServiceIdentifierString())
.setSourceServiceId(sourceServiceIdentifier.toServiceIdentifierString())
.setSourceDevice(sourceDeviceId.intValue());
}

View File

@ -96,8 +96,8 @@ public class RestDeprecationFilter implements ContainerRequestFilter {
return false;
}
if (securityContext.getUserPrincipal() instanceof AuthenticatedDevice ad) {
return experimentEnrollmentManager.isEnrolled(ad.getAccount().getUuid(), AUTHENTICATED_EXPERIMENT_NAME);
if (securityContext.getUserPrincipal() instanceof AuthenticatedDevice authenticatedDevice) {
return experimentEnrollmentManager.isEnrolled(authenticatedDevice.getAccountIdentifier(), AUTHENTICATED_EXPERIMENT_NAME);
} else {
log.error("Security context was not null but user principal was of type {}", securityContext.getUserPrincipal().getClass().getName());
return false;

View File

@ -9,8 +9,6 @@ import java.util.Optional;
import jakarta.ws.rs.core.Response;
import org.signal.chat.messages.SendMessageResponse;
import org.signal.chat.messages.SendMultiRecipientMessageResponse;
import org.whispersystems.textsecuregcm.auth.AccountAndAuthenticatedDeviceHolder;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account;
@ -31,7 +29,7 @@ public interface SpamChecker {
SpamCheckResult<Response> checkForIndividualRecipientSpamHttp(
final MessageType messageType,
final ContainerRequestContext requestContext,
final Optional<? extends AccountAndAuthenticatedDeviceHolder> maybeSource,
final Optional<org.whispersystems.textsecuregcm.auth.AuthenticatedDevice> maybeSource,
final Optional<Account> maybeDestination,
final ServiceIdentifier destinationIdentifier);
@ -79,7 +77,7 @@ public interface SpamChecker {
@Override
public SpamCheckResult<Response> checkForIndividualRecipientSpamHttp(final MessageType messageType,
final ContainerRequestContext requestContext,
final Optional<? extends AccountAndAuthenticatedDeviceHolder> maybeSource,
final Optional<org.whispersystems.textsecuregcm.auth.AuthenticatedDevice> maybeSource,
final Optional<Account> maybeDestination,
final ServiceIdentifier destinationIdentifier) {
@ -95,7 +93,7 @@ public interface SpamChecker {
@Override
public SpamCheckResult<GrpcResponse<SendMessageResponse>> checkForIndividualRecipientSpamGrpc(final MessageType messageType,
final Optional<AuthenticatedDevice> maybeSource,
final Optional<org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice> maybeSource,
final Optional<Account> maybeDestination,
final ServiceIdentifier destinationIdentifier) {

View File

@ -1,36 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.websocket.auth.PrincipalSupplier;
public class AccountPrincipalSupplier implements PrincipalSupplier<AuthenticatedDevice> {
private final AccountsManager accountsManager;
public AccountPrincipalSupplier(final AccountsManager accountsManager) {
this.accountsManager = accountsManager;
}
@Override
public AuthenticatedDevice refresh(final AuthenticatedDevice oldAccount) {
final Account account = accountsManager.getByAccountIdentifier(oldAccount.getAccount().getUuid())
.orElseThrow(() -> new RefreshingAccountNotFoundException("Could not find account"));
final Device device = account.getDevice(oldAccount.getAuthenticatedDevice().getId())
.orElseThrow(() -> new RefreshingAccountNotFoundException("Could not find device"));
return new AuthenticatedDevice(account, device);
}
@Override
public AuthenticatedDevice deepCopy(final AuthenticatedDevice authenticatedDevice) {
final Account cloned = AccountUtil.cloneAccountAsNotStale(authenticatedDevice.getAccount());
return new AuthenticatedDevice(
cloned,
cloned.getDevice(authenticatedDevice.getAuthenticatedDevice().getId())
.orElseThrow(() -> new IllegalStateException(
"Could not find device from a clone of an account where the device was present")));
}
}

View File

@ -600,13 +600,10 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
}
/**
* Unlink a device from the given account. The device will be immediately disconnected if it is
* connected to any chat frontend, but it is the caller's responsibility to make sure that the
* account's *other* devices are disconnected, either by use of
* {@link org.whispersystems.textsecuregcm.auth.LinkedDeviceRefreshRequirementProvider} or
* directly by calling {@link DeviceDisconnectionManager#requestDisconnection}.
* Unlink a device from the given account. The device will be immediately disconnected if it is connected to any chat
* frontend.
*
* @returns the updated Account
* @return the updated Account
*/
public CompletableFuture<Account> removeDevice(final Account account, final byte deviceId) {
if (deviceId == Device.PRIMARY_ID) {

View File

@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.websocket;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.micrometer.core.instrument.Tags;
import java.util.Optional;
import java.util.concurrent.ScheduledExecutorService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -20,7 +21,10 @@ import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener;
@ -36,6 +40,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private static final Logger log = LoggerFactory.getLogger(AuthenticatedConnectListener.class);
private final AccountsManager accountsManager;
private final ReceiptSender receiptSender;
private final MessagesManager messagesManager;
private final MessageMetrics messageMetrics;
@ -51,7 +56,9 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private final OpenWebSocketCounter openAuthenticatedWebSocketCounter;
private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter;
public AuthenticatedConnectListener(ReceiptSender receiptSender,
public AuthenticatedConnectListener(
AccountsManager accountsManager,
ReceiptSender receiptSender,
MessagesManager messagesManager,
MessageMetrics messageMetrics,
PushNotificationManager pushNotificationManager,
@ -62,6 +69,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
ClientReleaseManager clientReleaseManager,
MessageDeliveryLoopMonitor messageDeliveryLoopMonitor,
final ExperimentEnrollmentManager experimentEnrollmentManager) {
this.accountsManager = accountsManager;
this.receiptSender = receiptSender;
this.messagesManager = messagesManager;
this.messageMetrics = messageMetrics;
@ -82,7 +91,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
}
@Override
public void onWebSocketConnect(WebSocketSessionContext context) {
public void onWebSocketConnect(final WebSocketSessionContext context) {
final boolean authenticated = (context.getAuthenticated() != null);
final OpenWebSocketCounter openWebSocketCounter =
@ -92,12 +101,24 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
if (authenticated) {
final AuthenticatedDevice auth = context.getAuthenticated(AuthenticatedDevice.class);
final Optional<Account> maybeAuthenticatedAccount = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier());
final Optional<Device> maybeAuthenticatedDevice = maybeAuthenticatedAccount.flatMap(account -> account.getDevice(auth.getDeviceId()));;
if (maybeAuthenticatedAccount.isEmpty() || maybeAuthenticatedDevice.isEmpty()) {
log.warn("{}:{} not found when opening authenticated WebSocket", auth.getAccountIdentifier(), auth.getDeviceId());
context.getClient().close(1011, "Unexpected error initializing connection");
return;
}
final WebSocketConnection connection = new WebSocketConnection(receiptSender,
messagesManager,
messageMetrics,
pushNotificationManager,
pushNotificationScheduler,
auth,
maybeAuthenticatedAccount.get(),
maybeAuthenticatedDevice.get(),
context.getClient(),
scheduledExecutorService,
messageDeliveryScheduler,
@ -110,8 +131,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
// receive push notifications for inbound messages. We should do this first because, at this point, the
// connection has already closed and attempts to actually deliver a message via the connection will not succeed.
// It's preferable to start sending push notifications as soon as possible.
webSocketConnectionEventManager.handleClientDisconnected(auth.getAccount().getUuid(),
auth.getAuthenticatedDevice().getId());
webSocketConnectionEventManager.handleClientDisconnected(auth.getAccountIdentifier(), auth.getDeviceId());
// Finally, stop trying to deliver messages and send a push notification if the connection is aware of any
// undelivered messages.
@ -127,7 +147,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
// Finally, we register this client's presence, which suppresses push notifications. We do this last because
// receiving extra push notifications is generally preferable to missing out on a push notification.
webSocketConnectionEventManager.handleClientConnected(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), connection);
webSocketConnectionEventManager.handleClientConnected(auth.getAccountIdentifier(), auth.getDeviceId(), connection);
} catch (final Exception e) {
log.warn("Failed to initialize websocket", e);
context.getClient().close(1011, "Unexpected error initializing connection");

View File

@ -9,41 +9,39 @@ import static org.whispersystems.textsecuregcm.util.HeaderUtils.basicCredentials
import com.google.common.net.HttpHeaders;
import javax.annotation.Nullable;
import io.dropwizard.auth.basic.BasicCredentials;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import java.util.Optional;
public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<AuthenticatedDevice> {
private static final ReusableAuth<AuthenticatedDevice> CREDENTIALS_NOT_PRESENTED = ReusableAuth.anonymous();
private final AccountAuthenticator accountAuthenticator;
private final PrincipalSupplier<AuthenticatedDevice> principalSupplier;
public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator,
final PrincipalSupplier<AuthenticatedDevice> principalSupplier) {
public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator) {
this.accountAuthenticator = accountAuthenticator;
this.principalSupplier = principalSupplier;
}
@Override
public ReusableAuth<AuthenticatedDevice> authenticate(final UpgradeRequest request)
public Optional<AuthenticatedDevice> authenticate(final UpgradeRequest request)
throws InvalidCredentialsException {
@Nullable final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);
if (authHeader == null) {
return CREDENTIALS_NOT_PRESENTED;
return Optional.empty();
}
return basicCredentialsFromAuthHeader(authHeader)
.flatMap(accountAuthenticator::authenticate)
.map(authenticatedAccount -> ReusableAuth.authenticated(authenticatedAccount, this.principalSupplier))
final BasicCredentials credentials = basicCredentialsFromAuthHeader(authHeader)
.orElseThrow(InvalidCredentialsException::new);
final AuthenticatedDevice authenticatedDevice = accountAuthenticator.authenticate(credentials)
.orElseThrow(InvalidCredentialsException::new);
return Optional.of(authenticatedDevice);
}
}

View File

@ -37,7 +37,6 @@ import org.eclipse.jetty.util.StaticException;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
@ -52,6 +51,7 @@ import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventListener;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
@ -123,7 +123,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
private final AuthenticatedDevice auth;
private final Account authenticatedAccount;
private final Device authenticatedDevice;
private final WebSocketClient client;
private final int sendFuturesTimeoutMillis;
@ -156,7 +157,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
MessageMetrics messageMetrics,
PushNotificationManager pushNotificationManager,
PushNotificationScheduler pushNotificationScheduler,
AuthenticatedDevice auth,
Account authenticatedAccount,
Device authenticatedDevice,
WebSocketClient client,
ScheduledExecutorService scheduledExecutorService,
Scheduler messageDeliveryScheduler,
@ -169,7 +171,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
messageMetrics,
pushNotificationManager,
pushNotificationScheduler,
auth,
authenticatedAccount,
authenticatedDevice,
client,
DEFAULT_SEND_FUTURES_TIMEOUT_MILLIS,
scheduledExecutorService,
@ -184,7 +187,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
MessageMetrics messageMetrics,
PushNotificationManager pushNotificationManager,
PushNotificationScheduler pushNotificationScheduler,
AuthenticatedDevice auth,
Account authenticatedAccount,
Device authenticatedDevice,
WebSocketClient client,
int sendFuturesTimeoutMillis,
ScheduledExecutorService scheduledExecutorService,
@ -198,7 +202,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
this.messageMetrics = messageMetrics;
this.pushNotificationManager = pushNotificationManager;
this.pushNotificationScheduler = pushNotificationScheduler;
this.auth = auth;
this.authenticatedAccount = authenticatedAccount;
this.authenticatedDevice = authenticatedDevice;
this.client = client;
this.sendFuturesTimeoutMillis = sendFuturesTimeoutMillis;
this.scheduledExecutorService = scheduledExecutorService;
@ -209,7 +214,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
}
public void start() {
pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), auth.getAuthenticatedDevice(), client.getUserAgent());
pushNotificationManager.handleMessagesRetrieved(authenticatedAccount, authenticatedDevice, client.getUserAgent());
queueDrainStartTime.set(System.currentTimeMillis());
processStoredMessages();
}
@ -229,8 +234,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
client.close(1000, "OK");
if (storedMessageState.get() != StoredMessageState.EMPTY) {
pushNotificationScheduler.scheduleDelayedNotification(auth.getAccount(),
auth.getAuthenticatedDevice(),
pushNotificationScheduler.scheduleDelayedNotification(authenticatedAccount,
authenticatedDevice,
CLOSE_WITH_PENDING_MESSAGES_NOTIFICATION_DELAY);
}
}
@ -242,7 +247,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
sendMessageCounter.increment();
sentMessageCounter.increment();
bytesSentCounter.increment(body.map(bytes -> bytes.length).orElse(0));
messageMetrics.measureAccountEnvelopeUuidMismatches(auth.getAccount(), message);
messageMetrics.measureAccountEnvelopeUuidMismatches(authenticatedAccount, message);
// X-Signal-Key: false must be sent until Android stops assuming it missing means true
return client.sendRequest("PUT", "/api/v1/message",
@ -253,7 +258,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
} else {
messageMetrics.measureOutgoingMessageLatency(message.getServerTimestamp(),
"websocket",
auth.getAuthenticatedDevice().isPrimary(),
authenticatedDevice.isPrimary(),
message.getUrgent(),
message.getEphemeral(),
client.getUserAgent(),
@ -263,12 +268,12 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
final CompletableFuture<Void> result;
if (isSuccessResponse(response)) {
result = messagesManager.delete(auth.getAccount().getUuid(), auth.getAuthenticatedDevice(),
result = messagesManager.delete(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice,
storedMessageInfo.guid(), storedMessageInfo.serverTimestamp())
.thenApply(ignored -> null);
if (message.getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT) {
recordMessageDeliveryDuration(message.getServerTimestamp(), auth.getAuthenticatedDevice());
recordMessageDeliveryDuration(message.getServerTimestamp(), authenticatedDevice);
sendDeliveryReceiptFor(message);
}
} else {
@ -307,7 +312,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
try {
receiptSender.sendReceipt(ServiceIdentifier.valueOf(message.getDestinationServiceId()),
auth.getAuthenticatedDevice().getId(), AciServiceIdentifier.valueOf(message.getSourceServiceId()),
authenticatedDevice.getId(), AciServiceIdentifier.valueOf(message.getSourceServiceId()),
message.getClientTimestamp());
} catch (IllegalArgumentException e) {
logger.error("Could not parse UUID: {}", message.getSourceServiceId());
@ -338,7 +343,6 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
// Cleared the queue! Send a queue empty message if we need to
consecutiveRetries.set(0);
if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) {
final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent()));
final long drainDuration = System.currentTimeMillis() - queueDrainStartTime.get();
@ -399,7 +403,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
final CompletableFuture<Void> queueCleared = new CompletableFuture<>();
final Publisher<Envelope> messages =
messagesManager.getMessagesForDeviceReactive(auth.getAccount().getUuid(), auth.getAuthenticatedDevice(), cachedMessagesOnly);
messagesManager.getMessagesForDeviceReactive(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice, cachedMessagesOnly);
final AtomicBoolean hasSentFirstMessage = new AtomicBoolean();
final AtomicBoolean hasErrored = new AtomicBoolean();
@ -410,8 +414,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
.limitRate(MESSAGE_PUBLISHER_LIMIT_RATE)
.doOnNext(envelope -> {
if (hasSentFirstMessage.compareAndSet(false, true)) {
messageDeliveryLoopMonitor.recordDeliveryAttempt(auth.getAccount().getIdentifier(IdentityType.ACI),
auth.getAuthenticatedDevice().getId(),
messageDeliveryLoopMonitor.recordDeliveryAttempt(authenticatedAccount.getIdentifier(IdentityType.ACI),
authenticatedDevice.getId(),
UUID.fromString(envelope.getServerGuid()),
client.getUserAgent(),
"websocket");
@ -471,7 +475,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
final UUID messageGuid = UUID.fromString(envelope.getServerGuid());
if (envelope.getStory() && !client.shouldDeliverStories()) {
messagesManager.delete(auth.getAccount().getUuid(), auth.getAuthenticatedDevice(), messageGuid, envelope.getServerTimestamp());
messagesManager.delete(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice, messageGuid, envelope.getServerTimestamp());
return CompletableFuture.completedFuture(null);
} else {

View File

@ -21,6 +21,7 @@ import jakarta.ws.rs.core.MediaType;
import java.io.IOException;
import java.net.URI;
import java.util.EnumSet;
import java.util.Optional;
import org.apache.commons.lang3.RandomStringUtils;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
@ -34,9 +35,7 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ -78,8 +77,7 @@ public class WebsocketResourceProviderIntegrationTest {
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(testController);
webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.setAuthenticator(upgradeRequest ->
ReusableAuth.authenticated(mock(AuthenticatedDevice.class), PrincipalSupplier.forImmutablePrincipal()));
webSocketEnvironment.setAuthenticator(upgradeRequest -> Optional.of(mock(AuthenticatedDevice.class)));
webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE);
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {

View File

@ -1,279 +0,0 @@
package org.whispersystems.textsecuregcm;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
import io.dropwizard.auth.Auth;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletRegistration;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import java.io.IOException;
import java.net.URI;
import java.util.EnumSet;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.IntStream;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync;
import org.glassfish.jersey.server.ServerProperties;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.RefreshingAccountNotFoundException;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
public class WebsocketReuseAuthIntegrationTest {
private static final AuthenticatedDevice ACCOUNT = mock(AuthenticatedDevice.class);
@SuppressWarnings("unchecked")
private static final PrincipalSupplier<AuthenticatedDevice> PRINCIPAL_SUPPLIER = mock(PrincipalSupplier.class);
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION =
new DropwizardAppExtension<>(TestApplication.class);
private WebSocketClient client;
@BeforeEach
void setUp() throws Exception {
reset(PRINCIPAL_SUPPLIER);
reset(ACCOUNT);
when(ACCOUNT.getName()).thenReturn("original");
client = new WebSocketClient();
client.start();
}
@AfterEach
void tearDown() throws Exception {
client.stop();
}
public static class TestApplication extends Application<Configuration> {
@Override
public void run(final Configuration configuration, final Environment environment) throws Exception {
final TestController testController = new TestController();
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
final WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment =
new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(testController);
webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.setAuthenticator(upgradeRequest -> ReusableAuth.authenticated(ACCOUNT, PRINCIPAL_SUPPLIER));
webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE);
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {
});
final WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class,
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
final ServletRegistration.Dynamic websocketServlet =
environment.servlets().addServlet("WebSocket", webSocketServlet);
websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
}
}
private WebSocketResponseMessage make1WebsocketRequest(final String requestPath) throws IOException {
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
return testWebsocketListener.doGet(requestPath).join();
}
@ParameterizedTest
@ValueSource(strings = {"/test/read-auth", "/test/optional-read-auth"})
public void readAuth(final String path) throws IOException {
final WebSocketResponseMessage response = make1WebsocketRequest(path);
assertThat(response.getStatus()).isEqualTo(200);
verifyNoMoreInteractions(PRINCIPAL_SUPPLIER);
}
@ParameterizedTest
@ValueSource(strings = {"/test/write-auth", "/test/optional-write-auth"})
public void writeAuth(final String path) throws IOException {
final AuthenticatedDevice copiedAccount = mock(AuthenticatedDevice.class);
when(copiedAccount.getName()).thenReturn("copy");
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(copiedAccount);
final WebSocketResponseMessage response = make1WebsocketRequest(path);
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.getBody().map(String::new)).get().isEqualTo("copy");
verify(PRINCIPAL_SUPPLIER, times(1)).deepCopy(any());
verifyNoMoreInteractions(PRINCIPAL_SUPPLIER);
}
@Test
public void readAfterWrite() throws IOException {
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(ACCOUNT);
final AuthenticatedDevice account2 = mock(AuthenticatedDevice.class);
when(account2.getName()).thenReturn("refresh");
when(PRINCIPAL_SUPPLIER.refresh(any())).thenReturn(account2);
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse.getBody().map(String::new)).get().isEqualTo("original");
final WebSocketResponseMessage writeResponse = testWebsocketListener.doGet("/test/write-auth").join();
assertThat(writeResponse.getBody().map(String::new)).get().isEqualTo("original");
final WebSocketResponseMessage readResponse2 = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse2.getBody().map(String::new)).get().isEqualTo("refresh");
}
@Test
public void readAfterWriteRefreshFails() throws IOException {
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(ACCOUNT);
when(PRINCIPAL_SUPPLIER.refresh(any())).thenThrow(RefreshingAccountNotFoundException.class);
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
final WebSocketResponseMessage writeResponse = testWebsocketListener.doGet("/test/write-auth").join();
assertThat(writeResponse.getBody().map(String::new)).get().isEqualTo("original");
final WebSocketResponseMessage readResponse2 = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse2.getStatus()).isEqualTo(500);
}
@Test
public void readConcurrentWithWrite() throws IOException, ExecutionException, InterruptedException, TimeoutException {
final AuthenticatedDevice deepCopy = mock(AuthenticatedDevice.class);
when(deepCopy.getName()).thenReturn("deepCopy");
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(deepCopy);
final AuthenticatedDevice refresh = mock(AuthenticatedDevice.class);
when(refresh.getName()).thenReturn("refresh");
when(PRINCIPAL_SUPPLIER.refresh(any())).thenReturn(refresh);
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
// start a write request that takes a while to finish
final CompletableFuture<WebSocketResponseMessage> writeResponse =
testWebsocketListener.doGet("/test/start-delayed-write/foo");
// send a bunch of reads, they should reflect the original auth
final List<CompletableFuture<WebSocketResponseMessage>> futures = IntStream.range(0, 10)
.boxed().map(i -> testWebsocketListener.doGet("/test/read-auth"))
.toList();
CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join();
for (CompletableFuture<WebSocketResponseMessage> future : futures) {
assertThat(future.join().getBody().map(String::new)).get().isEqualTo("original");
}
assertThat(writeResponse.isDone()).isFalse();
// finish the delayed write request
testWebsocketListener.doGet("/test/finish-delayed-write/foo").get(1, TimeUnit.SECONDS);
assertThat(writeResponse.join().getBody().map(String::new)).get().isEqualTo("deepCopy");
// subsequent reads should have the refreshed auth
final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse.getBody().map(String::new)).get().isEqualTo("refresh");
}
@Path("/test")
public static class TestController {
private final ConcurrentHashMap<String, CountDownLatch> delayedWriteLatches = new ConcurrentHashMap<>();
@GET
@Path("/read-auth")
@ManagedAsync
public String readAuth(@ReadOnly @Auth final AuthenticatedDevice account) {
return account.getName();
}
@GET
@Path("/optional-read-auth")
@ManagedAsync
public String optionalReadAuth(@ReadOnly @Auth final Optional<AuthenticatedDevice> account) {
return account.map(AuthenticatedDevice::getName).orElse("empty");
}
@GET
@Path("/write-auth")
@ManagedAsync
public String writeAuth(@Auth final AuthenticatedDevice account) {
return account.getName();
}
@GET
@Path("/optional-write-auth")
@ManagedAsync
public String optionalWriteAuth(@Auth final Optional<AuthenticatedDevice> account) {
return account.map(AuthenticatedDevice::getName).orElse("empty");
}
@GET
@Path("/start-delayed-write/{id}")
@ManagedAsync
public String startDelayedWrite(@Auth final AuthenticatedDevice account, @PathParam("id") String id)
throws InterruptedException {
delayedWriteLatches.computeIfAbsent(id, i -> new CountDownLatch(1)).await();
return account.getName();
}
@GET
@Path("/finish-delayed-write/{id}")
@ManagedAsync
public String finishDelayedWrite(@PathParam("id") String id) {
delayedWriteLatches.computeIfAbsent(id, i -> new CountDownLatch(1)).countDown();
return "ok";
}
}
}

View File

@ -18,7 +18,6 @@ import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
class CertificateGeneratorTest {
@ -37,7 +36,7 @@ class CertificateGeneratorTest {
@Test
void testCreateFor() throws IOException, InvalidKeyException, org.signal.libsignal.protocol.InvalidKeyException {
final Account account = mock(Account.class);
final Device device = mock(Device.class);
final byte deviceId = 4;
final CertificateGenerator certificateGenerator = new CertificateGenerator(
Base64.getDecoder().decode(SIGNING_CERTIFICATE),
Curve.decodePrivatePoint(Base64.getDecoder().decode(SIGNING_KEY)), 1);
@ -45,9 +44,8 @@ class CertificateGeneratorTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(IDENTITY_KEY);
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.getNumber()).thenReturn("+18005551234");
when(device.getId()).thenReturn((byte) 4);
assertTrue(certificateGenerator.createFor(account, device, true).length > 0);
assertTrue(certificateGenerator.createFor(account, device, false).length > 0);
assertTrue(certificateGenerator.createFor(account, deviceId, true).length > 0);
assertTrue(certificateGenerator.createFor(account, deviceId, false).length > 0);
}
}

View File

@ -31,8 +31,6 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.PrincipalSupplier;
class IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest {
@ -59,9 +57,9 @@ class IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest {
final boolean expectPqKeyCheck,
@Nullable final String expectedAlertHeader) {
final ReusableAuth<AuthenticatedDevice> reusableAuth = authenticatedDevice != null
? ReusableAuth.authenticated(authenticatedDevice, PrincipalSupplier.forImmutablePrincipal())
: ReusableAuth.anonymous();
final Optional<AuthenticatedDevice> reusableAuth = authenticatedDevice != null
? Optional.of(authenticatedDevice)
: Optional.empty();
final JettyServerUpgradeResponse response = mock(JettyServerUpgradeResponse.class);
@ -88,20 +86,24 @@ class IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest {
private static List<Arguments> handleAuthentication() {
final Device activePrimaryDevice = mock(Device.class);
when(activePrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(activePrimaryDevice.isPrimary()).thenReturn(true);
when(activePrimaryDevice.getLastSeen()).thenReturn(CLOCK.millis());
final Device minIdlePrimaryDevice = mock(Device.class);
when(minIdlePrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(minIdlePrimaryDevice.isPrimary()).thenReturn(true);
when(minIdlePrimaryDevice.getLastSeen())
.thenReturn(CLOCK.instant().minus(MIN_IDLE_DURATION).minusSeconds(1).toEpochMilli());
final Device longIdlePrimaryDevice = mock(Device.class);
when(longIdlePrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(longIdlePrimaryDevice.isPrimary()).thenReturn(true);
when(longIdlePrimaryDevice.getLastSeen())
.thenReturn(CLOCK.instant().minus(IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter.PQ_KEY_CHECK_THRESHOLD).minusSeconds(1).toEpochMilli());
final Device linkedDevice = mock(Device.class);
when(linkedDevice.getId()).thenReturn((byte) (Device.PRIMARY_ID + 1));
when(linkedDevice.isPrimary()).thenReturn(false);
final Account accountWithActivePrimaryDevice = mock(Account.class);

View File

@ -1,328 +0,0 @@
/*
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import io.dropwizard.auth.Auth;
import io.dropwizard.auth.AuthDynamicFeature;
import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
import io.dropwizard.jersey.DropwizardResourceConfig;
import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import jakarta.ws.rs.DELETE;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.PUT;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.Principal;
import java.time.Duration;
import java.util.Arrays;
import java.util.Base64;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.glassfish.jersey.server.ApplicationHandler;
import org.glassfish.jersey.server.ResourceConfig;
import org.glassfish.jersey.server.monitoring.ApplicationEventListener;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProvider;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.logging.WebsocketRequestLog;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
import org.whispersystems.websocket.messages.protobuf.SubProtocol;
import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider;
@ExtendWith(DropwizardExtensionsSupport.class)
class LinkedDeviceRefreshRequirementProviderTest {
private final ApplicationEventListener applicationEventListener = mock(ApplicationEventListener.class);
private final Account account = new Account();
private final Device authenticatedDevice = DevicesHelper.createDevice(Device.PRIMARY_ID);
private final Supplier<Optional<TestPrincipal>> principalSupplier = () -> Optional.of(
new TestPrincipal("test", account, authenticatedDevice));
private final ResourceExtension resources = ResourceExtension.builder()
.addProvider(new AuthDynamicFeature(new BasicCredentialAuthFilter.Builder<TestPrincipal>()
.setAuthenticator(c -> principalSupplier.get()).buildAuthFilter()))
.addProvider(new AuthValueFactoryProvider.Binder<>(TestPrincipal.class))
.addProvider(applicationEventListener)
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new TestResource())
.build();
private AccountsManager accountsManager;
private DisconnectionRequestManager disconnectionRequestManager;
private LinkedDeviceRefreshRequirementProvider provider;
@BeforeEach
void setup() {
accountsManager = mock(AccountsManager.class);
disconnectionRequestManager = mock(DisconnectionRequestManager.class);
provider = new LinkedDeviceRefreshRequirementProvider(accountsManager);
final WebsocketRefreshRequestEventListener listener =
new WebsocketRefreshRequestEventListener(disconnectionRequestManager, provider);
when(applicationEventListener.onRequest(any())).thenReturn(listener);
final UUID uuid = UUID.randomUUID();
account.setUuid(uuid);
account.addDevice(authenticatedDevice);
IntStream.range(2, 4)
.forEach(deviceId -> account.addDevice(DevicesHelper.createDevice((byte) deviceId)));
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
}
@Test
void testDeviceAdded() {
final int initialDeviceCount = account.getDevices().size();
final List<String> addedDeviceNames = List.of(
Base64.getEncoder().encodeToString("newDevice1".getBytes(StandardCharsets.UTF_8)),
Base64.getEncoder().encodeToString("newDevice2".getBytes(StandardCharsets.UTF_8)));
final Response response = resources.getJerseyTest()
.target("/v1/test/account/devices")
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.put(Entity.entity(addedDeviceNames, MediaType.APPLICATION_JSON_PATCH_JSON));
assertEquals(200, response.getStatus());
assertEquals(initialDeviceCount + addedDeviceNames.size(), account.getDevices().size());
verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of((byte) 1));
verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of((byte) 2));
verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of((byte) 3));
}
@ParameterizedTest
@ValueSource(ints = {1, 2})
void testDeviceRemoved(final int removedDeviceCount) {
final List<Byte> initialDeviceIds = account.getDevices().stream().map(Device::getId).toList();
final List<Byte> deletedDeviceIds = account.getDevices().stream()
.map(Device::getId)
.filter(deviceId -> deviceId != Device.PRIMARY_ID)
.limit(removedDeviceCount)
.toList();
assert deletedDeviceIds.size() == removedDeviceCount;
final String deletedDeviceIdsParam = deletedDeviceIds.stream().map(String::valueOf)
.collect(Collectors.joining(","));
final Response response = resources.getJerseyTest()
.target("/v1/test/account/devices/" + deletedDeviceIdsParam)
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.delete();
assertEquals(200, response.getStatus());
initialDeviceIds.forEach(deviceId ->
verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of(deviceId)));
verifyNoMoreInteractions(disconnectionRequestManager);
}
@Test
void testOnEvent() {
Response response = resources.getJerseyTest()
.target("/v1/test/hello")
.request()
// no authorization required
.get();
assertEquals(200, response.getStatus());
response = resources.getJerseyTest()
.target("/v1/test/authorized")
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.get();
assertEquals(200, response.getStatus());
verify(accountsManager, never()).getByAccountIdentifier(any(UUID.class));
}
@Nested
class WebSocket {
private WebSocketResourceProvider<TestPrincipal> provider;
private RemoteEndpoint remoteEndpoint;
@BeforeEach
void setup() {
ResourceConfig resourceConfig = new DropwizardResourceConfig();
resourceConfig.register(applicationEventListener);
resourceConfig.register(new TestResource());
resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder());
resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class));
resourceConfig.register(new JacksonMessageBodyProvider(SystemMapper.jsonMapper()));
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
provider = new WebSocketResourceProvider<>("127.0.0.1", RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME,
applicationHandler, requestLog, TestPrincipal.reusableAuth("test", account, authenticatedDevice),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
remoteEndpoint = mock(RemoteEndpoint.class);
Session session = mock(Session.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getRemote()).thenReturn(remoteEndpoint);
when(session.getUpgradeRequest()).thenReturn(request);
provider.onWebSocketConnect(session);
}
@Test
void testOnEvent() throws Exception {
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello",
new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
final SubProtocol.WebSocketResponseMessage response = verifyAndGetResponse(remoteEndpoint);
assertEquals(200, response.getStatus());
}
private SubProtocol.WebSocketResponseMessage verifyAndGetResponse(final RemoteEndpoint remoteEndpoint)
throws IOException {
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
return SubProtocol.WebSocketMessage.parseFrom(responseBytesCaptor.getValue().array()).getResponse();
}
}
public static class TestPrincipal implements Principal, AccountAndAuthenticatedDeviceHolder {
private final String name;
private final Account account;
private final Device device;
private TestPrincipal(final String name, final Account account, final Device device) {
this.name = name;
this.account = account;
this.device = device;
}
@Override
public String getName() {
return name;
}
@Override
public Account getAccount() {
return account;
}
@Override
public Device getAuthenticatedDevice() {
return device;
}
public static ReusableAuth<TestPrincipal> reusableAuth(final String name, final Account account, final Device device) {
return ReusableAuth.authenticated(new TestPrincipal(name, account, device), PrincipalSupplier.forImmutablePrincipal());
}
}
@Path("/v1/test")
public static class TestResource {
@GET
@Path("/hello")
public String testGetHello() {
return "Hello!";
}
@GET
@Path("/authorized")
public String testAuth(@Auth TestPrincipal principal) {
return "Youre in!";
}
@PUT
@Path("/account/devices")
@ChangesLinkedDevices
public String addDevices(@Auth TestPrincipal auth, List<byte[]> deviceNames) {
deviceNames.forEach(name -> {
final Device device = DevicesHelper.createDevice(auth.getAccount().getNextDeviceId());
auth.getAccount().addDevice(device);
device.setName(name);
});
return "Added devices " + deviceNames;
}
@DELETE
@Path("/account/devices/{deviceIds}")
@ChangesLinkedDevices
public String removeDevices(@Auth TestPrincipal auth, @PathParam("deviceIds") String deviceIds) {
Arrays.stream(deviceIds.split(","))
.map(Byte::valueOf)
.forEach(auth.getAccount()::removeDevice);
return "Removed device(s) " + deviceIds;
}
}
}

View File

@ -1,294 +0,0 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth;
import io.dropwizard.auth.AuthDynamicFeature;
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletRegistration;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.client.Invocation;
import java.io.IOException;
import java.net.URI;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
class PhoneNumberChangeRefreshRequirementProviderTest {
private static final String NUMBER = "+18005551234";
private static final String CHANGED_NUMBER = "+18005554321";
private static final String TEST_CRED_HEADER = HeaderUtils.basicAuthHeader("test", "password");
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION = new DropwizardAppExtension<>(
TestApplication.class);
private static final AccountAuthenticator AUTHENTICATOR = mock(AccountAuthenticator.class);
private static final AccountsManager ACCOUNTS_MANAGER = mock(AccountsManager.class);
private static final DisconnectionRequestManager DISCONNECTION_REQUEST_MANAGER =
mock(DisconnectionRequestManager.class);
private WebSocketClient client;
private final Account account1 = new Account();
private final Account account2 = new Account();
private final Device authenticatedDevice = DevicesHelper.createDevice(Device.PRIMARY_ID);
@BeforeEach
void setUp() throws Exception {
reset(AUTHENTICATOR, ACCOUNTS_MANAGER, DISCONNECTION_REQUEST_MANAGER);
client = new WebSocketClient();
client.start();
final UUID uuid = UUID.randomUUID();
account1.setUuid(uuid);
account1.addDevice(authenticatedDevice);
account1.setNumber(NUMBER, UUID.randomUUID());
account2.setUuid(uuid);
account2.addDevice(authenticatedDevice);
account2.setNumber(CHANGED_NUMBER, UUID.randomUUID());
}
@AfterEach
void tearDown() throws Exception {
client.stop();
}
public static class TestApplication extends Application<Configuration> {
@Override
public void run(final Configuration configuration, final Environment environment) throws Exception {
final TestController testController = new TestController();
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
final WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment =
new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.jersey().register(testController);
webSocketEnvironment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, DISCONNECTION_REQUEST_MANAGER));
environment.jersey()
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, DISCONNECTION_REQUEST_MANAGER));
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {
});
environment.jersey().register(new AuthDynamicFeature(new BasicCredentialAuthFilter.Builder<AuthenticatedDevice>()
.setAuthenticator(AUTHENTICATOR)
.buildAuthFilter()));
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(AUTHENTICATOR, mock(PrincipalSupplier.class)));
final WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class,
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
final ServletRegistration.Dynamic websocketServlet =
environment.servlets().addServlet("WebSocket", webSocketServlet);
websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
}
}
enum Protocol { HTTP, WEBSOCKET }
private void makeAnonymousRequest(final Protocol protocol, final String requestPath) throws IOException {
makeRequest(protocol, requestPath, true);
}
/*
* Make an authenticated request that will return account1 as the principal
*/
private void makeAuthenticatedRequest(
final Protocol protocol,
final String requestPath) throws IOException {
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedDevice(account1, authenticatedDevice)));
makeRequest(protocol,requestPath, false);
}
private void makeRequest(final Protocol protocol, final String requestPath, final boolean anonymous) throws IOException {
switch (protocol) {
case WEBSOCKET -> {
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
if (!anonymous) {
upgradeRequest.setHeader(HttpHeaders.AUTHORIZATION, TEST_CRED_HEADER);
}
client.connect(
testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())),
upgradeRequest);
testWebsocketListener.sendRequest(requestPath, "GET", Collections.emptyList(), Optional.empty()).join();
}
case HTTP -> {
final Invocation.Builder request = DROPWIZARD_APP_EXTENSION.client()
.target("http://127.0.0.1:%s%s".formatted(DROPWIZARD_APP_EXTENSION.getLocalPort(), requestPath))
.request();
if (!anonymous) {
request.header(HttpHeaders.AUTHORIZATION, TEST_CRED_HEADER);
}
request.get();
}
}
}
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestNoChange(final Protocol protocol) throws IOException {
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account1));
makeAuthenticatedRequest(protocol, "/test/annotated");
// Event listeners can fire after responses are sent
verify(ACCOUNTS_MANAGER, timeout(5000).times(1)).getByAccountIdentifier(eq(account1.getUuid()));
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
}
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestChange(final Protocol protocol) throws IOException {
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account2));
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedDevice(account1, authenticatedDevice)));
makeAuthenticatedRequest(protocol, "/test/annotated");
// Make sure we disconnect the account if the account has changed numbers. Event listeners can fire after responses
// are sent, so use a timeout.
verify(DISCONNECTION_REQUEST_MANAGER, timeout(5000))
.requestDisconnection(account1.getUuid(), List.of(authenticatedDevice.getId()));
verifyNoMoreInteractions(DISCONNECTION_REQUEST_MANAGER);
}
@Test
void handleRequestChangeAsyncEndpoint() throws IOException {
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account2));
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedDevice(account1, authenticatedDevice)));
// Event listeners with asynchronous HTTP endpoints don't currently correctly maintain state between request and
// response
makeAuthenticatedRequest(Protocol.WEBSOCKET, "/test/async-annotated");
// Make sure we disconnect the account if the account has changed numbers. Event listeners can fire after responses
// are sent, so use a timeout.
verify(DISCONNECTION_REQUEST_MANAGER, timeout(5000))
.requestDisconnection(account1.getUuid(), List.of(authenticatedDevice.getId()));
verifyNoMoreInteractions(DISCONNECTION_REQUEST_MANAGER);
}
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestNotAnnotated(final Protocol protocol) throws IOException, InterruptedException {
makeAuthenticatedRequest(protocol,"/test/not-annotated");
// Give a tick for event listeners to run. Racy, but should occasionally catch an errant running listener if one is
// introduced.
Thread.sleep(100);
// Shouldn't even read the account if the method has not been annotated
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
}
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestNotAuthenticated(final Protocol protocol) throws IOException, InterruptedException {
makeAnonymousRequest(protocol, "/test/not-authenticated");
// Give a tick for event listeners to run. Racy, but should occasionally catch an errant running listener if one is
// introduced.
Thread.sleep(100);
// Shouldn't even read the account if the method has not been annotated
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
}
@Path("/test")
public static class TestController {
@GET
@Path("/annotated")
@ChangesPhoneNumber
public String annotated(@ReadOnly @Auth final AuthenticatedDevice account) {
return "ok";
}
@GET
@Path("/async-annotated")
@ChangesPhoneNumber
@ManagedAsync
public String asyncAnnotated(@ReadOnly @Auth final AuthenticatedDevice account) {
return "ok";
}
@GET
@Path("/not-authenticated")
@ChangesPhoneNumber
public String notAuthenticated() {
return "ok";
}
@GET
@Path("/not-annotated")
public String notAnnotated(@ReadOnly @Auth final AuthenticatedDevice account) {
return "ok";
}
}
}

View File

@ -211,6 +211,11 @@ class AccountControllerTest {
when(accountsManager.getByE164(eq(SENDER_HAS_STORAGE))).thenReturn(Optional.of(senderHasStorage));
when(accountsManager.getByE164(eq(SENDER_TRANSFER))).thenReturn(Optional.of(senderTransfer));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_TWO));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_3)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_3));
when(accountsManager.getByAccountIdentifier(AuthHelper.UNDISCOVERABLE_UUID)).thenReturn(Optional.of(AuthHelper.UNDISCOVERABLE_ACCOUNT));
doAnswer(invocation -> {
final byte[] proof = invocation.getArgument(0);
final byte[] hash = invocation.getArgument(1);

View File

@ -145,6 +145,8 @@ class AccountControllerV2Test {
void setUp() throws Exception {
when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter);
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any(), any(), any())).thenAnswer(
(Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0);
@ -607,6 +609,8 @@ class AccountControllerV2Test {
@BeforeEach
void setUp() throws Exception {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(changeNumberManager.updatePniKeys(any(), any(), any(), any(), any(), any(), any())).thenAnswer(
(Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0);
@ -768,7 +772,9 @@ class AccountControllerV2Test {
@BeforeEach
void setup() {
AccountsHelper.setupMockUpdate(accountsManager);
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
}
@Test
void testSetPhoneNumberDiscoverability() {
Response response = resources.getJerseyTest()
@ -805,6 +811,7 @@ class AccountControllerV2Test {
@MethodSource
void testGetAccountDataReport(final Account account, final String expectedTextAfterHeader) throws Exception {
when(AuthHelper.ACCOUNTS_MANAGER.getByAccountIdentifier(account.getUuid())).thenReturn(Optional.of(account));
when(accountsManager.getByAccountIdentifier(account.getUuid())).thenReturn(Optional.of(account));
final Response response = resources.getJerseyTest()
.target("/v2/accounts/data_report")

View File

@ -73,6 +73,7 @@ import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.GrpcStatusRuntimeExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.metrics.BackupMetrics;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.EnumMapUtil;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -82,6 +83,7 @@ import reactor.core.publisher.Flux;
@ExtendWith(DropwizardExtensionsSupport.class)
public class ArchiveControllerTest {
private static final AccountsManager accountsManager = mock(AccountsManager.class);
private static final BackupAuthManager backupAuthManager = mock(BackupAuthManager.class);
private static final BackupManager backupManager = mock(BackupManager.class);
private final BackupAuthTestUtil backupAuthTestUtil = new BackupAuthTestUtil(Clock.systemUTC());
@ -95,7 +97,7 @@ public class ArchiveControllerTest {
.addProvider(new RateLimitExceededExceptionMapper())
.setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new ArchiveController(backupAuthManager, backupManager, new BackupMetrics()))
.addResource(new ArchiveController(accountsManager, backupAuthManager, backupManager, new BackupMetrics()))
.build();
private final UUID aci = UUID.randomUUID();
@ -106,6 +108,9 @@ public class ArchiveControllerTest {
public void setUp() {
reset(backupAuthManager);
reset(backupManager);
when(accountsManager.getByAccountIdentifierAsync(AuthHelper.VALID_UUID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT)));
}
@ParameterizedTest

View File

@ -9,6 +9,8 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
@ -21,9 +23,11 @@ import java.time.Instant;
import java.time.ZoneId;
import java.time.temporal.ChronoUnit;
import java.util.Base64;
import java.util.Optional;
import java.util.stream.Stream;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
@ -44,6 +48,7 @@ import org.whispersystems.textsecuregcm.entities.DeliveryCertificate;
import org.whispersystems.textsecuregcm.entities.GroupCredentials;
import org.whispersystems.textsecuregcm.entities.MessageProtos.SenderCertificate;
import org.whispersystems.textsecuregcm.entities.MessageProtos.ServerCertificate;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -66,6 +71,8 @@ class CertificateControllerTest {
private static final ServerZkAuthOperations serverZkAuthOperations;
private static final Clock clock = Clock.fixed(Instant.now(), ZoneId.systemDefault());
private static final AccountsManager accountsManager = mock(AccountsManager.class);
static {
try {
certificateGenerator = new CertificateGenerator(Base64.getDecoder().decode(signingCertificate),
@ -82,9 +89,14 @@ class CertificateControllerTest {
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class))
.setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new CertificateController(certificateGenerator, serverZkAuthOperations, genericServerSecretParams, clock))
.addResource(new CertificateController(accountsManager, certificateGenerator, serverZkAuthOperations, genericServerSecretParams, clock))
.build();
@BeforeEach
void setUp() {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
}
@Test
void testValidCertificate() throws Exception {
DeliveryCertificate certificateObject = resources.getJerseyTest()

View File

@ -38,6 +38,7 @@ import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker;
import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker.ChallengeConstraints;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
@ -45,11 +46,12 @@ import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
@ExtendWith(DropwizardExtensionsSupport.class)
class ChallengeControllerTest {
private static final AccountsManager accountsManager = mock(AccountsManager.class);
private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class);
private static final ChallengeConstraintChecker challengeConstraintChecker = mock(ChallengeConstraintChecker.class);
private static final ChallengeController challengeController =
new ChallengeController(rateLimitChallengeManager, challengeConstraintChecker);
new ChallengeController(accountsManager, rateLimitChallengeManager, challengeConstraintChecker);
private static final ResourceExtension EXTENSION = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
@ -63,6 +65,9 @@ class ChallengeControllerTest {
@BeforeEach
void setup() {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_TWO));
when(challengeConstraintChecker.challengeConstraints(any(), any()))
.thenReturn(new ChallengeConstraints(true, Optional.empty()));
}

View File

@ -27,6 +27,7 @@ import java.time.Duration;
import java.time.Instant;
import java.util.Base64;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.glassfish.jersey.server.ServerProperties;
@ -45,6 +46,7 @@ import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.GrpcStatusRuntimeExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.devicecheck.AppleDeviceCheckManager;
import org.whispersystems.textsecuregcm.storage.devicecheck.ChallengeNotFoundException;
import org.whispersystems.textsecuregcm.storage.devicecheck.DeviceCheckKeyIdNotFoundException;
@ -62,6 +64,7 @@ class DeviceCheckControllerTest {
private final static Duration REDEMPTION_DURATION = Duration.ofDays(5);
private final static long REDEMPTION_LEVEL = 201L;
private static final AccountsManager accountsManager = mock(AccountsManager.class);
private final static BackupAuthManager backupAuthManager = mock(BackupAuthManager.class);
private final static AppleDeviceCheckManager appleDeviceCheckManager = mock(AppleDeviceCheckManager.class);
private final static RateLimiters rateLimiters = mock(RateLimiters.class);
@ -76,7 +79,7 @@ class DeviceCheckControllerTest {
.addProvider(new RateLimitExceededExceptionMapper())
.setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new DeviceCheckController(clock, backupAuthManager, appleDeviceCheckManager, rateLimiters,
.addResource(new DeviceCheckController(clock, accountsManager, backupAuthManager, appleDeviceCheckManager, rateLimiters,
REDEMPTION_LEVEL, REDEMPTION_DURATION))
.build();
@ -86,6 +89,8 @@ class DeviceCheckControllerTest {
reset(appleDeviceCheckManager);
reset(rateLimiters);
when(rateLimiters.forDescriptor(any())).thenReturn(mock(RateLimiter.class));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
}
@ParameterizedTest

View File

@ -42,14 +42,12 @@ import java.util.concurrent.CompletableFuture;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.server.ServerProperties;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
@ -62,8 +60,6 @@ import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest;
@ -116,7 +112,6 @@ class DeviceControllerTest {
private static final Account account = mock(Account.class);
private static final Account maxedAccount = mock(Account.class);
private static final Device primaryDevice = mock(Device.class);
private static final DisconnectionRequestManager disconnectionRequestManager = mock(DisconnectionRequestManager.class);
private static final Map<String, Integer> deviceConfiguration = new HashMap<>();
private static final TestClock testClock = TestClock.now();
@ -129,16 +124,12 @@ class DeviceControllerTest {
persistentTimer,
deviceConfiguration);
@RegisterExtension
public static final AuthHelper.AuthFilterExtension AUTH_FILTER_EXTENSION = new AuthHelper.AuthFilterExtension();
private static final ResourceExtension resources = ResourceExtension.builder()
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class))
.addProvider(new RateLimitExceededExceptionMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, disconnectionRequestManager))
.addProvider(new DeviceLimitExceededExceptionMapper())
.addResource(deviceController)
.build();
@ -157,8 +148,15 @@ class DeviceControllerTest {
when(account.getNumber()).thenReturn(AuthHelper.VALID_NUMBER);
when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID);
when(account.getPhoneNumberIdentifier()).thenReturn(AuthHelper.VALID_PNI);
when(account.getPrimaryDevice()).thenReturn(primaryDevice);
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(primaryDevice));
when(account.getDevices()).thenReturn(List.of(primaryDevice));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
when(accountsManager.getByAccountIdentifierAsync(AuthHelper.VALID_UUID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(accountsManager.getByE164(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(account));
when(accountsManager.getByE164(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(maxedAccount));
@ -229,7 +227,7 @@ class DeviceControllerTest {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@ -310,7 +308,7 @@ class DeviceControllerTest {
final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(primaryDevice));
when(account.getDevices()).thenReturn(List.of(primaryDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@ -362,7 +360,7 @@ class DeviceControllerTest {
final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(primaryDevice));
when(account.getDevices()).thenReturn(List.of(primaryDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@ -398,7 +396,7 @@ class DeviceControllerTest {
void linkDeviceAtomicReusedToken() {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@ -447,7 +445,7 @@ class DeviceControllerTest {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@ -487,12 +485,12 @@ class DeviceControllerTest {
void linkDeviceAtomicConflictingChannel(final boolean fetchesMessages,
final Optional<ApnRegistrationId> apnRegistrationId,
final Optional<GcmRegistrationId> gcmRegistrationId) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
when(accountsManager.generateLinkDeviceToken(any())).thenReturn("test");
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final LinkDeviceToken deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
@ -548,12 +546,12 @@ class DeviceControllerTest {
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniPqLastResortPreKey) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
when(accountsManager.generateLinkDeviceToken(any())).thenReturn("test");
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final LinkDeviceToken deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
@ -613,11 +611,11 @@ class DeviceControllerTest {
aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
@ -647,11 +645,11 @@ class DeviceControllerTest {
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniPqLastResortPreKey) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(aciIdentityKey);
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(pniIdentityKey);
@ -698,7 +696,7 @@ class DeviceControllerTest {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@ -735,7 +733,7 @@ class DeviceControllerTest {
void linkDeviceRegistrationId(final int registrationId, final int pniRegistrationId, final int expectedStatusCode) {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
@ -800,17 +798,16 @@ class DeviceControllerTest {
@Test
void maxDevicesTest() {
final AuthHelper.TestAccount testAccount = AUTH_FILTER_EXTENSION.createTestAccount();
final List<Device> devices = IntStream.range(0, DeviceController.MAX_DEVICES + 1)
.mapToObj(i -> mock(Device.class))
.toList();
when(testAccount.account.getDevices()).thenReturn(devices);
when(account.getDevices()).thenReturn(devices);
Response response = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", testAccount.getAuthHeader())
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get();
assertEquals(411, response.getStatus());
@ -829,7 +826,7 @@ class DeviceControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
assertThat(response.hasEntity()).isFalse();
verify(AuthHelper.VALID_DEVICE).setCapabilities(Set.of(DeviceCapability.DELETE_SYNC));
verify(primaryDevice).setCapabilities(Set.of(DeviceCapability.DELETE_SYNC));
}
}
@ -851,12 +848,12 @@ class DeviceControllerTest {
void removeDevice() {
// this is a static mock, so it might have previous invocations
clearInvocations(AuthHelper.VALID_ACCOUNT);
clearInvocations(account);
final byte deviceId = 2;
when(accountsManager.removeDevice(AuthHelper.VALID_ACCOUNT, deviceId))
.thenReturn(CompletableFuture.completedFuture(AuthHelper.VALID_ACCOUNT));
when(accountsManager.removeDevice(account, deviceId))
.thenReturn(CompletableFuture.completedFuture(account));
try (final Response response = resources
.getJerseyTest()
@ -869,14 +866,14 @@ class DeviceControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
assertThat(response.hasEntity()).isFalse();
verify(accountsManager).removeDevice(AuthHelper.VALID_ACCOUNT, deviceId);
verify(accountsManager).removeDevice(account, deviceId);
}
}
@Test
void unlinkPrimaryDevice() {
// this is a static mock, so it might have previous invocations
clearInvocations(AuthHelper.VALID_ACCOUNT);
clearInvocations(account);
try (final Response response = resources
.getJerseyTest()
@ -897,7 +894,10 @@ class DeviceControllerTest {
final byte deviceId = 2;
when(accountsManager.removeDevice(AuthHelper.VALID_ACCOUNT_3, deviceId))
.thenReturn(CompletableFuture.completedFuture(AuthHelper.VALID_ACCOUNT));
.thenReturn(CompletableFuture.completedFuture(account));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_3))
.thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_3));
try (final Response response = resources
.getJerseyTest()
@ -946,7 +946,7 @@ class DeviceControllerTest {
assertEquals(204, response.getStatus());
}
verify(clientPublicKeysManager).setPublicKey(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE.getId(), request.publicKey());
verify(clientPublicKeysManager).setPublicKey(account, AuthHelper.VALID_DEVICE.getId(), request.publicKey());
}
@Test
@ -959,7 +959,7 @@ class DeviceControllerTest {
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
when(accountsManager
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any()))
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(primaryDevice), eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo)));
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
@ -985,7 +985,7 @@ class DeviceControllerTest {
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
when(accountsManager
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any()))
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(primaryDevice), eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
@ -1005,7 +1005,7 @@ class DeviceControllerTest {
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
when(accountsManager
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any()))
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(primaryDevice), eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException()));
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
@ -1079,7 +1079,7 @@ class DeviceControllerTest {
new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)));
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferArchive))
when(accountsManager.recordTransferArchiveUpload(account, deviceId, deviceCreated, transferArchive))
.thenReturn(CompletableFuture.completedFuture(null));
try (final Response response = resources.getJerseyTest()
@ -1092,7 +1092,7 @@ class DeviceControllerTest {
assertEquals(204, response.getStatus());
verify(accountsManager)
.recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferArchive);
.recordTransferArchiveUpload(account, deviceId, deviceCreated, transferArchive);
}
}
@ -1103,7 +1103,7 @@ class DeviceControllerTest {
final RemoteAttachmentError transferFailure = new RemoteAttachmentError(RemoteAttachmentError.ErrorType.CONTINUE_WITHOUT_UPLOAD);
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferFailure))
when(accountsManager.recordTransferArchiveUpload(account, deviceId, deviceCreated, transferFailure))
.thenReturn(CompletableFuture.completedFuture(null));
try (final Response response = resources.getJerseyTest()
@ -1116,7 +1116,7 @@ class DeviceControllerTest {
assertEquals(204, response.getStatus());
verify(accountsManager)
.recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferFailure);
.recordTransferArchiveUpload(account, deviceId, deviceCreated, transferFailure);
}
}
@ -1186,7 +1186,7 @@ class DeviceControllerTest {
new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)));
when(rateLimiter.validateAsync(anyString())).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.waitForTransferArchive(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any()))
when(accountsManager.waitForTransferArchive(eq(account), eq(primaryDevice), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(transferArchive)));
try (final Response response = resources.getJerseyTest()
@ -1206,7 +1206,7 @@ class DeviceControllerTest {
new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)));
when(rateLimiter.validateAsync(anyString())).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.waitForTransferArchive(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any()))
when(accountsManager.waitForTransferArchive(eq(account), eq(primaryDevice), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(transferArchive)));
try (final Response response = resources.getJerseyTest()
@ -1223,7 +1223,7 @@ class DeviceControllerTest {
@Test
void waitForTransferArchiveNoArchiveUploaded() {
when(rateLimiter.validateAsync(anyString())).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.waitForTransferArchive(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any()))
when(accountsManager.waitForTransferArchive(eq(account), eq(primaryDevice), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
try (final Response response = resources.getJerseyTest()

View File

@ -19,6 +19,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
import org.whispersystems.textsecuregcm.configuration.DirectoryV2ClientConfiguration;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
@ -35,7 +36,7 @@ class DirectoryControllerV2Test {
final Account account = mock(Account.class);
final UUID uuid = UUID.fromString("11111111-1111-1111-1111-111111111111");
when(account.getUuid()).thenReturn(uuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(uuid);
final ExternalServiceCredentials credentials = controller.getAuthToken(
new AuthenticatedDevice(account, mock(Device.class)));

View File

@ -140,6 +140,8 @@ class DonationControllerTest {
when(receiptCredentialPresentation.getReceiptExpirationTime()).thenReturn(receiptExpiration);
when(redeemedReceiptsManager.put(same(receiptSerial), eq(receiptExpiration), eq(receiptLevel), eq(AuthHelper.VALID_UUID))).thenReturn(
CompletableFuture.completedFuture(Boolean.FALSE));
when(accountsManager.getByAccountIdentifierAsync(eq(AuthHelper.VALID_UUID))).thenReturn(
CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT)));
RedeemReceiptRequest request = new RedeemReceiptRequest(presentation, true, true);
Response response = resources.getJerseyTest()

View File

@ -242,10 +242,21 @@ class KeysControllerTest {
when(existsAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of("1337".getBytes()));
when(accounts.getByServiceIdentifier(any())).thenReturn(Optional.empty());
when(accounts.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(accounts.getByServiceIdentifier(new AciServiceIdentifier(EXISTS_UUID))).thenReturn(Optional.of(existsAccount));
when(accounts.getByServiceIdentifier(new PniServiceIdentifier(EXISTS_PNI))).thenReturn(Optional.of(existsAccount));
when(accounts.getByServiceIdentifierAsync(new AciServiceIdentifier(EXISTS_UUID)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(existsAccount)));
when(accounts.getByServiceIdentifierAsync(new PniServiceIdentifier(EXISTS_PNI)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(existsAccount)));
when(accounts.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accounts.getByAccountIdentifierAsync(AuthHelper.VALID_UUID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT)));
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
when(KEYS.storeEcOneTimePreKeys(any(), anyByte(), any()))

View File

@ -236,6 +236,14 @@ class MessageControllerTest {
when(accountsManager.getByServiceIdentifierAsync(MULTI_DEVICE_PNI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(internationalAccount)));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifierAsync(AuthHelper.VALID_UUID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT)));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_3)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_3));
when(accountsManager.getByAccountIdentifierAsync(AuthHelper.VALID_UUID_3))
.thenReturn(CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT_3)));
when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getStoriesLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getInboundMessageBytes()).thenReturn(rateLimiter);

View File

@ -221,6 +221,8 @@ class ProfileControllerTest {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(capabilitiesAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(AuthHelper.VALID_UUID))).thenReturn(Optional.of(capabilitiesAccount));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_TWO));
final byte[] name = TestRandomUtil.nextBytes(81);
final byte[] emoji = TestRandomUtil.nextBytes(60);
final byte[] about = TestRandomUtil.nextBytes(156);
@ -1155,6 +1157,7 @@ class ProfileControllerTest {
reset(accountsManager);
final int accountsManagerUpdateRetryCount = 2;
AccountsHelper.setupMockUpdateWithRetries(accountsManager, accountsManagerUpdateRetryCount);
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_TWO));
// set up two invocations -- one for each AccountsManager#update try
when(AuthHelper.VALID_ACCOUNT_TWO.getBadges())
.thenReturn(List.of(

View File

@ -170,7 +170,12 @@ class RemoteConfigControllerTest {
void testHashKeyLinkedConfigs() {
boolean allUnlinkedConfigsMatched = true;
for (AuthHelper.TestAccount testAccount : AuthHelper.TEST_ACCOUNTS) {
UserRemoteConfigList configuration = resources.getJerseyTest().target("/v1/config/").request().header("Authorization", testAccount.getAuthHeader()).get(UserRemoteConfigList.class);
UserRemoteConfigList configuration = resources.getJerseyTest()
.target("/v1/config/")
.request()
.header("Authorization", testAccount.getAuthHeader())
.get(UserRemoteConfigList.class);
assertThat(configuration.getConfig()).hasSize(11);
final UserRemoteConfig linkedConfig0 = configuration.getConfig().get(7);

View File

@ -7,7 +7,6 @@ package org.whispersystems.textsecuregcm.entities;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.util.Random;
import java.util.UUID;
import javax.annotation.Nullable;
import org.junit.jupiter.api.Test;
@ -16,7 +15,6 @@ import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
@ -65,15 +63,13 @@ class OutgoingMessageEntityTest {
@Test
void entityPreservesEnvelope() {
final byte[] reportSpamToken = TestRandomUtil.nextBytes(8);
final AciServiceIdentifier sourceServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final Account account = new Account();
account.setUuid(UUID.randomUUID());
IncomingMessage message = new IncomingMessage(1, (byte) 44, 55, TestRandomUtil.nextBytes(4));
final IncomingMessage message = new IncomingMessage(1, (byte) 44, 55, TestRandomUtil.nextBytes(4));
MessageProtos.Envelope baseEnvelope = message.toEnvelope(
new AciServiceIdentifier(UUID.randomUUID()),
account,
sourceServiceIdentifier,
(byte) 123,
System.currentTimeMillis(),
false,

View File

@ -153,7 +153,7 @@ class MetricsRequestEventListenerTest {
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.reusableAuth("foo"),
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.authenticatedTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class);
@ -220,7 +220,7 @@ class MetricsRequestEventListenerTest {
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.reusableAuth("foo"),
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.authenticatedTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class);

View File

@ -220,6 +220,8 @@ public class AccountsHelper {
case "getBackupVoucher" -> when(updatedAccount.getBackupVoucher()).thenAnswer(stubbing);
case "getLastSeen" -> when(updatedAccount.getLastSeen()).thenAnswer(stubbing);
case "hasLockedCredentials" -> when(updatedAccount.hasLockedCredentials()).thenAnswer(stubbing);
case "getCurrentProfileVersion" -> when(updatedAccount.getCurrentProfileVersion()).thenAnswer(stubbing);
case "getUnidentifiedAccessKey" -> when(updatedAccount.getUnidentifiedAccessKey()).thenAnswer(stubbing);
default -> throw new IllegalArgumentException("unsupported method: Account#" + stubbing.getInvocation().getMethod().getName());
}
}

View File

@ -277,6 +277,7 @@ public class AuthHelper {
when(account.getPrimaryDevice()).thenReturn(device);
when(account.getNumber()).thenReturn(number);
when(account.getUuid()).thenReturn(uuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(uuid);
when(accountsManager.getByE164(number)).thenReturn(Optional.of(account));
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
}

View File

@ -5,8 +5,7 @@
package org.whispersystems.textsecuregcm.tests.util;
import java.security.Principal;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import java.util.Optional;
public class TestPrincipal implements Principal {
@ -21,7 +20,7 @@ public class TestPrincipal implements Principal {
return name;
}
public static ReusableAuth<TestPrincipal> reusableAuth(final String name) {
return ReusableAuth.authenticated(new TestPrincipal(name), PrincipalSupplier.forImmutablePrincipal());
public static Optional<TestPrincipal> authenticatedTestPrincipal(final String name) {
return Optional.of(new TestPrincipal(name));
}
}

View File

@ -176,7 +176,7 @@ class LoggingUnhandledExceptionMapperTest {
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog,
TestPrincipal.reusableAuth("foo"),
TestPrincipal.authenticatedTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);

View File

@ -28,7 +28,6 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.PrincipalSupplier;
class WebSocketAccountAuthenticatorTest {
@ -70,14 +69,12 @@ class WebSocketAccountAuthenticatorTest {
when(upgradeRequest.getHeader(eq(HttpHeaders.AUTHORIZATION))).thenReturn(authorizationHeaderValue);
}
final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(
accountAuthenticator,
mock(PrincipalSupplier.class));
final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
if (expectInvalid) {
assertThrows(InvalidCredentialsException.class, () -> webSocketAuthenticator.authenticate(upgradeRequest));
} else {
assertEquals(expectAccount, webSocketAuthenticator.authenticate(upgradeRequest).ref().isPresent());
assertEquals(expectAccount, webSocketAuthenticator.authenticate(upgradeRequest).isPresent());
}
}

View File

@ -43,11 +43,11 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
@ -111,7 +111,7 @@ class WebSocketConnectionIntegrationTest {
clientReleaseManager = mock(ClientReleaseManager.class);
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.getIdentifier(IdentityType.ACI)).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(Device.PRIMARY_ID);
}
@ -137,7 +137,8 @@ class WebSocketConnectionIntegrationTest {
new MessageMetrics(),
mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
new AuthenticatedDevice(account, device),
account,
device,
webSocketClient,
scheduledExecutorService,
messageDeliveryScheduler,
@ -159,14 +160,14 @@ class WebSocketConnectionIntegrationTest {
expectedMessages.add(envelope);
}
messagesDynamoDb.store(persistedMessages, account.getUuid(), device);
messagesDynamoDb.store(persistedMessages, account.getIdentifier(IdentityType.ACI), device);
}
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join();
expectedMessages.add(envelope);
}
@ -226,7 +227,8 @@ class WebSocketConnectionIntegrationTest {
new MessageMetrics(),
mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
new AuthenticatedDevice(account, device),
account,
device,
webSocketClient,
scheduledExecutorService,
messageDeliveryScheduler,
@ -250,13 +252,13 @@ class WebSocketConnectionIntegrationTest {
expectedMessages.add(envelope);
}
messagesDynamoDb.store(persistedMessages, account.getUuid(), device);
messagesDynamoDb.store(persistedMessages, account.getIdentifier(IdentityType.ACI), device);
}
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join();
expectedMessages.add(envelope);
}
@ -296,7 +298,8 @@ class WebSocketConnectionIntegrationTest {
new MessageMetrics(),
mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
new AuthenticatedDevice(account, device),
account,
device,
webSocketClient,
100, // use a very short timeout, so that this test completes quickly
scheduledExecutorService,
@ -321,13 +324,13 @@ class WebSocketConnectionIntegrationTest {
expectedMessages.add(envelope);
}
messagesDynamoDb.store(persistedMessages, account.getUuid(), device);
messagesDynamoDb.store(persistedMessages, account.getIdentifier(IdentityType.ACI), device);
}
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join();
expectedMessages.add(envelope);
}

View File

@ -56,6 +56,7 @@ import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
@ -67,9 +68,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import reactor.core.publisher.Flux;
@ -90,7 +89,6 @@ class WebSocketConnectionTest {
private AccountsManager accountsManager;
private Account account;
private Device device;
private AuthenticatedDevice auth;
private UpgradeRequest upgradeRequest;
private MessagesManager messagesManager;
private ReceiptSender receiptSender;
@ -104,7 +102,6 @@ class WebSocketConnectionTest {
accountsManager = mock(AccountsManager.class);
account = mock(Account.class);
device = mock(Device.class);
auth = new AuthenticatedDevice(account, device);
upgradeRequest = mock(UpgradeRequest.class);
messagesManager = mock(MessagesManager.class);
receiptSender = mock(ReceiptSender.class);
@ -122,8 +119,8 @@ class WebSocketConnectionTest {
@Test
void testCredentials() throws Exception {
WebSocketAccountAuthenticator webSocketAuthenticator =
new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class));
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager,
new WebSocketAccountAuthenticator(accountAuthenticator);
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, receiptSender, messagesManager,
new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class),
mock(WebSocketConnectionEventManager.class), retrySchedulingExecutor,
messageDeliveryScheduler, clientReleaseManager, mock(MessageDeliveryLoopMonitor.class),
@ -133,9 +130,9 @@ class WebSocketConnectionTest {
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(new AuthenticatedDevice(account, device)));
ReusableAuth<AuthenticatedDevice> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated()).thenReturn(account.ref().orElse(null));
when(sessionContext.getAuthenticated(AuthenticatedDevice.class)).thenReturn(account.ref().orElse(null));
Optional<AuthenticatedDevice> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated()).thenReturn(account.orElse(null));
when(sessionContext.getAuthenticated(AuthenticatedDevice.class)).thenReturn(account.orElse(null));
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
when(webSocketClient.getUserAgent()).thenReturn("Signal-Android/6.22.8");
@ -150,7 +147,7 @@ class WebSocketConnectionTest {
// unauthenticated
when(upgradeRequest.getParameterMap()).thenReturn(Map.of());
account = webSocketAuthenticator.authenticate(upgradeRequest);
assertFalse(account.ref().isPresent());
assertFalse(account.isPresent());
connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext, times(2)).addWebsocketClosedListener(
@ -174,7 +171,7 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
final Device sender1device = mock(Device.class);
@ -191,7 +188,7 @@ class WebSocketConnectionTest {
String userAgent = HttpHeaders.USER_AGENT;
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.fromIterable(outgoingMessages));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
@ -237,7 +234,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@ -320,7 +317,7 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
final Device sender1device = mock(Device.class);
@ -337,7 +334,7 @@ class WebSocketConnectionTest {
String userAgent = HttpHeaders.USER_AGENT;
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.fromIterable(pendingMessages));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
@ -364,7 +361,7 @@ class WebSocketConnectionTest {
futures.get(1).complete(response);
futures.get(0).completeExceptionally(new IOException());
verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getUuid())), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)),
verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getIdentifier(IdentityType.ACI))), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)),
eq(secondMessage.getClientTimestamp()));
connection.stop();
@ -377,7 +374,7 @@ class WebSocketConnectionTest {
final WebSocketConnection connection = webSocketConnection(client);
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.getIdentifier(IdentityType.ACI)).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@ -385,7 +382,7 @@ class WebSocketConnectionTest {
final AtomicBoolean returnMessageList = new AtomicBoolean(false);
when(
messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenAnswer(invocation -> {
synchronized (threadWaiting) {
threadWaiting.set(true);
@ -442,7 +439,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
final UUID accountUuid = UUID.randomUUID();
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@ -490,7 +487,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
final UUID accountUuid = UUID.randomUUID();
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@ -573,7 +570,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
final UUID accountUuid = UUID.randomUUID();
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@ -581,7 +578,7 @@ class WebSocketConnectionTest {
final List<Envelope> messages = List.of(
createMessage(senderUuid, UUID.randomUUID(), 1111L, "message the first"));
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.fromIterable(messages))
.thenReturn(Flux.empty());
@ -629,7 +626,7 @@ class WebSocketConnectionTest {
private WebSocketConnection webSocketConnection(final WebSocketClient client) {
return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(),
mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), auth, client,
mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), account, device, client,
retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager,
mock(MessageDeliveryLoopMonitor.class), mock(ExperimentEnrollmentManager.class));
}
@ -642,7 +639,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@ -669,7 +666,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@ -725,7 +722,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@ -741,11 +738,11 @@ class WebSocketConnectionTest {
// anything.
connection.processStoredMessages();
verify(messagesManager).getMessagesForDeviceReactive(account.getUuid(), device, false);
verify(messagesManager).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false);
connection.handleNewMessageAvailable();
verify(messagesManager).getMessagesForDeviceReactive(account.getUuid(), device, true);
verify(messagesManager).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, true);
}
@Test
@ -756,7 +753,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@ -773,7 +770,7 @@ class WebSocketConnectionTest {
connection.processStoredMessages();
connection.handleMessagesPersisted();
verify(messagesManager, times(2)).getMessagesForDeviceReactive(account.getUuid(), device, false);
verify(messagesManager, times(2)).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false);
}
@Test
@ -783,9 +780,9 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn((byte) 2);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.error(new RedisException("OH NO")));
when(retrySchedulingExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer(
@ -812,9 +809,9 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn((byte) 2);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.error(new RedisException("OH NO")));
final WebSocketClient client = mock(WebSocketClient.class);
@ -835,7 +832,7 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
final int totalMessages = 1000;
@ -884,7 +881,7 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
final AtomicBoolean canceled = new AtomicBoolean();

View File

@ -1,149 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket;
import java.security.Principal;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import org.whispersystems.websocket.auth.PrincipalSupplier;
/**
* This class holds a principal that can be reused across requests on a websocket. Since two requests may operate
* concurrently on the same principal, and some principals contain non thread-safe mutable state, appropriate use of
* this class ensures that no data races occur. It also ensures that after a principal is modified, a subsequent request
* gets the up-to-date principal
*
* @param <T> The underlying principal type
* @see PrincipalSupplier
*/
public abstract sealed class ReusableAuth<T extends Principal> {
/**
* Get a reference to the underlying principal that callers pledge not to modify.
* <p>
* The reference returned will potentially be provided to many threads concurrently accessing the principal. Callers
* should use this method only if they can ensure that they will not modify the in-memory principal object AND they do
* not intend to modify the underlying canonical representation of the principal.
* <p>
* For example, if a caller retrieves a reference to a principal, does not modify the in memory state, but updates a
* field on a database that should be reflected in subsequent retrievals of the principal, they will have met the
* first criteria, but not the second. In that case they should instead use {@link #mutableRef()}.
* <p>
* If other callers have modified the underlying principal by using {@link #mutableRef()}, this method may need to
* refresh the principal via {@link PrincipalSupplier#refresh} which could be a blocking operation.
*
* @return If authenticated, a reference to the underlying principal that should not be modified
*/
public abstract Optional<T> ref();
public interface MutableRef<T> {
T ref();
void close();
}
/**
* Get a reference to the underlying principal that may be modified.
* <p>
* The underlying principal can be safely modified. Multiple threads may operate on the same {@link ReusableAuth} so
* long as they each have their own {@link MutableRef}. After any modifications, the caller must call
* {@link MutableRef#close} to notify the principal has become dirty. Close should be called after modifications but
* before sending a response on the websocket. This ensures that a request that comes in after a modification response
* is received is guaranteed to see the modification.
*
* @return If authenticated, a reference to the underlying principal that may be modified
*/
public abstract Optional<MutableRef<T>> mutableRef();
/**
* @return A {@link ReusableAuth} indicating no credential were provided
*/
public static <T extends Principal> ReusableAuth<T> anonymous() {
//noinspection unchecked
return (ReusableAuth<T>) Anonymous.ANON_RESULT;
}
/**
* Create a successfully authenticated {@link ReusableAuth}
*
* @param principal The authenticated principal
* @param principalSupplier Instructions for how to refresh or copy this principal
* @param <T> The principal type
* @return A {@link ReusableAuth} for a successfully authenticated principal
*/
public static <T extends Principal> ReusableAuth<T> authenticated(T principal,
PrincipalSupplier<T> principalSupplier) {
return new Authenticated<>(principal, principalSupplier);
}
private static final class Anonymous<T extends Principal> extends ReusableAuth<T> {
@SuppressWarnings({"rawtypes"})
private static final ReusableAuth ANON_RESULT = new Anonymous();
@Override
public Optional<T> ref() {
return Optional.empty();
}
@Override
public Optional<MutableRef<T>> mutableRef() {
return Optional.empty();
}
}
private static final class Authenticated<T extends Principal> extends ReusableAuth<T> {
private T basePrincipal;
private final AtomicBoolean needRefresh = new AtomicBoolean(false);
private final PrincipalSupplier<T> principalSupplier;
Authenticated(final T basePrincipal, PrincipalSupplier<T> principalSupplier) {
this.basePrincipal = basePrincipal;
this.principalSupplier = principalSupplier;
}
@Override
public Optional<T> ref() {
maybeRefresh();
return Optional.of(basePrincipal);
}
@Override
public Optional<MutableRef<T>> mutableRef() {
maybeRefresh();
return Optional.of(new AuthenticatedMutableRef(principalSupplier.deepCopy(basePrincipal)));
}
private void maybeRefresh() {
if (needRefresh.compareAndSet(true, false)) {
basePrincipal = principalSupplier.refresh(basePrincipal);
}
}
private class AuthenticatedMutableRef implements MutableRef<T> {
final T ref;
private AuthenticatedMutableRef(T ref) {
this.ref = ref;
}
public T ref() {
return ref;
}
public void close() {
needRefresh.set(true);
}
}
}
private ReusableAuth() {
}
}

View File

@ -58,7 +58,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
private final Map<Long, CompletableFuture<WebSocketResponseMessage>> requestMap = new ConcurrentHashMap<>();
private final ReusableAuth<T> reusableAuth;
private final Optional<T> reusableAuth;
private final WebSocketMessageFactory messageFactory;
private final Optional<WebSocketConnectListener> connectListener;
private final ApplicationHandler jerseyHandler;
@ -77,7 +77,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
String remoteAddressPropertyName,
ApplicationHandler jerseyHandler,
WebsocketRequestLog requestLog,
ReusableAuth<T> authenticated,
Optional<T> authenticated,
WebSocketMessageFactory messageFactory,
Optional<WebSocketConnectListener> connectListener,
Duration idleTimeout) {
@ -97,7 +97,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
this.remoteEndpoint = session.getRemote();
this.context = new WebSocketSessionContext(
new WebSocketClient(session, remoteEndpoint, messageFactory, requestMap));
this.context.setAuthenticated(reusableAuth.ref().orElse(null));
this.context.setAuthenticated(reusableAuth.orElse(null));
this.session.setIdleTimeout(idleTimeout);
connectListener.ifPresent(listener -> listener.onWebSocketConnect(this.context));
@ -164,16 +164,10 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
/**
* The property name where {@link org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider} can find an
* {@link ReusableAuth} object that lives for the lifetime of the websocket
* authenticated principal that lives for the lifetime of the websocket
*/
public static final String REUSABLE_AUTH_PROPERTY = WebSocketResourceProvider.class.getName() + ".reusableAuth";
/**
* The property name where {@link org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider} can install a
* {@link org.whispersystems.websocket.ReusableAuth.MutableRef} for us to close when the request is finished
*/
public static final String RESOLVED_PRINCIPAL_PROPERTY = WebSocketResourceProvider.class.getName() + ".resolvedPrincipal";
/**
* The property name where request byte count is stored for metrics collection
*/
@ -205,16 +199,6 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
containerRequest, responseBody);
responseFuture
.whenComplete((ignoredResponse, ignoredError) -> {
// If the request ended up being one that mutates our principal, we have to close it to indicate we're done
// with the mutation operation
final Object resolvedPrincipal = containerRequest.getProperty(RESOLVED_PRINCIPAL_PROPERTY);
if (resolvedPrincipal instanceof ReusableAuth.MutableRef<?> ref) {
ref.close();
} else if (resolvedPrincipal != null) {
logger.warn("unexpected resolved principal type {} : {}", resolvedPrincipal.getClass(), resolvedPrincipal);
}
})
.thenAccept(response -> {
try {
final int responseBytes = responseBody.size();

View File

@ -64,9 +64,9 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
try {
Optional<WebSocketAuthenticator<T>> authenticator = Optional.ofNullable(environment.getAuthenticator());
final ReusableAuth<T> authenticated = authenticator.isPresent()
final Optional<T> authenticated = authenticator.isPresent()
? authenticator.get().authenticate(request)
: ReusableAuth.anonymous();
: Optional.empty();
Optional.ofNullable(environment.getAuthenticatedWebSocketUpgradeFilter())
.ifPresent(filter -> filter.handleAuthentication(authenticated, request, response));

View File

@ -6,13 +6,13 @@
package org.whispersystems.websocket.auth;
import java.security.Principal;
import java.util.Optional;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.whispersystems.websocket.ReusableAuth;
public interface AuthenticatedWebSocketUpgradeFilter<T extends Principal> {
void handleAuthentication(ReusableAuth<T> authenticated,
void handleAuthentication(@SuppressWarnings("OptionalUsedAsFieldOrParameterType") Optional<T> authenticated,
JettyServerUpgradeRequest request,
JettyServerUpgradeResponse response);
}

View File

@ -1,26 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket.auth;
import io.dropwizard.auth.Auth;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* An @{@link Auth} object annotated with {@link Mutable} indicates that the consumer of the object
* will modify the object or its underlying canonical source.
*
* Note: An {@link Auth} object that does not specify @{@link ReadOnly} will be assumed to be @Mutable
*
* @see org.whispersystems.websocket.ReusableAuth
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD, ElementType.PARAMETER})
public @interface Mutable {
}

View File

@ -1,58 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket.auth;
/**
* Teach {@link org.whispersystems.websocket.ReusableAuth} how to make a deep copy of a principal (that is safe to
* concurrently modify while the original principal is being read), and how to refresh a principal after it has been
* potentially modified.
*
* @param <T> The underlying principal type
*/
public interface PrincipalSupplier<T> {
/**
* Re-fresh the principal after it has been modified.
* <p>
* If the principal is populated from a backing store, refresh should re-read it.
*
* @param t the potentially stale principal to refresh
* @return The up-to-date principal
*/
T refresh(T t);
/**
* Create a deep, in-memory copy of the principal. This should be identical to the original principal, but should
* share no mutable state with the original. It should be safe for two threads to independently write and read from
* two independent deep copies.
*
* @param t the principal to copy
* @return An in-memory copy of the principal
*/
T deepCopy(T t);
class ImmutablePrincipalSupplier<T> implements PrincipalSupplier<T> {
@SuppressWarnings({"rawtypes"})
private static final PrincipalSupplier INSTANCE = new ImmutablePrincipalSupplier();
@Override
public T refresh(final T t) {
return t;
}
@Override
public T deepCopy(final T t) {
return t;
}
}
/**
* @return A principal supplier that can be used if the principal type does not support modification.
*/
static <T> PrincipalSupplier<T> forImmutablePrincipal() {
//noinspection unchecked
return (PrincipalSupplier<T>) ImmutablePrincipalSupplier.INSTANCE;
}
}

View File

@ -1,25 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket.auth;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* An @{@link io.dropwizard.auth.Auth} object annotated with {@link ReadOnly} indicates that the consumer of the object
* will never modify the object, nor its underlying canonical source.
* <p>
* For example, a consumer of a @ReadOnly AuthenticatedAccount promises to never modify the in-memory
* AuthenticatedAccount and to never modify the underlying Account database for the account.
*
* @see org.whispersystems.websocket.ReusableAuth
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD, ElementType.PARAMETER})
public @interface ReadOnly {
}

View File

@ -5,9 +5,20 @@
package org.whispersystems.websocket.auth;
import java.security.Principal;
import java.util.Optional;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.websocket.ReusableAuth;
public interface WebSocketAuthenticator<T extends Principal> {
ReusableAuth<T> authenticate(UpgradeRequest request) throws InvalidCredentialsException;
/**
* Authenticates an account from credential headers provided in a WebSocket upgrade request.
*
* @param request the request from which to extract credentials
*
* @return the authenticated principal if credentials were provided and authenticated or empty if the caller is
* anonymous
*
* @throws InvalidCredentialsException if credentials were provided, but could not be authenticated
*/
Optional<T> authenticate(UpgradeRequest request) throws InvalidCredentialsException;
}

View File

@ -21,7 +21,6 @@ import org.glassfish.jersey.server.model.Parameter;
import org.glassfish.jersey.server.spi.internal.ValueParamProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProvider;
@Singleton
@ -43,36 +42,30 @@ public class WebsocketAuthValueFactoryProvider<T extends Principal> extends Abst
return null;
}
final boolean readOnly = parameter.isAnnotationPresent(ReadOnly.class);
final boolean readOnly = true;
if (parameter.getRawType() == Optional.class
&& ParameterizedType.class.isAssignableFrom(parameter.getType().getClass())
&& principalClass == ((ParameterizedType) parameter.getType()).getActualTypeArguments()[0]) {
return containerRequest -> createPrincipal(containerRequest, readOnly);
return this::createPrincipal;
} else if (principalClass.equals(parameter.getRawType())) {
return containerRequest ->
createPrincipal(containerRequest, readOnly)
createPrincipal(containerRequest)
.orElseThrow(() -> new WebApplicationException("Authenticated resource", 401));
} else {
throw new IllegalStateException("Can't inject unassignable principal: " + principalClass + " for parameter: " + parameter);
}
}
private Optional<? extends Principal> createPrincipal(final ContainerRequest request, final boolean readOnly) {
private Optional<? extends Principal> createPrincipal(final ContainerRequest request) {
final Object obj = request.getProperty(WebSocketResourceProvider.REUSABLE_AUTH_PROPERTY);
if (!(obj instanceof ReusableAuth<?>)) {
if (!(obj instanceof Optional<?>)) {
logger.warn("Unexpected reusable auth property type {} : {}", obj.getClass(), obj);
return Optional.empty();
}
@SuppressWarnings("unchecked") final ReusableAuth<T> reusableAuth = (ReusableAuth<T>) obj;
if (readOnly) {
return reusableAuth.ref();
} else {
return reusableAuth.mutableRef().map(writeRef -> {
request.setProperty(WebSocketResourceProvider.RESOLVED_PRINCIPAL_PROPERTY, writeRef);
return writeRef.ref();
});
}
//noinspection unchecked
return (Optional<T>) obj;
}
@Singleton

View File

@ -17,6 +17,7 @@ import io.dropwizard.jersey.DropwizardResourceConfig;
import jakarta.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.security.Principal;
import java.util.Optional;
import javax.security.auth.Subject;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
@ -26,7 +27,6 @@ import org.glassfish.jersey.server.ResourceConfig;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
@ -75,7 +75,7 @@ public class WebSocketResourceProviderFactoryTest {
when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request)))
.thenReturn(ReusableAuth.authenticated(account, PrincipalSupplier.forImmutablePrincipal()));
.thenReturn(Optional.of(account));
when(environment.jersey()).thenReturn(jerseyEnvironment);
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1");
@ -129,8 +129,7 @@ public class WebSocketResourceProviderFactoryTest {
@Test
void testAuthenticatedWebSocketUpgradeFilter() throws InvalidCredentialsException {
final Account account = new Account();
final ReusableAuth<Account> reusableAuth =
ReusableAuth.authenticated(account, PrincipalSupplier.forImmutablePrincipal());
final Optional<Account> reusableAuth = Optional.of(account);
when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenReturn(reusableAuth);

View File

@ -59,7 +59,6 @@ import org.glassfish.jersey.server.ResourceConfig;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.logging.WebsocketRequestLog;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
@ -81,7 +80,7 @@ class WebSocketResourceProviderTest {
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME,
applicationHandler, requestLog,
immutableTestPrincipal("fooz"),
Optional.of(new TestPrincipal("fooz")),
new ProtobufWebSocketMessageFactory(),
Optional.of(connectListener),
Duration.ofMillis(30000));
@ -109,7 +108,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("foo")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -186,7 +185,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("foo")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -242,7 +241,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("foo")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -282,7 +281,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("foo")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -322,7 +321,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("authorizedUserName"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("authorizedUserName")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -362,7 +361,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, ReusableAuth.anonymous(),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.empty(),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -401,7 +400,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("something"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("something")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -441,7 +440,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, ReusableAuth.anonymous(),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.empty(),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -481,7 +480,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("gooduser")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -522,7 +521,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("gooduser")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -564,7 +563,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("gooduser")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -604,7 +603,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("gooduser")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -729,10 +728,6 @@ class WebSocketResourceProviderTest {
}
}
public static ReusableAuth<TestPrincipal> immutableTestPrincipal(final String name) {
return ReusableAuth.authenticated(new TestPrincipal(name), PrincipalSupplier.forImmutablePrincipal());
}
public static class TestException extends Exception {
public TestException(String message) {