Don't cache authenticated accounts in memory

This commit is contained in:
Jon Chambers 2025-06-23 08:40:05 -05:00 committed by GitHub
parent 9dfe51eac4
commit c952baa672
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
86 changed files with 961 additions and 2264 deletions

View File

@ -85,7 +85,6 @@ import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator
import org.whispersystems.textsecuregcm.auth.IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter;
import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager;
import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager;
import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener;
import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor;
import org.whispersystems.textsecuregcm.backup.BackupAuthManager;
@ -212,7 +211,6 @@ import org.whispersystems.textsecuregcm.spam.RegistrationRecoveryChecker;
import org.whispersystems.textsecuregcm.spam.SpamChecker;
import org.whispersystems.textsecuregcm.spam.SpamFilter;
import org.whispersystems.textsecuregcm.storage.AccountLockManager;
import org.whispersystems.textsecuregcm.storage.AccountPrincipalSupplier;
import org.whispersystems.textsecuregcm.storage.Accounts;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ChangeNumberManager;
@ -980,23 +978,19 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.jersey().register(MultiRecipientMessageProvider.class);
environment.jersey().register(new AuthDynamicFeature(accountAuthFilter));
environment.jersey().register(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class));
environment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager,
disconnectionRequestManager));
environment.jersey().register(new TimestampResponseFilter());
///
WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment = new WebSocketEnvironment<>(environment,
config.getWebSocketConfiguration(), Duration.ofMillis(90000));
webSocketEnvironment.jersey().register(new VirtualExecutorServiceProvider("managed-async-websocket-virtual-thread-"));
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator, new AccountPrincipalSupplier(accountsManager)));
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator));
webSocketEnvironment.setAuthenticatedWebSocketUpgradeFilter(new IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter(
keysManager, config.idlePrimaryDeviceReminderConfiguration().minIdleDuration(), Clock.systemUTC()));
webSocketEnvironment.setConnectListener(
new AuthenticatedConnectListener(receiptSender, messagesManager, messageMetrics, pushNotificationManager,
new AuthenticatedConnectListener(accountsManager, receiptSender, messagesManager, messageMetrics, pushNotificationManager,
pushNotificationScheduler, webSocketConnectionEventManager, websocketScheduledExecutor,
messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor, experimentEnrollmentManager));
webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(accountsManager, disconnectionRequestManager));
webSocketEnvironment.jersey().register(new RateLimitByIpFilter(rateLimiters));
webSocketEnvironment.jersey().register(new RequestStatisticsFilter(TrafficSource.WEBSOCKET));
webSocketEnvironment.jersey().register(MultiRecipientMessageProvider.class);
@ -1083,15 +1077,15 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
registrationLockVerificationManager, rateLimiters),
new AttachmentControllerV4(rateLimiters, gcsAttachmentGenerator, tusAttachmentGenerator,
experimentEnrollmentManager),
new ArchiveController(backupAuthManager, backupManager, backupMetrics),
new ArchiveController(accountsManager, backupAuthManager, backupManager, backupMetrics),
new CallRoutingControllerV2(rateLimiters, cloudflareTurnCredentialsManager),
new CallLinkController(rateLimiters, callingGenericZkSecretParams),
new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().certificate(),
new CertificateController(accountsManager, new CertificateGenerator(config.getDeliveryCertificate().certificate(),
config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()),
zkAuthOperations, callingGenericZkSecretParams, clock),
new ChallengeController(rateLimitChallengeManager, challengeConstraintChecker),
new ChallengeController(accountsManager, rateLimitChallengeManager, challengeConstraintChecker),
new DeviceController(accountsManager, clientPublicKeysManager, rateLimiters, persistentTimer, config.getMaxDevices()),
new DeviceCheckController(clock, backupAuthManager, appleDeviceCheckManager, rateLimiters,
new DeviceCheckController(clock, accountsManager, backupAuthManager, appleDeviceCheckManager, rateLimiters,
config.getDeviceCheck().backupRedemptionLevel(),
config.getDeviceCheck().backupRedemptionDuration()),
new DirectoryV2Controller(directoryV2CredentialsGenerator),
@ -1140,8 +1134,6 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
WebSocketEnvironment<AuthenticatedDevice> provisioningEnvironment = new WebSocketEnvironment<>(environment,
webSocketEnvironment.getRequestLog(), Duration.ofMillis(60000));
provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager,
disconnectionRequestManager));
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager, provisioningWebsocketTimeoutExecutor, Duration.ofSeconds(90)));
provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET, clientReleaseManager));
provisioningEnvironment.jersey().register(new KeepAliveController(webSocketConnectionEventManager));

View File

@ -7,10 +7,20 @@ package org.whispersystems.textsecuregcm.auth;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import java.time.Instant;
import java.util.UUID;
public interface AccountAndAuthenticatedDeviceHolder {
UUID getAccountIdentifier();
byte getDeviceId();
Instant getPrimaryDeviceLastSeen();
@Deprecated(forRemoval = true)
Account getAccount();
@Deprecated(forRemoval = true)
Device getAuthenticatedDevice();
}

View File

@ -6,7 +6,10 @@
package org.whispersystems.textsecuregcm.auth;
import java.security.Principal;
import java.time.Instant;
import java.util.UUID;
import javax.security.auth.Subject;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
@ -30,6 +33,21 @@ public class AuthenticatedDevice implements Principal, AccountAndAuthenticatedDe
return device;
}
@Override
public UUID getAccountIdentifier() {
return account.getIdentifier(IdentityType.ACI);
}
@Override
public byte getDeviceId() {
return device.getId();
}
@Override
public Instant getPrimaryDeviceLastSeen() {
return Instant.ofEpochMilli(account.getPrimaryDevice().getLastSeen());
}
// Principal implementation
@Override

View File

@ -31,9 +31,9 @@ public class CertificateGenerator {
this.serverCertificate = ServerCertificate.parseFrom(serverCertificate);
}
public byte[] createFor(Account account, Device device, boolean includeE164) throws InvalidKeyException {
public byte[] createFor(final Account account, final byte deviceId, boolean includeE164) throws InvalidKeyException {
SenderCertificate.Certificate.Builder builder = SenderCertificate.Certificate.newBuilder()
.setSenderDevice(Math.toIntExact(device.getId()))
.setSenderDevice(Math.toIntExact(deviceId))
.setExpires(System.currentTimeMillis() + TimeUnit.DAYS.toMillis(expiresDays))
.setIdentityKey(ByteString.copyFrom(account.getIdentityKey(IdentityType.ACI).serialize()))
.setSigner(serverCertificate)

View File

@ -1,44 +0,0 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import jakarta.ws.rs.core.SecurityContext;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import org.glassfish.jersey.server.ContainerRequest;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
class ContainerRequestUtil {
/**
* A read-only subset of the authenticated Account object, to enforce that filter-based consumers do not perform
* account modifying operations.
*/
record AccountInfo(UUID accountId, String e164, Set<Byte> deviceIds) {
static AccountInfo fromAccount(final Account account) {
return new AccountInfo(
account.getUuid(),
account.getNumber(),
account.getDevices().stream().map(Device::getId).collect(Collectors.toSet()));
}
}
static Optional<AccountInfo> getAuthenticatedAccount(final ContainerRequest request) {
return Optional.ofNullable(request.getSecurityContext())
.map(SecurityContext::getUserPrincipal)
.map(principal -> {
if (principal instanceof AccountAndAuthenticatedDeviceHolder aaadh) {
return aaadh.getAccount();
}
return null;
})
.map(AccountInfo::fromAccount);
}
}

View File

@ -8,17 +8,16 @@ package org.whispersystems.textsecuregcm.auth;
import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Optional;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter;
public class IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter implements
AuthenticatedWebSocketUpgradeFilter<AuthenticatedDevice> {
@ -58,21 +57,19 @@ public class IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter implements
}
@Override
public void handleAuthentication(final ReusableAuth<AuthenticatedDevice> authenticated,
public void handleAuthentication(final Optional<AuthenticatedDevice> authenticated,
final JettyServerUpgradeRequest request,
final JettyServerUpgradeResponse response) {
// No action needed if the connection is unauthenticated (in which case we don't know when we've last seen the
// primary device) or if the authenticated device IS the primary device
authenticated.ref()
.filter(authenticatedDevice -> !authenticatedDevice.getAuthenticatedDevice().isPrimary())
authenticated
.filter(authenticatedDevice -> authenticatedDevice.getDeviceId() != Device.PRIMARY_ID)
.ifPresent(authenticatedDevice -> {
final Instant primaryDeviceLastSeen =
Instant.ofEpochMilli(authenticatedDevice.getAccount().getPrimaryDevice().getLastSeen());
final Instant primaryDeviceLastSeen = authenticatedDevice.getPrimaryDeviceLastSeen();
if (primaryDeviceLastSeen.isBefore(clock.instant().minus(PQ_KEY_CHECK_THRESHOLD)) &&
keysManager.getLastResort(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI), Device.PRIMARY_ID)
.join().isEmpty()) {
keysManager.getLastResort(authenticatedDevice.getAccountIdentifier(), Device.PRIMARY_ID).join().isEmpty()) {
response.addHeader(ALERT_HEADER, CRITICAL_IDLE_PRIMARY_DEVICE_ALERT);
CRITICAL_IDLE_PRIMARY_WARNING_COUNTER.increment();

View File

@ -1,96 +0,0 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.util.Pair;
/**
* This {@link WebsocketRefreshRequirementProvider} observes intra-request changes in devices linked to an
* {@link Account} and triggers a WebSocket refresh if that set changes. If a change in linked devices is observed, then
* any active WebSocket connections for the account must be closed in order for clients to get a refreshed
* {@link io.dropwizard.auth.Auth} object with a current device list.
*
* @see AuthenticatedDevice
*/
public class LinkedDeviceRefreshRequirementProvider implements WebsocketRefreshRequirementProvider {
private final AccountsManager accountsManager;
private static final Logger logger = LoggerFactory.getLogger(LinkedDeviceRefreshRequirementProvider.class);
private static final String ACCOUNT_UUID = LinkedDeviceRefreshRequirementProvider.class.getName() + ".accountUuid";
private static final String LINKED_DEVICE_IDS = LinkedDeviceRefreshRequirementProvider.class.getName() + ".deviceIds";
public LinkedDeviceRefreshRequirementProvider(final AccountsManager accountsManager) {
this.accountsManager = accountsManager;
}
@Override
public void handleRequestFiltered(final RequestEvent requestEvent) {
if (requestEvent.getUriInfo().getMatchedResourceMethod().getInvocable().getHandlingMethod().getAnnotation(
ChangesLinkedDevices.class) != null) {
// The authenticated principal, if any, will be available after filters have run. Now that the account is known,
// capture a snapshot of the account's linked devices before carrying out the requests business logic.
ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest())
.ifPresent(account -> setAccount(requestEvent.getContainerRequest(), account));
}
}
public static void setAccount(final ContainerRequest containerRequest, final Account account) {
setAccount(containerRequest, ContainerRequestUtil.AccountInfo.fromAccount(account));
}
private static void setAccount(final ContainerRequest containerRequest, final ContainerRequestUtil.AccountInfo info) {
containerRequest.setProperty(ACCOUNT_UUID, info.accountId());
containerRequest.setProperty(LINKED_DEVICE_IDS, info.deviceIds());
}
@Override
public List<Pair<UUID, Byte>> handleRequestFinished(final RequestEvent requestEvent) {
// Now that the request is finished, check whether the set of linked devices has changed. If the value did change or
// if a devices was added or removed, all devices must disconnect and reauthenticate.
if (requestEvent.getContainerRequest().getProperty(LINKED_DEVICE_IDS) != null) {
@SuppressWarnings("unchecked") final Set<Byte> initialLinkedDeviceIds =
(Set<Byte>) requestEvent.getContainerRequest().getProperty(LINKED_DEVICE_IDS);
return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID))
.map(ContainerRequestUtil.AccountInfo::fromAccount)
.map(accountInfo -> {
final Set<Byte> deviceIdsToDisplace;
final Set<Byte> currentLinkedDeviceIds = accountInfo.deviceIds();
if (!initialLinkedDeviceIds.equals(currentLinkedDeviceIds)) {
deviceIdsToDisplace = new HashSet<>(initialLinkedDeviceIds);
deviceIdsToDisplace.addAll(currentLinkedDeviceIds);
} else {
deviceIdsToDisplace = Collections.emptySet();
}
return deviceIdsToDisplace.stream()
.map(deviceId -> new Pair<>(accountInfo.accountId(), deviceId))
.collect(Collectors.toList());
}).orElseGet(() -> {
logger.error("Request had account, but it is no longer present");
return Collections.emptyList();
});
} else {
return Collections.emptyList();
}
}
}

View File

@ -1,56 +0,0 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.util.Pair;
public class PhoneNumberChangeRefreshRequirementProvider implements WebsocketRefreshRequirementProvider {
private static final String ACCOUNT_UUID =
PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".accountUuid";
private static final String INITIAL_NUMBER_KEY =
PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".initialNumber";
private final AccountsManager accountsManager;
public PhoneNumberChangeRefreshRequirementProvider(final AccountsManager accountsManager) {
this.accountsManager = accountsManager;
}
@Override
public void handleRequestFiltered(final RequestEvent requestEvent) {
if (requestEvent.getUriInfo().getMatchedResourceMethod().getInvocable().getHandlingMethod()
.getAnnotation(ChangesPhoneNumber.class) == null) {
return;
}
ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest())
.ifPresent(account -> {
requestEvent.getContainerRequest().setProperty(INITIAL_NUMBER_KEY, account.e164());
requestEvent.getContainerRequest().setProperty(ACCOUNT_UUID, account.accountId());
});
}
@Override
public List<Pair<UUID, Byte>> handleRequestFinished(final RequestEvent requestEvent) {
final String initialNumber = (String) requestEvent.getContainerRequest().getProperty(INITIAL_NUMBER_KEY);
if (initialNumber == null) {
return Collections.emptyList();
}
return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID))
.filter(account -> !initialNumber.equals(account.getNumber()))
.map(account -> account.getDevices().stream()
.map(device -> new Pair<>(account.getUuid(), device.getId()))
.collect(Collectors.toList()))
.orElse(Collections.emptyList());
}
}

View File

@ -1,38 +0,0 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import org.glassfish.jersey.server.monitoring.ApplicationEvent;
import org.glassfish.jersey.server.monitoring.ApplicationEventListener;
import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.glassfish.jersey.server.monitoring.RequestEventListener;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
/**
* Delegates request events to a listener that watches for intra-request changes that require websocket refreshes
*/
public class WebsocketRefreshApplicationEventListener implements ApplicationEventListener {
private final WebsocketRefreshRequestEventListener websocketRefreshRequestEventListener;
public WebsocketRefreshApplicationEventListener(final AccountsManager accountsManager,
final DisconnectionRequestManager disconnectionRequestManager) {
this.websocketRefreshRequestEventListener = new WebsocketRefreshRequestEventListener(
disconnectionRequestManager,
new LinkedDeviceRefreshRequirementProvider(accountsManager),
new PhoneNumberChangeRefreshRequirementProvider(accountsManager));
}
@Override
public void onEvent(final ApplicationEvent event) {
}
@Override
public RequestEventListener onRequest(final RequestEvent requestEvent) {
return websocketRefreshRequestEventListener;
}
}

View File

@ -1,74 +0,0 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import jakarta.ws.rs.container.ResourceInfo;
import jakarta.ws.rs.core.Context;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.glassfish.jersey.server.monitoring.RequestEvent.Type;
import org.glassfish.jersey.server.monitoring.RequestEventListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class WebsocketRefreshRequestEventListener implements RequestEventListener {
private final DisconnectionRequestManager disconnectionRequestManager;
private final WebsocketRefreshRequirementProvider[] providers;
private static final Counter DISPLACED_ACCOUNTS = Metrics.counter(
name(WebsocketRefreshRequestEventListener.class, "displacedAccounts"));
private static final Counter DISPLACED_DEVICES = Metrics.counter(
name(WebsocketRefreshRequestEventListener.class, "displacedDevices"));
private static final Logger logger = LoggerFactory.getLogger(WebsocketRefreshRequestEventListener.class);
public WebsocketRefreshRequestEventListener(
final DisconnectionRequestManager disconnectionRequestManager,
final WebsocketRefreshRequirementProvider... providers) {
this.disconnectionRequestManager = disconnectionRequestManager;
this.providers = providers;
}
@Context
private ResourceInfo resourceInfo;
@Override
public void onEvent(final RequestEvent event) {
if (event.getType() == Type.REQUEST_FILTERED) {
for (final WebsocketRefreshRequirementProvider provider : providers) {
provider.handleRequestFiltered(event);
}
} else if (event.getType() == Type.FINISHED) {
final AtomicInteger displacedDevices = new AtomicInteger(0);
Arrays.stream(providers)
.flatMap(provider -> provider.handleRequestFinished(event).stream())
.distinct()
.forEach(pair -> {
try {
displacedDevices.incrementAndGet();
disconnectionRequestManager.requestDisconnection(pair.first(), List.of(pair.second()));
} catch (final Exception e) {
logger.error("Could not displace device presence", e);
}
});
if (displacedDevices.get() > 0) {
DISPLACED_ACCOUNTS.increment();
DISPLACED_DEVICES.increment(displacedDevices.get());
}
}
}
}

View File

@ -1,34 +0,0 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import java.util.List;
import java.util.UUID;
import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.whispersystems.textsecuregcm.util.Pair;
/**
* A websocket refresh requirement provider watches for intra-request changes (e.g. to authentication status) that
* require a websocket refresh.
*/
public interface WebsocketRefreshRequirementProvider {
/**
* Processes a request after filters have run and the request has been mapped to a destination controller.
*
* @param requestEvent the request event to observe
*/
void handleRequestFiltered(RequestEvent requestEvent);
/**
* Processes a request after all normal request handling has been completed.
*
* @param requestEvent the request event to observe
* @return a list of pairs of account UUID/device ID pairs identifying websockets that need to be refreshed as a
* result of the observed request
*/
List<Pair<UUID, Byte>> handleRequestFinished(RequestEvent requestEvent);
}

View File

@ -66,8 +66,6 @@ import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.auth.Mutable;
import org.whispersystems.websocket.auth.ReadOnly;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v1/accounts")
@ -97,11 +95,14 @@ public class AccountController {
@Path("/gcm/")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public void setGcmRegistrationId(@Mutable @Auth AuthenticatedDevice auth,
public void setGcmRegistrationId(@Auth AuthenticatedDevice auth,
@NotNull @Valid GcmRegistrationId registrationId) {
final Account account = auth.getAccount();
final Device device = auth.getAuthenticatedDevice();
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final Device device = account.getDevice(auth.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
if (Objects.equals(device.getGcmId(), registrationId.gcmRegistrationId())) {
return;
@ -116,9 +117,12 @@ public class AccountController {
@DELETE
@Path("/gcm/")
public void deleteGcmRegistrationId(@Mutable @Auth AuthenticatedDevice auth) {
Account account = auth.getAccount();
Device device = auth.getAuthenticatedDevice();
public void deleteGcmRegistrationId(@Auth AuthenticatedDevice auth) {
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final Device device = account.getDevice(auth.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
accounts.updateDevice(account, device.getId(), d -> {
d.setGcmId(null);
@ -131,11 +135,14 @@ public class AccountController {
@Path("/apn/")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public void setApnRegistrationId(@Mutable @Auth AuthenticatedDevice auth,
public void setApnRegistrationId(@Auth AuthenticatedDevice auth,
@NotNull @Valid ApnRegistrationId registrationId) {
final Account account = auth.getAccount();
final Device device = auth.getAuthenticatedDevice();
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final Device device = account.getDevice(auth.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
// Unlike FCM tokens, we need current "last updated" timestamps for APNs tokens and so update device records
// unconditionally
@ -148,9 +155,12 @@ public class AccountController {
@DELETE
@Path("/apn/")
public void deleteApnRegistrationId(@Mutable @Auth AuthenticatedDevice auth) {
Account account = auth.getAccount();
Device device = auth.getAuthenticatedDevice();
public void deleteApnRegistrationId(@Auth AuthenticatedDevice auth) {
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final Device device = account.getDevice(auth.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
accounts.updateDevice(account, device.getId(), d -> {
d.setApnId(null);
@ -166,17 +176,23 @@ public class AccountController {
@PUT
@Produces(MediaType.APPLICATION_JSON)
@Path("/registration_lock")
public void setRegistrationLock(@Mutable @Auth AuthenticatedDevice auth, @NotNull @Valid RegistrationLock accountLock) {
SaltedTokenHash credentials = SaltedTokenHash.generateFor(accountLock.getRegistrationLock());
public void setRegistrationLock(@Auth AuthenticatedDevice auth, @NotNull @Valid RegistrationLock accountLock) {
final SaltedTokenHash credentials = SaltedTokenHash.generateFor(accountLock.getRegistrationLock());
accounts.update(auth.getAccount(),
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
accounts.update(account,
a -> a.setRegistrationLock(credentials.hash(), credentials.salt()));
}
@DELETE
@Path("/registration_lock")
public void removeRegistrationLock(@Mutable @Auth AuthenticatedDevice auth) {
accounts.update(auth.getAccount(), a -> a.setRegistrationLock(null, null));
public void removeRegistrationLock(@Auth AuthenticatedDevice auth) {
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
accounts.update(account, a -> a.setRegistrationLock(null, null));
}
@PUT
@ -190,7 +206,7 @@ public class AccountController {
@ApiResponse(responseCode = "204", description = "Device name changed successfully")
@ApiResponse(responseCode = "404", description = "No device found with the given ID")
@ApiResponse(responseCode = "403", description = "Not authorized to change the name of the device with the given ID")
public void setName(@Mutable @Auth final AuthenticatedDevice auth,
public void setName(@Auth final AuthenticatedDevice auth,
@NotNull @Valid final DeviceName deviceName,
@Nullable
@ -199,15 +215,16 @@ public class AccountController {
requiredMode = Schema.RequiredMode.NOT_REQUIRED)
final Byte deviceId) {
final Account account = auth.getAccount();
final byte targetDeviceId = deviceId == null ? auth.getAuthenticatedDevice().getId() : deviceId;
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final byte targetDeviceId = deviceId == null ? auth.getDeviceId() : deviceId;
if (account.getDevice(targetDeviceId).isEmpty()) {
throw new NotFoundException();
}
final boolean mayChangeName = auth.getAuthenticatedDevice().isPrimary() ||
auth.getAuthenticatedDevice().getId() == targetDeviceId;
final boolean mayChangeName = auth.getDeviceId() == Device.PRIMARY_ID || auth.getDeviceId() == targetDeviceId;
if (!mayChangeName) {
throw new ForbiddenException();
@ -221,14 +238,14 @@ public class AccountController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public void setAccountAttributes(
@Mutable @Auth AuthenticatedDevice auth,
@Auth AuthenticatedDevice auth,
@HeaderParam(HeaderUtils.X_SIGNAL_AGENT) String userAgent,
@NotNull @Valid AccountAttributes attributes) {
final Account account = auth.getAccount();
final byte deviceId = auth.getAuthenticatedDevice().getId();
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final Account updatedAccount = accounts.update(account, a -> {
a.getDevice(deviceId).ifPresent(d -> {
a.getDevice(auth.getDeviceId()).ifPresent(d -> {
d.setFetchesMessages(attributes.getFetchesMessages());
d.setName(attributes.getName());
d.setLastSeen(Util.todayInMillis());
@ -252,8 +269,11 @@ public class AccountController {
@GET
@Path("/whoami")
@Produces(MediaType.APPLICATION_JSON)
public AccountIdentityResponse whoAmI(@ReadOnly @Auth AuthenticatedDevice auth) {
return AccountIdentityResponseBuilder.fromAccount(auth.getAccount());
public AccountIdentityResponse whoAmI(@Auth final AuthenticatedDevice auth) {
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
return AccountIdentityResponseBuilder.fromAccount(account);
}
@DELETE
@ -267,8 +287,11 @@ public class AccountController {
)
@ApiResponse(responseCode = "204", description = "Username successfully deleted.", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
public CompletableFuture<Response> deleteUsernameHash(@Mutable @Auth final AuthenticatedDevice auth) {
return accounts.clearUsernameHash(auth.getAccount())
public CompletableFuture<Response> deleteUsernameHash(@Auth final AuthenticatedDevice auth) {
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
return accounts.clearUsernameHash(account)
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
}
@ -289,10 +312,13 @@ public class AccountController {
@ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Ratelimited.")
public CompletableFuture<ReserveUsernameHashResponse> reserveUsernameHash(
@Mutable @Auth final AuthenticatedDevice auth,
@Auth final AuthenticatedDevice auth,
@NotNull @Valid final ReserveUsernameHashRequest usernameRequest) throws RateLimitExceededException {
rateLimiters.getUsernameReserveLimiter().validate(auth.getAccount().getUuid());
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
rateLimiters.getUsernameReserveLimiter().validate(auth.getAccountIdentifier());
for (final byte[] hash : usernameRequest.usernameHashes()) {
if (hash.length != USERNAME_HASH_LENGTH) {
@ -300,7 +326,7 @@ public class AccountController {
}
}
return accounts.reserveUsernameHash(auth.getAccount(), usernameRequest.usernameHashes())
return accounts.reserveUsernameHash(account, usernameRequest.usernameHashes())
.thenApply(reservation -> new ReserveUsernameHashResponse(reservation.reservedUsernameHash()))
.exceptionally(throwable -> {
if (ExceptionUtils.unwrap(throwable) instanceof UsernameHashNotAvailableException) {
@ -329,18 +355,21 @@ public class AccountController {
@ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Ratelimited.")
public CompletableFuture<UsernameHashResponse> confirmUsernameHash(
@Mutable @Auth final AuthenticatedDevice auth,
@Auth final AuthenticatedDevice auth,
@NotNull @Valid final ConfirmUsernameHashRequest confirmRequest) {
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
try {
usernameHashZkProofVerifier.verifyProof(confirmRequest.zkProof(), confirmRequest.usernameHash());
} catch (final BaseUsernameException e) {
throw new WebApplicationException(Response.status(422).build());
}
return rateLimiters.getUsernameSetLimiter().validateAsync(auth.getAccount().getUuid())
return rateLimiters.getUsernameSetLimiter().validateAsync(account.getUuid())
.thenCompose(ignored -> accounts.confirmReservedUsernameHash(
auth.getAccount(),
account,
confirmRequest.usernameHash(),
confirmRequest.encryptedUsername()))
.thenApply(updatedAccount -> new UsernameHashResponse(updatedAccount.getUsernameHash()
@ -374,7 +403,7 @@ public class AccountController {
@ApiResponse(responseCode = "400", description = "Request must not be authenticated.")
@ApiResponse(responseCode = "404", description = "Account not found for the given username.")
public CompletableFuture<AccountIdentifierResponse> lookupUsernameHash(
@ReadOnly @Auth final Optional<AuthenticatedDevice> maybeAuthenticatedAccount,
@Auth final Optional<AuthenticatedDevice> maybeAuthenticatedAccount,
@PathParam("usernameHash") final String usernameHash) {
requireNotAuthenticated(maybeAuthenticatedAccount);
@ -413,12 +442,14 @@ public class AccountController {
@ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Ratelimited.")
public UsernameLinkHandle updateUsernameLink(
@Mutable @Auth final AuthenticatedDevice auth,
@Auth final AuthenticatedDevice auth,
@NotNull @Valid final EncryptedUsername encryptedUsername) throws RateLimitExceededException {
// check ratelimiter for username link operations
rateLimiters.forDescriptor(RateLimiters.For.USERNAME_LINK_OPERATION).validate(auth.getAccount().getUuid());
final Account account = auth.getAccount();
// check ratelimiter for username link operations
rateLimiters.forDescriptor(RateLimiters.For.USERNAME_LINK_OPERATION).validate(auth.getAccountIdentifier());
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
// check if username hash is set for the account
if (account.getUsernameHash().isEmpty()) {
@ -431,7 +462,7 @@ public class AccountController {
} else {
usernameLinkHandle = UUID.randomUUID();
}
updateUsernameLink(auth.getAccount(), usernameLinkHandle, encryptedUsername.usernameLinkEncryptedValue());
updateUsernameLink(account, usernameLinkHandle, encryptedUsername.usernameLinkEncryptedValue());
return new UsernameLinkHandle(usernameLinkHandle);
}
@ -447,10 +478,14 @@ public class AccountController {
@ApiResponse(responseCode = "204", description = "Username Link successfully deleted.", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ApiResponse(responseCode = "429", description = "Ratelimited.")
public void deleteUsernameLink(@Mutable @Auth final AuthenticatedDevice auth) throws RateLimitExceededException {
public void deleteUsernameLink(@Auth final AuthenticatedDevice auth) throws RateLimitExceededException {
// check ratelimiter for username link operations
rateLimiters.forDescriptor(RateLimiters.For.USERNAME_LINK_OPERATION).validate(auth.getAccount().getUuid());
clearUsernameLink(auth.getAccount());
rateLimiters.forDescriptor(RateLimiters.For.USERNAME_LINK_OPERATION).validate(auth.getAccountIdentifier());
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
clearUsernameLink(account);
}
@GET
@ -470,7 +505,7 @@ public class AccountController {
@ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Ratelimited.")
public CompletableFuture<EncryptedUsername> lookupUsernameLink(
@ReadOnly @Auth final Optional<AuthenticatedDevice> maybeAuthenticatedAccount,
@Auth final Optional<AuthenticatedDevice> maybeAuthenticatedAccount,
@PathParam("uuid") final UUID usernameLinkHandle) {
requireNotAuthenticated(maybeAuthenticatedAccount);
@ -496,7 +531,7 @@ public class AccountController {
@Path("/account/{identifier}")
@RateLimitedByIp(RateLimiters.For.CHECK_ACCOUNT_EXISTENCE)
public Response accountExists(
@ReadOnly @Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@Parameter(description = "An ACI or PNI account identifier to check")
@PathParam("identifier") final ServiceIdentifier accountIdentifier) {
@ -511,8 +546,11 @@ public class AccountController {
@DELETE
@Path("/me")
public CompletableFuture<Response> deleteAccount(@Mutable @Auth AuthenticatedDevice auth) {
return accounts.delete(auth.getAccount(), AccountsManager.DeletionReason.USER_REQUEST).thenApply(Util.ASYNC_EMPTY_RESPONSE);
public CompletableFuture<Response> deleteAccount(@Auth AuthenticatedDevice auth) {
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
return accounts.delete(account, AccountsManager.DeletionReason.USER_REQUEST).thenApply(Util.ASYNC_EMPTY_RESPONSE);
}
private void clearUsernameLink(final Account account) {

View File

@ -55,8 +55,7 @@ import org.whispersystems.textsecuregcm.push.MessageTooLargeException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ChangeNumberManager;
import org.whispersystems.websocket.auth.Mutable;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.textsecuregcm.storage.Device;
@Path("/v2/accounts")
@io.swagger.v3.oas.annotations.tags.Tag(name = "Account")
@ -101,12 +100,12 @@ public class AccountControllerV2 {
@ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header(
name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public AccountIdentityResponse changeNumber(@Mutable @Auth final AuthenticatedDevice authenticatedDevice,
public AccountIdentityResponse changeNumber(@Auth final AuthenticatedDevice authenticatedDevice,
@NotNull @Valid final ChangeNumberRequest request,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgentString,
@Context final ContainerRequestContext requestContext) throws RateLimitExceededException, InterruptedException {
if (!authenticatedDevice.getAuthenticatedDevice().isPrimary()) {
if (authenticatedDevice.getDeviceId() != Device.PRIMARY_ID) {
throw new ForbiddenException();
}
@ -116,8 +115,11 @@ public class AccountControllerV2 {
final String number = request.number();
final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
// Only verify and check reglock if there's a data change to be made...
if (!authenticatedDevice.getAccount().getNumber().equals(number)) {
if (!account.getNumber().equals(number)) {
rateLimiters.getRegistrationLimiter().validate(number);
@ -139,7 +141,7 @@ public class AccountControllerV2 {
// ...but always attempt to make the change in case a client retries and needs to re-send messages
try {
final Account updatedAccount = changeNumberManager.changeNumber(
authenticatedDevice.getAccount(),
account,
request.number(),
request.pniIdentityKey(),
request.devicePniSignedPrekeys(),
@ -185,11 +187,11 @@ public class AccountControllerV2 {
content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class)))
@ApiResponse(responseCode = "413", description = "One or more device messages was too large")
public AccountIdentityResponse distributePhoneNumberIdentityKeys(
@Mutable @Auth final AuthenticatedDevice authenticatedDevice,
@Auth final AuthenticatedDevice authenticatedDevice,
@HeaderParam(HttpHeaders.USER_AGENT) @Nullable final String userAgentString,
@NotNull @Valid final PhoneNumberIdentityKeyDistributionRequest request) {
if (!authenticatedDevice.getAuthenticatedDevice().isPrimary()) {
if (authenticatedDevice.getDeviceId() != Device.PRIMARY_ID) {
throw new ForbiddenException();
}
@ -197,9 +199,12 @@ public class AccountControllerV2 {
throw new WebApplicationException("Invalid signature", 422);
}
final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
try {
final Account updatedAccount = changeNumberManager.updatePniKeys(
authenticatedDevice.getAccount(),
account,
request.pniIdentityKey(),
request.devicePniSignedPrekeys(),
request.devicePniPqLastResortPrekeys(),
@ -235,10 +240,13 @@ public class AccountControllerV2 {
@Operation(summary = "Sets whether the account should be discoverable by phone number in the directory.")
@ApiResponse(responseCode = "204", description = "The setting was successfully updated.")
public void setPhoneNumberDiscoverability(
@Mutable @Auth AuthenticatedDevice auth,
@NotNull @Valid PhoneNumberDiscoverabilityRequest phoneNumberDiscoverability
) {
accountsManager.update(auth.getAccount(), a -> a.setDiscoverableByPhoneNumber(
@Auth AuthenticatedDevice auth,
@NotNull @Valid PhoneNumberDiscoverabilityRequest phoneNumberDiscoverability) {
final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
accountsManager.update(account, a -> a.setDiscoverableByPhoneNumber(
phoneNumberDiscoverability.discoverableByPhoneNumber()));
}
@ -249,9 +257,10 @@ public class AccountControllerV2 {
@ApiResponse(responseCode = "200",
description = "Response with data report. A plain text representation is a field in the response.",
useReturnTypeSchema = true)
public AccountDataReportResponse getAccountDataReport(@ReadOnly @Auth final AuthenticatedDevice auth) {
public AccountDataReportResponse getAccountDataReport(@Auth final AuthenticatedDevice auth) {
final Account account = auth.getAccount();
final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return new AccountDataReportResponse(UUID.randomUUID(), Instant.now(),
new AccountDataReportResponse.AccountAndDevicesDataReport(

View File

@ -37,6 +37,7 @@ import jakarta.ws.rs.PUT;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.QueryParam;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import java.io.IOException;
@ -71,14 +72,15 @@ import org.whispersystems.textsecuregcm.backup.MediaEncryptionParameters;
import org.whispersystems.textsecuregcm.entities.RemoteAttachment;
import org.whispersystems.textsecuregcm.metrics.BackupMetrics;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.BackupAuthCredentialAdapter;
import org.whispersystems.textsecuregcm.util.ByteArrayAdapter;
import org.whispersystems.textsecuregcm.util.ByteArrayBase64UrlAdapter;
import org.whispersystems.textsecuregcm.util.ECPublicKeyAdapter;
import org.whispersystems.textsecuregcm.util.ExactlySize;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.auth.Mutable;
import org.whispersystems.websocket.auth.ReadOnly;
import reactor.core.publisher.Mono;
@Path("/v1/archives")
@ -88,14 +90,18 @@ public class ArchiveController {
public final static String X_SIGNAL_ZK_AUTH = "X-Signal-ZK-Auth";
public final static String X_SIGNAL_ZK_AUTH_SIGNATURE = "X-Signal-ZK-Auth-Signature";
private final AccountsManager accountsManager;
private final BackupAuthManager backupAuthManager;
private final BackupManager backupManager;
private final BackupMetrics backupMetrics;
public ArchiveController(
final AccountsManager accountsManager,
final BackupAuthManager backupAuthManager,
final BackupManager backupManager,
final BackupMetrics backupMetrics) {
this.accountsManager = accountsManager;
this.backupAuthManager = backupAuthManager;
this.backupManager = backupManager;
this.backupMetrics = backupMetrics;
@ -138,13 +144,22 @@ public class ArchiveController {
@ApiResponse(responseCode = "403", description = "The device did not have permission to set the backup-id. Only the primary device can set the backup-id for an account")
@ApiResponse(responseCode = "429", description = "Rate limited. Too many attempts to change the backup-id have been made")
public CompletionStage<Response> setBackupId(
@Mutable @Auth final AuthenticatedDevice account,
@Auth final AuthenticatedDevice authenticatedDevice,
@Valid @NotNull final SetBackupIdRequest setBackupIdRequest) throws RateLimitExceededException {
return this.backupAuthManager
.commitBackupId(account.getAccount(), account.getAuthenticatedDevice(),
setBackupIdRequest.messagesBackupAuthCredentialRequest,
setBackupIdRequest.mediaBackupAuthCredentialRequest)
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
return accountsManager.getByAccountIdentifierAsync(authenticatedDevice.getAccountIdentifier())
.thenCompose(maybeAccount -> {
final Account account = maybeAccount
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
final Device device = account.getDevice(authenticatedDevice.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return backupAuthManager
.commitBackupId(account, device, setBackupIdRequest.messagesBackupAuthCredentialRequest,
setBackupIdRequest.mediaBackupAuthCredentialRequest)
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
});
}
public record RedeemBackupReceiptRequest(
@ -188,12 +203,17 @@ public class ArchiveController {
@ApiResponse(responseCode = "409", description = "The target account does not have a backup-id commitment")
@ApiResponse(responseCode = "429", description = "Rate limited.")
public CompletionStage<Response> redeemReceipt(
@Mutable @Auth final AuthenticatedDevice account,
@Auth final AuthenticatedDevice authenticatedDevice,
@Valid @NotNull final RedeemBackupReceiptRequest redeemBackupReceiptRequest) {
return this.backupAuthManager.redeemReceipt(
account.getAccount(),
redeemBackupReceiptRequest.receiptCredentialPresentation())
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
return accountsManager.getByAccountIdentifierAsync(authenticatedDevice.getAccountIdentifier())
.thenCompose(maybeAccount -> {
final Account account = maybeAccount
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return backupAuthManager.redeemReceipt(account, redeemBackupReceiptRequest.receiptCredentialPresentation())
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
});
}
public record BackupAuthCredentialsResponse(
@ -252,7 +272,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "404", description = "Could not find an existing blinded backup id")
@ApiResponse(responseCode = "429", description = "Rate limited.")
public CompletionStage<BackupAuthCredentialsResponse> getBackupZKCredentials(
@Mutable @Auth AuthenticatedDevice auth,
@Auth AuthenticatedDevice authenticatedDevice,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@NotNull @QueryParam("redemptionStartSeconds") Long startSeconds,
@NotNull @QueryParam("redemptionEndSeconds") Long endSeconds) {
@ -260,27 +280,33 @@ public class ArchiveController {
final Map<BackupCredentialType, List<BackupAuthCredentialsResponse.BackupAuthCredential>> credentialsByType =
new ConcurrentHashMap<>();
return CompletableFuture.allOf(Arrays.stream(BackupCredentialType.values())
.map(credentialType -> this.backupAuthManager.getBackupAuthCredentials(
auth.getAccount(),
credentialType,
Instant.ofEpochSecond(startSeconds), Instant.ofEpochSecond(endSeconds))
.thenAccept(credentials -> {
backupMetrics.updateGetCredentialCounter(
UserAgentTagUtil.getPlatformTag(userAgent),
credentialType,
credentials.size());
credentialsByType.put(credentialType, credentials.stream()
.map(credential -> new BackupAuthCredentialsResponse.BackupAuthCredential(
credential.credential().serialize(),
credential.redemptionTime().getEpochSecond()))
.toList());
}))
.toArray(CompletableFuture[]::new))
.thenApply(ignored -> new BackupAuthCredentialsResponse(credentialsByType.entrySet().stream()
.collect(Collectors.toMap(
e -> BackupAuthCredentialsResponse.CredentialType.fromLibsignalType(e.getKey()),
Map.Entry::getValue))));
return accountsManager.getByAccountIdentifierAsync(authenticatedDevice.getAccountIdentifier())
.thenCompose(maybeAccount -> {
final Account account = maybeAccount
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return CompletableFuture.allOf(Arrays.stream(BackupCredentialType.values())
.map(credentialType -> this.backupAuthManager.getBackupAuthCredentials(
account,
credentialType,
Instant.ofEpochSecond(startSeconds), Instant.ofEpochSecond(endSeconds))
.thenAccept(credentials -> {
backupMetrics.updateGetCredentialCounter(
UserAgentTagUtil.getPlatformTag(userAgent),
credentialType,
credentials.size());
credentialsByType.put(credentialType, credentials.stream()
.map(credential -> new BackupAuthCredentialsResponse.BackupAuthCredential(
credential.credential().serialize(),
credential.redemptionTime().getEpochSecond()))
.toList());
}))
.toArray(CompletableFuture[]::new))
.thenApply(ignored -> new BackupAuthCredentialsResponse(credentialsByType.entrySet().stream()
.collect(Collectors.toMap(
e -> BackupAuthCredentialsResponse.CredentialType.fromLibsignalType(e.getKey()),
Map.Entry::getValue))));
});
}
@ -343,7 +369,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<ReadAuthResponse> readAuth(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))
@ -395,7 +421,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<BackupInfoResponse> backupInfo(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))
@ -441,7 +467,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "204", description = "The public key was set")
@ApiResponse(responseCode = "429", description = "Rate limited.")
public CompletionStage<Response> setPublicKey(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))
@NotNull
@ -481,7 +507,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<UploadDescriptorResponse> backup(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))
@ -518,7 +544,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<UploadDescriptorResponse> uploadTemporaryAttachment(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@ -606,7 +632,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<CopyMediaResponse> copyMedia(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))
@ -705,7 +731,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<Response> copyMedia(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))
@ -744,7 +770,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<Response> refresh(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))
@ -811,7 +837,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<ListResponse> listMedia(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))
@ -867,7 +893,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<Response> deleteMedia(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))
@ -904,7 +930,7 @@ public class ArchiveController {
@ApiResponse(responseCode = "429", description = "Rate limited.")
@ApiResponseZkAuth
public CompletionStage<Response> deleteBackup(
@ReadOnly @Auth final Optional<AuthenticatedDevice> account,
@Auth final Optional<AuthenticatedDevice> account,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class))

View File

@ -26,7 +26,6 @@ import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV3;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.websocket.auth.ReadOnly;
/**
@ -78,11 +77,11 @@ public class AttachmentControllerV4 {
@ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header(
name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public AttachmentDescriptorV3 getAttachmentUploadForm(@ReadOnly @Auth AuthenticatedDevice auth)
public AttachmentDescriptorV3 getAttachmentUploadForm(@Auth AuthenticatedDevice auth)
throws RateLimitExceededException {
rateLimiter.validate(auth.getAccount().getUuid());
rateLimiter.validate(auth.getAccountIdentifier());
final String key = generateAttachmentKey();
final boolean useCdn3 = this.experimentEnrollmentManager.isEnrolled(auth.getAccount().getUuid(), CDN3_EXPERIMENT_NAME);
final boolean useCdn3 = this.experimentEnrollmentManager.isEnrolled(auth.getAccountIdentifier(), CDN3_EXPERIMENT_NAME);
int cdn = useCdn3 ? 3 : 2;
final AttachmentGenerator.Descriptor descriptor = this.attachmentGenerators.get(cdn).generateAttachment(key);
return new AttachmentDescriptorV3(cdn, key, descriptor.headers(), descriptor.signedUploadLocation());

View File

@ -20,7 +20,6 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.entities.CreateCallLinkCredential;
import org.whispersystems.textsecuregcm.entities.GetCreateCallLinkCredentialsRequest;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v1/call-link")
@io.swagger.v3.oas.annotations.tags.Tag(name = "CallLink")
@ -52,11 +51,11 @@ public class CallLinkController {
@ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Ratelimited.")
public CreateCallLinkCredential getCreateAuth(
final @ReadOnly @Auth AuthenticatedDevice auth,
final @Auth AuthenticatedDevice auth,
final @NotNull @Valid GetCreateCallLinkCredentialsRequest request
) throws RateLimitExceededException {
rateLimiters.getCreateCallLinkLimiter().validate(auth.getAccount().getUuid());
rateLimiters.getCreateCallLinkLimiter().validate(auth.getAccountIdentifier());
final Instant truncatedDayTimestamp = Instant.now().truncatedTo(ChronoUnit.DAYS);
@ -68,7 +67,7 @@ public class CallLinkController {
}
return new CreateCallLinkCredential(
createCallLinkCredentialRequest.issueCredential(new ServiceId.Aci(auth.getAccount().getUuid()), truncatedDayTimestamp, genericServerSecretParams).serialize(),
createCallLinkCredentialRequest.issueCredential(new ServiceId.Aci(auth.getAccountIdentifier()), truncatedDayTimestamp, genericServerSecretParams).serialize(),
truncatedDayTimestamp.getEpochSecond()
);
}

View File

@ -18,11 +18,9 @@ import jakarta.ws.rs.Produces;
import jakarta.ws.rs.core.MediaType;
import java.io.IOException;
import java.util.List;
import java.util.UUID;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.CloudflareTurnCredentialsManager;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.websocket.auth.ReadOnly;
@io.swagger.v3.oas.annotations.tags.Tag(name = "Calling")
@Path("/v2/calling")
@ -56,11 +54,10 @@ public class CallRoutingControllerV2 {
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ApiResponse(responseCode = "422", description = "Invalid request format.")
@ApiResponse(responseCode = "429", description = "Rate limited.")
public GetCallingRelaysResponse getCallingRelays(final @ReadOnly @Auth AuthenticatedDevice auth)
public GetCallingRelaysResponse getCallingRelays(final @Auth AuthenticatedDevice auth)
throws RateLimitExceededException, IOException {
final UUID aci = auth.getAccount().getUuid();
rateLimiters.getCallEndpointLimiter().validate(aci);
rateLimiters.getCallEndpointLimiter().validate(auth.getAccountIdentifier());
try {
return new GetCallingRelaysResponse(List.of(cloudflareTurnCredentialsManager.retrieveFromCloudflare()));

View File

@ -8,18 +8,18 @@ package org.whispersystems.textsecuregcm.controllers;
import static com.codahale.metrics.MetricRegistry.name;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth;
import io.micrometer.core.instrument.Metrics;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.ws.rs.BadRequestException;
import jakarta.ws.rs.DefaultValue;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.HeaderParam;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.QueryParam;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import java.security.InvalidKeyException;
import java.time.Clock;
import java.time.Duration;
@ -38,13 +38,16 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
import org.whispersystems.textsecuregcm.entities.DeliveryCertificate;
import org.whispersystems.textsecuregcm.entities.GroupCredentials;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v1/certificate")
@Tag(name = "Certificate")
public class CertificateController {
private final AccountsManager accountsManager;
private final CertificateGenerator certificateGenerator;
private final ServerZkAuthOperations serverZkAuthOperations;
private final GenericServerSecretParams genericServerSecretParams;
@ -56,10 +59,13 @@ public class CertificateController {
private static final String INCLUDE_E164_TAG_NAME = "includeE164";
public CertificateController(
final AccountsManager accountsManager,
@Nonnull CertificateGenerator certificateGenerator,
@Nonnull ServerZkAuthOperations serverZkAuthOperations,
@Nonnull GenericServerSecretParams genericServerSecretParams,
@Nonnull Clock clock) {
this.accountsManager = accountsManager;
this.certificateGenerator = Objects.requireNonNull(certificateGenerator);
this.serverZkAuthOperations = Objects.requireNonNull(serverZkAuthOperations);
this.genericServerSecretParams = genericServerSecretParams;
@ -69,23 +75,25 @@ public class CertificateController {
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/delivery")
public DeliveryCertificate getDeliveryCertificate(@ReadOnly @Auth AuthenticatedDevice auth,
public DeliveryCertificate getDeliveryCertificate(@Auth AuthenticatedDevice auth,
@QueryParam("includeE164") @DefaultValue("true") boolean includeE164)
throws InvalidKeyException {
Metrics.counter(GENERATE_DELIVERY_CERTIFICATE_COUNTER_NAME, INCLUDE_E164_TAG_NAME, String.valueOf(includeE164))
.increment();
final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return new DeliveryCertificate(
certificateGenerator.createFor(auth.getAccount(), auth.getAuthenticatedDevice(), includeE164));
certificateGenerator.createFor(account, auth.getDeviceId(), includeE164));
}
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/auth/group")
public GroupCredentials getGroupAuthenticationCredentials(
@ReadOnly @Auth AuthenticatedDevice auth,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent,
@Auth AuthenticatedDevice auth,
@QueryParam("redemptionStartSeconds") long startSeconds,
@QueryParam("redemptionEndSeconds") long endSeconds) {
@ -102,13 +110,16 @@ public class CertificateController {
throw new BadRequestException();
}
final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
final List<GroupCredentials.GroupCredential> credentials = new ArrayList<>();
final List<GroupCredentials.CallLinkAuthCredential> callLinkAuthCredentials = new ArrayList<>();
Instant redemption = redemptionStart;
ServiceId.Aci aci = new ServiceId.Aci(auth.getAccount().getUuid());
ServiceId.Pni pni = new ServiceId.Pni(auth.getAccount().getPhoneNumberIdentifier());
final ServiceId.Aci aci = new ServiceId.Aci(account.getIdentifier(IdentityType.ACI));
final ServiceId.Pni pni = new ServiceId.Pni(account.getIdentifier(IdentityType.PNI));
while (!redemption.isAfter(redemptionEnd)) {
AuthCredentialWithPniResponse authCredentialWithPni = serverZkAuthOperations.issueAuthCredentialWithPniZkc(aci, pni, redemption);

View File

@ -25,6 +25,7 @@ import jakarta.ws.rs.POST;
import jakarta.ws.rs.PUT;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType;
@ -40,12 +41,14 @@ import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker;
import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker.ChallengeConstraints;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
@Path("/v1/challenge")
@Tag(name = "Challenge")
public class ChallengeController {
private final AccountsManager accountsManager;
private final RateLimitChallengeManager rateLimitChallengeManager;
private final ChallengeConstraintChecker challengeConstraintChecker;
@ -53,8 +56,10 @@ public class ChallengeController {
private static final String CHALLENGE_TYPE_TAG = "type";
public ChallengeController(
final AccountsManager accountsManager,
final RateLimitChallengeManager rateLimitChallengeManager,
final ChallengeConstraintChecker challengeConstraintChecker) {
this.accountsManager = accountsManager;
this.rateLimitChallengeManager = rateLimitChallengeManager;
this.challengeConstraintChecker = challengeConstraintChecker;
}
@ -77,15 +82,18 @@ public class ChallengeController {
@ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header(
name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public Response handleChallengeResponse(@ReadOnly @Auth final AuthenticatedDevice auth,
public Response handleChallengeResponse(@Auth final AuthenticatedDevice auth,
@Valid final AnswerChallengeRequest answerRequest,
@Context ContainerRequestContext requestContext,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) throws RateLimitExceededException, IOException {
final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent));
final ChallengeConstraints constraints = challengeConstraintChecker.challengeConstraints(
requestContext, auth.getAccount());
requestContext, account);
try {
if (answerRequest instanceof final AnswerPushChallengeRequest pushChallengeRequest) {
tags = tags.and(CHALLENGE_TYPE_TAG, "push");
@ -93,14 +101,14 @@ public class ChallengeController {
if (!constraints.pushPermitted()) {
return Response.status(429).build();
}
rateLimitChallengeManager.answerPushChallenge(auth.getAccount(), pushChallengeRequest.getChallenge());
rateLimitChallengeManager.answerPushChallenge(account, pushChallengeRequest.getChallenge());
} else if (answerRequest instanceof AnswerCaptchaChallengeRequest captchaChallengeRequest) {
tags = tags.and(CHALLENGE_TYPE_TAG, "captcha");
final String remoteAddress = (String) requestContext.getProperty(
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
boolean success = rateLimitChallengeManager.answerCaptchaChallenge(
auth.getAccount(),
account,
captchaChallengeRequest.getCaptcha(),
remoteAddress,
userAgent,
@ -126,7 +134,7 @@ public class ChallengeController {
summary = "Request a push challenge",
description = """
Clients may proactively request a push challenge by making an empty POST request. Push challenges will only be
sent to the requesting accounts main device. When the push is received it may be provided as proof of completed
sent to the requesting accounts main device. When the push is received it may be provided as proof of completed
challenge to /v1/challenge.
APNs challenge payloads will be formatted as follows:
```
@ -140,12 +148,12 @@ public class ChallengeController {
"rateLimitChallenge": "{CHALLENGE_TOKEN}"
}
```
FCM challenge payloads will be formatted as follows:
FCM challenge payloads will be formatted as follows:
```
{"rateLimitChallenge": "{CHALLENGE_TOKEN}"}
```
Clients may retry the PUT in the event of an HTTP/5xx response (except HTTP/508) from the server, but must
Clients may retry the PUT in the event of an HTTP/5xx response (except HTTP/508) from the server, but must
implement an exponential back-off system and limit the total number of retries.
"""
)
@ -163,15 +171,18 @@ public class ChallengeController {
@ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header(
name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public Response requestPushChallenge(@ReadOnly @Auth final AuthenticatedDevice auth,
public Response requestPushChallenge(@Auth final AuthenticatedDevice auth,
@Context ContainerRequestContext requestContext) {
final ChallengeConstraints constraints = challengeConstraintChecker.challengeConstraints(
requestContext, auth.getAccount());
final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
final ChallengeConstraints constraints = challengeConstraintChecker.challengeConstraints(requestContext, account);
if (!constraints.pushPermitted()) {
return Response.status(429).build();
}
try {
rateLimitChallengeManager.sendPushChallenge(auth.getAccount());
rateLimitChallengeManager.sendPushChallenge(account);
return Response.status(200).build();
} catch (final NotPushRegisteredException e) {
return Response.status(404).build();

View File

@ -33,6 +33,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.backup.BackupAuthManager;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.devicecheck.AppleDeviceCheckManager;
import org.whispersystems.textsecuregcm.storage.devicecheck.ChallengeNotFoundException;
import org.whispersystems.textsecuregcm.storage.devicecheck.DeviceCheckKeyIdNotFoundException;
@ -41,7 +42,6 @@ import org.whispersystems.textsecuregcm.storage.devicecheck.DuplicatePublicKeyEx
import org.whispersystems.textsecuregcm.storage.devicecheck.RequestReuseException;
import org.whispersystems.textsecuregcm.storage.devicecheck.TooManyKeysException;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.auth.ReadOnly;
/**
* Process platform device attestations.
@ -55,6 +55,7 @@ import org.whispersystems.websocket.auth.ReadOnly;
public class DeviceCheckController {
private final Clock clock;
private final AccountsManager accountsManager;
private final BackupAuthManager backupAuthManager;
private final AppleDeviceCheckManager deviceCheckManager;
private final RateLimiters rateLimiters;
@ -63,12 +64,14 @@ public class DeviceCheckController {
public DeviceCheckController(
final Clock clock,
final AccountsManager accountsManager,
final BackupAuthManager backupAuthManager,
final AppleDeviceCheckManager deviceCheckManager,
final RateLimiters rateLimiters,
final long backupRedemptionLevel,
final Duration backupRedemptionDuration) {
this.clock = clock;
this.accountsManager = accountsManager;
this.backupAuthManager = backupAuthManager;
this.deviceCheckManager = deviceCheckManager;
this.backupRedemptionLevel = backupRedemptionLevel;
@ -94,14 +97,17 @@ public class DeviceCheckController {
@ApiResponse(responseCode = "200", description = "The response body includes a challenge")
@ApiResponse(responseCode = "429", description = "Ratelimited.")
@ManagedAsync
public ChallengeResponse attestChallenge(@ReadOnly @Auth AuthenticatedDevice authenticatedDevice)
public ChallengeResponse attestChallenge(@Auth AuthenticatedDevice authenticatedDevice)
throws RateLimitExceededException {
rateLimiters.forDescriptor(RateLimiters.For.DEVICE_CHECK_CHALLENGE)
.validate(authenticatedDevice.getAccount().getUuid());
.validate(authenticatedDevice.getAccountIdentifier());
final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return new ChallengeResponse(deviceCheckManager.createChallenge(
AppleDeviceCheckManager.ChallengeType.ATTEST,
authenticatedDevice.getAccount()));
account));
}
@PUT
@ -125,7 +131,7 @@ public class DeviceCheckController {
@ApiResponse(responseCode = "409", description = "The provided keyId has already been registered to a different account")
@ManagedAsync
public void attest(
@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice,
@Auth final AuthenticatedDevice authenticatedDevice,
@Valid
@NotNull
@ -135,8 +141,11 @@ public class DeviceCheckController {
@RequestBody(description = "The attestation data, created by [attestKey](https://developer.apple.com/documentation/devicecheck/dcappattestservice/attestkey(_:clientdatahash:completionhandler:))")
@NotNull final byte[] attestation) {
final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
try {
deviceCheckManager.registerAttestation(authenticatedDevice.getAccount(), parseKeyId(keyId), attestation);
deviceCheckManager.registerAttestation(account, parseKeyId(keyId), attestation);
} catch (TooManyKeysException e) {
throw new WebApplicationException(Response.status(413).build());
} catch (ChallengeNotFoundException e) {
@ -166,17 +175,19 @@ public class DeviceCheckController {
@ApiResponse(responseCode = "429", description = "Ratelimited.")
@ManagedAsync
public ChallengeResponse assertChallenge(
@ReadOnly @Auth AuthenticatedDevice authenticatedDevice,
@Auth AuthenticatedDevice authenticatedDevice,
@Parameter(schema = @Schema(description = "The type of action you will make an assertion for",
allowableValues = {"backup"},
implementation = String.class))
@QueryParam("action") Action action) throws RateLimitExceededException {
rateLimiters.forDescriptor(RateLimiters.For.DEVICE_CHECK_CHALLENGE)
.validate(authenticatedDevice.getAccount().getUuid());
return new ChallengeResponse(
deviceCheckManager.createChallenge(toChallengeType(action),
authenticatedDevice.getAccount()));
.validate(authenticatedDevice.getAccountIdentifier());
final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return new ChallengeResponse(deviceCheckManager.createChallenge(toChallengeType(action), account));
}
@POST
@ -199,7 +210,7 @@ public class DeviceCheckController {
@ApiResponse(responseCode = "401", description = "The assertion could not be verified")
@ManagedAsync
public void assertion(
@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice,
@Auth final AuthenticatedDevice authenticatedDevice,
@Valid
@NotNull
@ -218,9 +229,12 @@ public class DeviceCheckController {
@RequestBody(description = "The assertion created by [generateAssertion](https://developer.apple.com/documentation/devicecheck/dcappattestservice/generateassertion(_:clientdatahash:completionhandler:))")
@NotNull final byte[] assertion) {
final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
try {
deviceCheckManager.validateAssert(
authenticatedDevice.getAccount(),
account,
parseKeyId(keyId),
toChallengeType(request.assertionRequest().action()),
request.assertionRequest().challenge(),
@ -237,7 +251,7 @@ public class DeviceCheckController {
// The request assertion was validated, execute it
switch (request.assertionRequest().action()) {
case BACKUP -> backupAuthManager.extendBackupVoucher(
authenticatedDevice.getAccount(),
account,
new Account.BackupVoucher(backupRedemptionLevel, clock.instant().plus(backupRedemptionDuration)))
.join();
}

View File

@ -34,7 +34,6 @@ import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.QueryParam;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import java.time.Duration;
@ -51,11 +50,9 @@ import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.glassfish.jersey.server.ContainerRequest;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader;
import org.whispersystems.textsecuregcm.auth.ChangesLinkedDevices;
import org.whispersystems.textsecuregcm.auth.LinkedDeviceRefreshRequirementProvider;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest;
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
@ -87,11 +84,10 @@ import org.whispersystems.textsecuregcm.util.DeviceCapabilityAdapter;
import org.whispersystems.textsecuregcm.util.EnumMapUtil;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.LinkDeviceToken;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.websocket.auth.Mutable;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v1/devices")
@Tag(name = "Devices")
@ -152,10 +148,10 @@ public class DeviceController {
@GET
@Produces(MediaType.APPLICATION_JSON)
public DeviceInfoList getDevices(@ReadOnly @Auth AuthenticatedDevice auth) {
public DeviceInfoList getDevices(@Auth AuthenticatedDevice auth) {
// Devices may change their own names (and primary devices may change the names of linked devices) and so the device
// state associated with the authenticated account may be stale. Fetch a fresh copy to compensate.
return accounts.getByAccountIdentifier(auth.getAccount().getIdentifier(IdentityType.ACI))
return accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.map(account -> new DeviceInfoList(account.getDevices().stream()
.map(DeviceInfo::forDevice)
.toList()))
@ -166,9 +162,8 @@ public class DeviceController {
@Produces(MediaType.APPLICATION_JSON)
@Path("/{device_id}")
@ChangesLinkedDevices
public void removeDevice(@Mutable @Auth AuthenticatedDevice auth, @PathParam("device_id") byte deviceId) {
if (auth.getAuthenticatedDevice().getId() != Device.PRIMARY_ID &&
auth.getAuthenticatedDevice().getId() != deviceId) {
public void removeDevice(@Auth AuthenticatedDevice auth, @PathParam("device_id") byte deviceId) {
if (auth.getDeviceId() != Device.PRIMARY_ID && auth.getDeviceId() != deviceId) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}
@ -176,13 +171,16 @@ public class DeviceController {
throw new ForbiddenException();
}
accounts.removeDevice(auth.getAccount(), deviceId).join();
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
accounts.removeDevice(account, deviceId).join();
}
/**
* Generates a signed device-linking token. Generally, primary devices will include the signed device-linking token in
* a provisioning message to a new device, and then the new device will include the token in its request to
* {@link #linkDevice(BasicAuthorizationHeader, String, LinkDeviceRequest, ContainerRequest)}.
* {@link #linkDevice(BasicAuthorizationHeader, String, LinkDeviceRequest)}.
*
* @param auth the authenticated account/device
*
@ -207,10 +205,11 @@ public class DeviceController {
@ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header(
name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public LinkDeviceToken createDeviceToken(@ReadOnly @Auth AuthenticatedDevice auth)
public LinkDeviceToken createDeviceToken(@Auth AuthenticatedDevice auth)
throws RateLimitExceededException, DeviceLimitExceededException {
final Account account = auth.getAccount();
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
rateLimiters.getAllocateDeviceLimiter().validate(account.getUuid());
@ -224,7 +223,7 @@ public class DeviceController {
throw new DeviceLimitExceededException(account.getDevices().size(), maxDeviceLimit);
}
if (auth.getAuthenticatedDevice().getId() != Device.PRIMARY_ID) {
if (auth.getDeviceId() != Device.PRIMARY_ID) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}
@ -252,8 +251,7 @@ public class DeviceController {
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public LinkDeviceResponse linkDevice(@HeaderParam(HttpHeaders.AUTHORIZATION) BasicAuthorizationHeader authorizationHeader,
@HeaderParam(HttpHeaders.USER_AGENT) @Nullable String userAgent,
@NotNull @Valid LinkDeviceRequest linkDeviceRequest,
@Context ContainerRequest containerRequest)
@NotNull @Valid LinkDeviceRequest linkDeviceRequest)
throws RateLimitExceededException, DeviceLimitExceededException {
final Account account = accounts.checkDeviceLinkingToken(linkDeviceRequest.verificationCode())
@ -279,11 +277,6 @@ public class DeviceController {
throw new WebApplicationException(Response.status(422).build());
}
// Normally, the "do we need to refresh somebody's websockets" listener can do this on its own. In this case,
// we're not using the conventional authentication system, and so we need to give it a hint so it knows who the
// active user is and what their device states look like.
LinkedDeviceRefreshRequirementProvider.setAccount(containerRequest, account);
final int maxDeviceLimit = maxDeviceConfiguration.getOrDefault(account.getNumber(), MAX_DEVICES);
if (account.getDevices().size() >= maxDeviceLimit) {
@ -351,7 +344,7 @@ public class DeviceController {
@ApiResponse(responseCode = "400", description = "The given token identifier or timeout was invalid")
@ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay")
public CompletionStage<Response> waitForLinkedDevice(
@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice,
@Auth final AuthenticatedDevice authenticatedDevice,
@PathParam("tokenIdentifier")
@Schema(description = "A 'link device' token identifier provided by the 'create link device token' endpoint")
@ -374,12 +367,18 @@ public class DeviceController {
final AtomicInteger linkedDeviceListenerCounter = getCounterForLinkedDeviceListeners(userAgent);
linkedDeviceListenerCounter.incrementAndGet();
return rateLimiters.getWaitForLinkedDeviceLimiter()
.validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI))
.thenCompose(ignored -> persistentTimer.start(WAIT_FOR_LINKED_DEVICE_TIMER_NAMESPACE, tokenIdentifier))
.thenCompose(sample -> accounts.waitForNewLinkedDevice(
authenticatedDevice.getAccount().getUuid(),
authenticatedDevice.getAuthenticatedDevice(),
return rateLimiters.getWaitForLinkedDeviceLimiter().validateAsync(authenticatedDevice.getAccountIdentifier())
.thenCompose(ignored -> accounts.getByAccountIdentifierAsync(authenticatedDevice.getAccountIdentifier()))
.thenCompose(maybeAccount -> {
final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return persistentTimer.start(WAIT_FOR_LINKED_DEVICE_TIMER_NAMESPACE, tokenIdentifier)
.thenApply(sample -> new Pair<>(account, sample));
})
.thenCompose(accountAndSample -> accounts.waitForNewLinkedDevice(
authenticatedDevice.getAccountIdentifier(),
accountAndSample.first().getDevice(authenticatedDevice.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)),
tokenIdentifier,
Duration.ofSeconds(timeoutSeconds))
.thenApply(maybeDeviceInfo -> maybeDeviceInfo
@ -391,7 +390,7 @@ public class DeviceController {
linkedDeviceListenerCounter.decrementAndGet();
if (response != null && response.getStatus() == Response.Status.OK.getStatusCode()) {
sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME)
accountAndSample.second().stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME)
.publishPercentileHistogram(true)
.tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)))
.register(Metrics.globalRegistry));
@ -410,14 +409,15 @@ public class DeviceController {
@PUT
@Produces(MediaType.APPLICATION_JSON)
@Path("/capabilities")
public void setCapabilities(@Mutable @Auth final AuthenticatedDevice auth,
public void setCapabilities(@Auth final AuthenticatedDevice auth,
@NotNull
final Map<String, Boolean> capabilities) {
assert (auth.getAuthenticatedDevice() != null);
final byte deviceId = auth.getAuthenticatedDevice().getId();
accounts.updateDevice(auth.getAccount(), deviceId,
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
accounts.updateDevice(account, auth.getDeviceId(),
d -> d.setCapabilities(DeviceCapabilityAdapter.mapToSet(capabilities)));
}
@ -435,12 +435,13 @@ public class DeviceController {
@ApiResponse(responseCode = "200", description = "Public key stored successfully")
@ApiResponse(responseCode = "401", description = "Account authentication check failed")
@ApiResponse(responseCode = "422", description = "Invalid request format")
public CompletableFuture<Void> setPublicKey(@Mutable @Auth final AuthenticatedDevice auth,
public CompletableFuture<Void> setPublicKey(@Auth final AuthenticatedDevice auth,
final SetPublicKeyRequest setPublicKeyRequest) {
return clientPublicKeysManager.setPublicKey(auth.getAccount(),
auth.getAuthenticatedDevice().getId(),
setPublicKeyRequest.publicKey());
final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return clientPublicKeysManager.setPublicKey(account, auth.getDeviceId(), setPublicKeyRequest.publicKey());
}
private static boolean isCapabilityDowngrade(final Account account, final Set<DeviceCapability> capabilities) {
@ -531,15 +532,21 @@ public class DeviceController {
@ApiResponse(responseCode = "204", description = "Success")
@ApiResponse(responseCode = "422", description = "The request object could not be parsed or was otherwise invalid")
@ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay")
public CompletionStage<Void> recordTransferArchiveUploaded(@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice,
public CompletionStage<Void> recordTransferArchiveUploaded(@Auth final AuthenticatedDevice authenticatedDevice,
@NotNull @Valid final TransferArchiveUploadedRequest transferArchiveUploadedRequest) {
return rateLimiters.getUploadTransferArchiveLimiter()
.validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI))
.thenCompose(ignored -> accounts.recordTransferArchiveUpload(authenticatedDevice.getAccount(),
transferArchiveUploadedRequest.destinationDeviceId(),
Instant.ofEpochMilli(transferArchiveUploadedRequest.destinationDeviceCreated()),
transferArchiveUploadedRequest.transferArchive()));
.validateAsync(authenticatedDevice.getAccountIdentifier())
.thenCompose(ignored -> accounts.getByAccountIdentifierAsync(authenticatedDevice.getAccountIdentifier()))
.thenCompose(maybeAccount -> {
final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return accounts.recordTransferArchiveUpload(account,
transferArchiveUploadedRequest.destinationDeviceId(),
Instant.ofEpochMilli(transferArchiveUploadedRequest.destinationDeviceCreated()),
transferArchiveUploadedRequest.transferArchive());
});
}
@GET
@ -558,7 +565,7 @@ public class DeviceController {
@ApiResponse(responseCode = "204", description = "No transfer archive was uploaded before the call completed; clients may repeat the call to continue waiting")
@ApiResponse(responseCode = "400", description = "The given timeout was invalid")
@ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay")
public CompletionStage<Response> waitForTransferArchive(@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice,
public CompletionStage<Response> waitForTransferArchive(@Auth final AuthenticatedDevice authenticatedDevice,
@QueryParam("timeout")
@DefaultValue("30")
@ -575,24 +582,30 @@ public class DeviceController {
@HeaderParam(HttpHeaders.USER_AGENT) @Nullable String userAgent) {
final String rateLimiterKey = authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI) +
":" + authenticatedDevice.getAuthenticatedDevice().getId();
final String rateLimiterKey = authenticatedDevice.getAccountIdentifier() + ":" + authenticatedDevice.getDeviceId();
return rateLimiters.getWaitForTransferArchiveLimiter().validateAsync(rateLimiterKey)
.thenCompose(ignored -> persistentTimer.start(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAMESPACE, rateLimiterKey))
.thenCompose(sample -> accounts.waitForTransferArchive(authenticatedDevice.getAccount(),
authenticatedDevice.getAuthenticatedDevice(),
.thenCompose(ignored -> accounts.getByAccountIdentifierAsync(authenticatedDevice.getAccountIdentifier()))
.thenCompose(maybeAccount -> {
final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
return persistentTimer.start(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAMESPACE, rateLimiterKey)
.thenApply(sample -> new Pair<>(account, sample));
})
.thenCompose(accountAndSample -> accounts.waitForTransferArchive(accountAndSample.first(),
accountAndSample.first().getDevice(authenticatedDevice.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)),
Duration.ofSeconds(timeoutSeconds))
.thenApply(maybeTransferArchive -> maybeTransferArchive
.map(transferArchive -> Response.status(Response.Status.OK).entity(transferArchive).build())
.orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build()))
.whenComplete((response, throwable) -> {
if (response != null && response.getStatus() == Response.Status.OK.getStatusCode()) {
sample.stop(Timer.builder(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAME)
accountAndSample.second().stop(Timer.builder(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAME)
.publishPercentileHistogram(true)
.tags(Tags.of(
UserAgentTagUtil.getPlatformTag(userAgent),
primaryPlatformTag(authenticatedDevice.getAccount())))
primaryPlatformTag(accountAndSample.first())))
.register(Metrics.globalRegistry));
}
}));

View File

@ -14,12 +14,10 @@ import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.core.MediaType;
import java.time.Clock;
import java.util.UUID;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
import org.whispersystems.textsecuregcm.configuration.DirectoryV2ClientConfiguration;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v2/directory")
@Tag(name = "Directory")
@ -57,8 +55,7 @@ public class DirectoryV2Controller {
"""
)
@ApiResponse(responseCode = "200", description = "`JSON` with generated credentials.", useReturnTypeSchema = true)
public ExternalServiceCredentials getAuthToken(final @ReadOnly @Auth AuthenticatedDevice auth) {
final UUID uuid = auth.getAccount().getUuid();
return directoryServiceTokenGenerator.generateForUuid(uuid);
public ExternalServiceCredentials getAuthToken(final @Auth AuthenticatedDevice auth) {
return directoryServiceTokenGenerator.generateForUuid(auth.getAccountIdentifier());
}
}

View File

@ -15,6 +15,7 @@ import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.Response.Status;
@ -33,10 +34,10 @@ import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration;
import org.whispersystems.textsecuregcm.entities.RedeemReceiptRequest;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountBadge;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.RedeemedReceiptsManager;
import org.whispersystems.websocket.auth.Mutable;
@Path("/v1/donation")
@Tag(name = "Donations")
@ -86,7 +87,7 @@ public class DonationController {
""")
@ApiResponse(responseCode = "429", description = "Rate limited.")
public CompletionStage<Response> redeemReceipt(
@Mutable @Auth final AuthenticatedDevice auth,
@Auth final AuthenticatedDevice auth,
@NotNull @Valid final RedeemReceiptRequest request) {
return CompletableFuture.supplyAsync(() -> {
ReceiptCredentialPresentation receiptCredentialPresentation;
@ -118,23 +119,29 @@ public class DonationController {
.type(MediaType.TEXT_PLAIN_TYPE)
.build());
}
return redeemedReceiptsManager.put(
receiptSerial, receiptExpiration.getEpochSecond(), receiptLevel, auth.getAccount().getUuid())
.thenCompose(receiptMatched -> {
if (!receiptMatched) {
return CompletableFuture.completedFuture(Response.status(Status.BAD_REQUEST)
.entity("receipt serial is already redeemed")
.type(MediaType.TEXT_PLAIN_TYPE)
.build());
}
return accountsManager.updateAsync(auth.getAccount(), a -> {
a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible()));
if (request.isPrimary()) {
a.makeBadgePrimaryIfExists(clock, badgeId);
return accountsManager.getByAccountIdentifierAsync(auth.getAccountIdentifier())
.thenCompose(maybeAccount -> {
final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
return redeemedReceiptsManager.put(
receiptSerial, receiptExpiration.getEpochSecond(), receiptLevel, auth.getAccountIdentifier())
.thenCompose(receiptMatched -> {
if (!receiptMatched) {
return CompletableFuture.completedFuture(Response.status(Status.BAD_REQUEST)
.entity("receipt serial is already redeemed")
.type(MediaType.TEXT_PLAIN_TYPE)
.build());
}
})
.thenApply(ignored -> Response.ok().build());
return accountsManager.updateAsync(account, a -> {
a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible()));
if (request.isPrimary()) {
a.makeBadgePrimaryIfExists(clock, badgeId);
}
})
.thenApply(ignored -> Response.ok().build());
});
});
}).thenCompose(Function.identity());
}

View File

@ -5,9 +5,8 @@
package org.whispersystems.textsecuregcm.controllers;
import org.whispersystems.textsecuregcm.auth.TurnToken;
import java.util.List;
import org.whispersystems.textsecuregcm.auth.TurnToken;
public record GetCallingRelaysResponse(List<TurnToken> relays) {
}

View File

@ -23,7 +23,6 @@ import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.websocket.session.WebSocketSession;
import org.whispersystems.websocket.session.WebSocketSessionContext;
@ -45,16 +44,16 @@ public class KeepAliveController {
}
@GET
public Response getKeepAlive(@ReadOnly @Auth Optional<AuthenticatedDevice> maybeAuth,
public Response getKeepAlive(@Auth Optional<AuthenticatedDevice> maybeAuth,
@WebSocketSession WebSocketSessionContext context) {
maybeAuth.ifPresent(auth -> {
if (!webSocketConnectionEventManager.isLocallyPresent(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId())) {
if (!webSocketConnectionEventManager.isLocallyPresent(auth.getAccountIdentifier(), auth.getDeviceId())) {
final Duration age = Duration.between(context.getClient().getCreated(), Instant.now());
logger.debug("***** No local subscription found for {}::{}; age = {}ms, User-Agent = {}",
auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), age.toMillis(),
auth.getAccountIdentifier(), auth.getDeviceId(), age.toMillis(),
context.getClient().getUserAgent());
context.getClient().close(1000, "OK");

View File

@ -49,7 +49,6 @@ import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceCl
import org.whispersystems.textsecuregcm.limits.RateLimitedByIp;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v1/key-transparency")
@Tag(name = "KeyTransparency")
@ -90,7 +89,7 @@ public class KeyTransparencyController {
@RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_SEARCH_PER_IP)
@Produces(MediaType.APPLICATION_JSON)
public KeyTransparencySearchResponse search(
@ReadOnly @Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@NotNull @Valid final KeyTransparencySearchRequest request) {
// Disallow clients from making authenticated requests to this endpoint
@ -142,7 +141,7 @@ public class KeyTransparencyController {
@RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_MONITOR_PER_IP)
@Produces(MediaType.APPLICATION_JSON)
public KeyTransparencyMonitorResponse monitor(
@ReadOnly @Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@NotNull @Valid final KeyTransparencyMonitorRequest request) {
// Disallow clients from making authenticated requests to this endpoint
@ -204,7 +203,7 @@ public class KeyTransparencyController {
@RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_DISTINGUISHED_PER_IP)
@Produces(MediaType.APPLICATION_JSON)
public KeyTransparencyDistinguishedKeyResponse getDistinguishedKey(
@ReadOnly @Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@Auth final Optional<AuthenticatedDevice> authenticatedAccount,
@Parameter(description = "The distinguished tree head size returned by a previously verified call")
@QueryParam("lastTreeHeadSize") @Valid final Optional<@Positive Long> lastTreeHeadSize) {

View File

@ -74,7 +74,6 @@ import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.auth.ReadOnly;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v2/keys")
@ -111,16 +110,21 @@ public class KeysController {
description = "Gets the number of one-time prekeys uploaded for this device and still available")
@ApiResponse(responseCode = "200", description = "Body contains the number of available one-time prekeys for the device.", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
public CompletableFuture<PreKeyCount> getStatus(@ReadOnly @Auth final AuthenticatedDevice auth,
public CompletableFuture<PreKeyCount> getStatus(@Auth final AuthenticatedDevice auth,
@QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) {
final CompletableFuture<Integer> ecCountFuture =
keysManager.getEcCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId());
return accounts.getByAccountIdentifierAsync(auth.getAccountIdentifier())
.thenCompose(maybeAccount -> {
final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
final CompletableFuture<Integer> pqCountFuture =
keysManager.getPqCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId());
final CompletableFuture<Integer> ecCountFuture =
keysManager.getEcCount(account.getIdentifier(identityType), auth.getDeviceId());
return ecCountFuture.thenCombine(pqCountFuture, PreKeyCount::new);
final CompletableFuture<Integer> pqCountFuture =
keysManager.getPqCount(account.getIdentifier(identityType), auth.getDeviceId());
return ecCountFuture.thenCombine(pqCountFuture, PreKeyCount::new);
});
}
@PUT
@ -132,7 +136,7 @@ public class KeysController {
@ApiResponse(responseCode = "403", description = "Attempt to change identity key from a non-primary device.")
@ApiResponse(responseCode = "422", description = "Invalid request format.")
public CompletableFuture<Response> setKeys(
@ReadOnly @Auth final AuthenticatedDevice auth,
@Auth final AuthenticatedDevice auth,
@RequestBody @NotNull @Valid final SetKeysRequest setKeysRequest,
@Parameter(allowEmptyValue=true)
@ -143,63 +147,70 @@ public class KeysController {
@QueryParam("identity") @DefaultValue("aci") final IdentityType identityType,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) {
final Account account = auth.getAccount();
final Device device = auth.getAuthenticatedDevice();
final UUID identifier = account.getIdentifier(identityType);
return accounts.getByAccountIdentifierAsync(auth.getAccountIdentifier())
.thenCompose(maybeAccount -> {
final Account account = maybeAccount
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
checkSignedPreKeySignatures(setKeysRequest, account.getIdentityKey(identityType), userAgent);
final Device device = account.getDevice(auth.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent);
final Tag primaryDeviceTag = Tag.of(PRIMARY_DEVICE_TAG_NAME, String.valueOf(auth.getAuthenticatedDevice().isPrimary()));
final Tag identityTypeTag = Tag.of(IDENTITY_TYPE_TAG_NAME, identityType.name());
final UUID identifier = account.getIdentifier(identityType);
final List<CompletableFuture<Void>> storeFutures = new ArrayList<>(4);
checkSignedPreKeySignatures(setKeysRequest, account.getIdentityKey(identityType), userAgent);
if (!setKeysRequest.preKeys().isEmpty()) {
final Tags tags = Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "ec"));
final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent);
final Tag primaryDeviceTag = Tag.of(PRIMARY_DEVICE_TAG_NAME, String.valueOf(auth.getDeviceId() == Device.PRIMARY_ID));
final Tag identityTypeTag = Tag.of(IDENTITY_TYPE_TAG_NAME, identityType.name());
Metrics.counter(STORE_KEYS_COUNTER_NAME, tags).increment();
final List<CompletableFuture<Void>> storeFutures = new ArrayList<>(4);
DistributionSummary.builder(STORE_KEY_BUNDLE_SIZE_DISTRIBUTION_NAME)
.tags(tags)
.publishPercentileHistogram()
.register(Metrics.globalRegistry)
.record(setKeysRequest.preKeys().size());
if (!setKeysRequest.preKeys().isEmpty()) {
final Tags tags = Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "ec"));
storeFutures.add(keysManager.storeEcOneTimePreKeys(identifier, device.getId(), setKeysRequest.preKeys()));
}
Metrics.counter(STORE_KEYS_COUNTER_NAME, tags).increment();
if (setKeysRequest.signedPreKey() != null) {
Metrics.counter(STORE_KEYS_COUNTER_NAME,
Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "ec-signed")))
.increment();
DistributionSummary.builder(STORE_KEY_BUNDLE_SIZE_DISTRIBUTION_NAME)
.tags(tags)
.publishPercentileHistogram()
.register(Metrics.globalRegistry)
.record(setKeysRequest.preKeys().size());
storeFutures.add(keysManager.storeEcSignedPreKeys(identifier, device.getId(), setKeysRequest.signedPreKey()));
}
storeFutures.add(keysManager.storeEcOneTimePreKeys(identifier, device.getId(), setKeysRequest.preKeys()));
}
if (!setKeysRequest.pqPreKeys().isEmpty()) {
final Tags tags = Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "kyber"));
Metrics.counter(STORE_KEYS_COUNTER_NAME, tags).increment();
if (setKeysRequest.signedPreKey() != null) {
Metrics.counter(STORE_KEYS_COUNTER_NAME,
Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "ec-signed")))
.increment();
DistributionSummary.builder(STORE_KEY_BUNDLE_SIZE_DISTRIBUTION_NAME)
.tags(tags)
.publishPercentileHistogram()
.register(Metrics.globalRegistry)
.record(setKeysRequest.pqPreKeys().size());
storeFutures.add(keysManager.storeEcSignedPreKeys(identifier, device.getId(), setKeysRequest.signedPreKey()));
}
storeFutures.add(keysManager.storeKemOneTimePreKeys(identifier, device.getId(), setKeysRequest.pqPreKeys()));
}
if (!setKeysRequest.pqPreKeys().isEmpty()) {
final Tags tags = Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "kyber"));
Metrics.counter(STORE_KEYS_COUNTER_NAME, tags).increment();
if (setKeysRequest.pqLastResortPreKey() != null) {
Metrics.counter(STORE_KEYS_COUNTER_NAME,
Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "kyber-last-resort")))
.increment();
DistributionSummary.builder(STORE_KEY_BUNDLE_SIZE_DISTRIBUTION_NAME)
.tags(tags)
.publishPercentileHistogram()
.register(Metrics.globalRegistry)
.record(setKeysRequest.pqPreKeys().size());
storeFutures.add(keysManager.storePqLastResort(identifier, device.getId(), setKeysRequest.pqLastResortPreKey()));
}
storeFutures.add(keysManager.storeKemOneTimePreKeys(identifier, device.getId(), setKeysRequest.pqPreKeys()));
}
return CompletableFuture.allOf(storeFutures.toArray(EMPTY_FUTURE_ARRAY))
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
if (setKeysRequest.pqLastResortPreKey() != null) {
Metrics.counter(STORE_KEYS_COUNTER_NAME,
Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "kyber-last-resort")))
.increment();
storeFutures.add(keysManager.storePqLastResort(identifier, device.getId(), setKeysRequest.pqLastResortPreKey()));
}
return CompletableFuture.allOf(storeFutures.toArray(EMPTY_FUTURE_ARRAY))
.thenApply(Util.ASYNC_EMPTY_RESPONSE);
});
}
private void checkSignedPreKeySignatures(final SetKeysRequest setKeysRequest,
@ -253,64 +264,69 @@ public class KeysController {
""")
@ApiResponse(responseCode = "422", description = "Invalid request format")
public CompletableFuture<Response> checkKeys(
@ReadOnly @Auth final AuthenticatedDevice auth,
@Auth final AuthenticatedDevice auth,
@RequestBody @NotNull @Valid final CheckKeysRequest checkKeysRequest) {
final UUID identifier = auth.getAccount().getIdentifier(checkKeysRequest.identityType());
final byte deviceId = auth.getAuthenticatedDevice().getId();
return accounts.getByAccountIdentifierAsync(auth.getAccountIdentifier())
.thenCompose(maybeAccount -> {
final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
final CompletableFuture<Optional<ECSignedPreKey>> ecSignedPreKeyFuture =
keysManager.getEcSignedPreKey(identifier, deviceId);
final UUID identifier = account.getIdentifier(checkKeysRequest.identityType());
final byte deviceId = auth.getDeviceId();
final CompletableFuture<Optional<KEMSignedPreKey>> lastResortKeyFuture =
keysManager.getLastResort(identifier, deviceId);
final CompletableFuture<Optional<ECSignedPreKey>> ecSignedPreKeyFuture =
keysManager.getEcSignedPreKey(identifier, deviceId);
return CompletableFuture.allOf(ecSignedPreKeyFuture, lastResortKeyFuture)
.thenApply(ignored -> {
final Optional<ECSignedPreKey> maybeSignedPreKey = ecSignedPreKeyFuture.join();
final Optional<KEMSignedPreKey> maybeLastResortKey = lastResortKeyFuture.join();
final CompletableFuture<Optional<KEMSignedPreKey>> lastResortKeyFuture =
keysManager.getLastResort(identifier, deviceId);
final boolean digestsMatch;
return CompletableFuture.allOf(ecSignedPreKeyFuture, lastResortKeyFuture)
.thenApply(ignored -> {
final Optional<ECSignedPreKey> maybeSignedPreKey = ecSignedPreKeyFuture.join();
final Optional<KEMSignedPreKey> maybeLastResortKey = lastResortKeyFuture.join();
if (maybeSignedPreKey.isPresent() && maybeLastResortKey.isPresent()) {
final IdentityKey identityKey = auth.getAccount().getIdentityKey(checkKeysRequest.identityType());
final ECSignedPreKey ecSignedPreKey = maybeSignedPreKey.get();
final KEMSignedPreKey lastResortKey = maybeLastResortKey.get();
final boolean digestsMatch;
final MessageDigest messageDigest;
if (maybeSignedPreKey.isPresent() && maybeLastResortKey.isPresent()) {
final IdentityKey identityKey = account.getIdentityKey(checkKeysRequest.identityType());
final ECSignedPreKey ecSignedPreKey = maybeSignedPreKey.get();
final KEMSignedPreKey lastResortKey = maybeLastResortKey.get();
try {
messageDigest = MessageDigest.getInstance("SHA-256");
} catch (final NoSuchAlgorithmException e) {
throw new AssertionError("Every implementation of the Java platform is required to support SHA-256", e);
}
final MessageDigest messageDigest;
messageDigest.update(identityKey.serialize());
try {
messageDigest = MessageDigest.getInstance("SHA-256");
} catch (final NoSuchAlgorithmException e) {
throw new AssertionError("Every implementation of the Java platform is required to support SHA-256", e);
}
{
final ByteBuffer ecSignedPreKeyIdBuffer = ByteBuffer.allocate(Long.BYTES);
ecSignedPreKeyIdBuffer.putLong(ecSignedPreKey.keyId());
ecSignedPreKeyIdBuffer.flip();
messageDigest.update(identityKey.serialize());
messageDigest.update(ecSignedPreKeyIdBuffer);
messageDigest.update(ecSignedPreKey.serializedPublicKey());
}
{
final ByteBuffer ecSignedPreKeyIdBuffer = ByteBuffer.allocate(Long.BYTES);
ecSignedPreKeyIdBuffer.putLong(ecSignedPreKey.keyId());
ecSignedPreKeyIdBuffer.flip();
{
final ByteBuffer lastResortKeyIdBuffer = ByteBuffer.allocate(Long.BYTES);
lastResortKeyIdBuffer.putLong(lastResortKey.keyId());
lastResortKeyIdBuffer.flip();
messageDigest.update(ecSignedPreKeyIdBuffer);
messageDigest.update(ecSignedPreKey.serializedPublicKey());
}
messageDigest.update(lastResortKeyIdBuffer);
messageDigest.update(lastResortKey.serializedPublicKey());
}
{
final ByteBuffer lastResortKeyIdBuffer = ByteBuffer.allocate(Long.BYTES);
lastResortKeyIdBuffer.putLong(lastResortKey.keyId());
lastResortKeyIdBuffer.flip();
digestsMatch = MessageDigest.isEqual(messageDigest.digest(), checkKeysRequest.digest());
} else {
digestsMatch = false;
}
messageDigest.update(lastResortKeyIdBuffer);
messageDigest.update(lastResortKey.serializedPublicKey());
}
return Response.status(digestsMatch ? Response.Status.OK : Response.Status.CONFLICT).build();
digestsMatch = MessageDigest.isEqual(messageDigest.digest(), checkKeysRequest.digest());
} else {
digestsMatch = false;
}
return Response.status(digestsMatch ? Response.Status.OK : Response.Status.CONFLICT).build();
});
});
}
@ -327,7 +343,7 @@ public class KeysController {
name = "Retry-After",
description = "If present, a positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public PreKeyResponse getDeviceKeys(
@ReadOnly @Auth Optional<AuthenticatedDevice> auth,
@Auth Optional<AuthenticatedDevice> maybeAuthenticatedDevice,
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@HeaderParam(HeaderUtils.GROUP_SEND_TOKEN) Optional<GroupSendTokenHeader> groupSendToken,
@ -340,15 +356,18 @@ public class KeysController {
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent)
throws RateLimitExceededException {
if (auth.isEmpty() && accessKey.isEmpty() && groupSendToken.isEmpty()) {
if (maybeAuthenticatedDevice.isEmpty() && accessKey.isEmpty() && groupSendToken.isEmpty()) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}
final Optional<Account> account = auth.map(AuthenticatedDevice::getAccount);
final Optional<Account> account = maybeAuthenticatedDevice
.map(authenticatedDevice -> accounts.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)));
final Optional<Account> maybeTarget = accounts.getByServiceIdentifier(targetIdentifier);
if (groupSendToken.isPresent()) {
if (auth.isPresent() || accessKey.isPresent()) {
if (maybeAuthenticatedDevice.isPresent() || accessKey.isPresent()) {
throw new BadRequestException();
}
try {
@ -364,7 +383,7 @@ public class KeysController {
if (account.isPresent()) {
rateLimiters.getPreKeysLimiter().validate(
account.get().getUuid() + "." + auth.get().getAuthenticatedDevice().getId() + "__" + targetIdentifier.uuid()
account.get().getUuid() + "." + maybeAuthenticatedDevice.get().getDeviceId() + "__" + targetIdentifier.uuid()
+ "." + deviceId);
}

View File

@ -105,6 +105,7 @@ import org.whispersystems.textsecuregcm.spam.SpamChecker;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
@ -112,7 +113,6 @@ import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebSocketConnection;
import org.whispersystems.websocket.WebsocketHeaders;
import org.whispersystems.websocket.auth.ReadOnly;
import reactor.core.scheduler.Scheduler;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@ -236,7 +236,7 @@ public class MessageController {
@ApiResponse(
responseCode="428",
description="The sender should complete a challenge before proceeding")
public Response sendMessage(@ReadOnly @Auth final Optional<AuthenticatedDevice> source,
public Response sendMessage(@Auth final Optional<AuthenticatedDevice> source,
@Parameter(description="The recipient's unidentified access key")
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) final Optional<Anonymous> accessKey,
@ -274,12 +274,14 @@ public class MessageController {
sendStoryMessage(destinationIdentifier, messages, context);
} else if (source.isPresent()) {
final AuthenticatedDevice authenticatedDevice = source.get();
final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
if (authenticatedDevice.getAccount().isIdentifiedBy(destinationIdentifier)) {
if (account.isIdentifiedBy(destinationIdentifier)) {
needsSync = false;
sendSyncMessage(source.get(), destinationIdentifier, messages, context);
sendSyncMessage(source.get(), account, destinationIdentifier, messages, context);
} else {
needsSync = authenticatedDevice.getAccount().getDevices().size() > 1;
needsSync = account.getDevices().size() > 1;
sendIdentifiedSenderIndividualMessage(authenticatedDevice, destinationIdentifier, messages, context);
}
} else {
@ -302,7 +304,7 @@ public class MessageController {
final Account destination =
accountsManager.getByServiceIdentifier(destinationIdentifier).orElseThrow(NotFoundException::new);
rateLimiters.getMessagesLimiter().validate(source.getAccount().getUuid(), destination.getUuid());
rateLimiters.getMessagesLimiter().validate(source.getAccountIdentifier(), destination.getUuid());
sendIndividualMessage(destination,
destinationIdentifier,
@ -314,6 +316,7 @@ public class MessageController {
}
private void sendSyncMessage(final AuthenticatedDevice source,
final Account sourceAccount,
final ServiceIdentifier destinationIdentifier,
final IncomingMessageList messages,
final ContainerRequestContext context)
@ -323,7 +326,7 @@ public class MessageController {
throw new WebApplicationException(Status.FORBIDDEN);
}
sendIndividualMessage(source.getAccount(),
sendIndividualMessage(sourceAccount,
destinationIdentifier,
source,
messages,
@ -420,8 +423,8 @@ public class MessageController {
try {
return message.toEnvelope(
destinationIdentifier,
sender != null ? sender.getAccount() : null,
sender != null ? sender.getAuthenticatedDevice().getId() : null,
sender != null ? new AciServiceIdentifier(sender.getAccountIdentifier()) : null,
sender != null ? sender.getDeviceId() : null,
messages.timestamp() == 0 ? System.currentTimeMillis() : messages.timestamp(),
isStory,
messages.online(),
@ -437,7 +440,7 @@ public class MessageController {
.collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId));
final Optional<Byte> syncMessageSenderDeviceId = messageType == MessageType.SYNC
? Optional.ofNullable(sender).map(authenticatedDevice -> authenticatedDevice.getAuthenticatedDevice().getId())
? Optional.ofNullable(sender).map(AuthenticatedDevice::getDeviceId)
: Optional.empty();
try {
@ -755,31 +758,37 @@ public class MessageController {
@Timed
@GET
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<OutgoingMessageEntityList> getPendingMessages(@ReadOnly @Auth AuthenticatedDevice auth,
public CompletableFuture<OutgoingMessageEntityList> getPendingMessages(@Auth AuthenticatedDevice auth,
@HeaderParam(WebsocketHeaders.X_SIGNAL_RECEIVE_STORIES) String receiveStoriesHeader,
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent) {
boolean shouldReceiveStories = WebsocketHeaders.parseReceiveStoriesHeader(receiveStoriesHeader);
return accountsManager.getByAccountIdentifierAsync(auth.getAccountIdentifier())
.thenCompose(maybeAccount -> {
final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final Device device = account.getDevice(auth.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), auth.getAuthenticatedDevice(), userAgent);
final boolean shouldReceiveStories = WebsocketHeaders.parseReceiveStoriesHeader(receiveStoriesHeader);
return messagesManager.getMessagesForDevice(
auth.getAccount().getUuid(),
auth.getAuthenticatedDevice(),
false)
.map(messagesAndHasMore -> {
Stream<Envelope> envelopes = messagesAndHasMore.first().stream();
if (!shouldReceiveStories) {
envelopes = envelopes.filter(e -> !e.getStory());
}
pushNotificationManager.handleMessagesRetrieved(account, device, userAgent);
return messagesManager.getMessagesForDevice(
auth.getAccountIdentifier(),
device,
false)
.map(messagesAndHasMore -> {
Stream<Envelope> envelopes = messagesAndHasMore.first().stream();
if (!shouldReceiveStories) {
envelopes = envelopes.filter(e -> !e.getStory());
}
final OutgoingMessageEntityList messages = new OutgoingMessageEntityList(envelopes
.map(OutgoingMessageEntity::fromEnvelope)
.peek(outgoingMessageEntity -> {
messageMetrics.measureAccountOutgoingMessageUuidMismatches(auth.getAccount(), outgoingMessageEntity);
messageMetrics.measureAccountOutgoingMessageUuidMismatches(account, outgoingMessageEntity);
messageMetrics.measureOutgoingMessageLatency(outgoingMessageEntity.serverTimestamp(),
"rest",
auth.getAuthenticatedDevice().isPrimary(),
auth.getDeviceId() == Device.PRIMARY_ID,
outgoingMessageEntity.urgent(),
// Messages fetched via this endpoint (as opposed to WebSocketConnection) are never ephemeral
// because, by definition, the client doesn't have a "live" connection via which to receive
@ -791,26 +800,27 @@ public class MessageController {
.collect(Collectors.toList()),
messagesAndHasMore.second());
Metrics.summary(OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)))
.record(estimateMessageListSizeBytes(messages));
Metrics.summary(OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)))
.record(estimateMessageListSizeBytes(messages));
if (!messages.messages().isEmpty()) {
messageDeliveryLoopMonitor.recordDeliveryAttempt(auth.getAccount().getIdentifier(IdentityType.ACI),
auth.getAuthenticatedDevice().getId(),
messages.messages().getFirst().guid(),
userAgent,
"rest");
}
if (!messages.messages().isEmpty()) {
messageDeliveryLoopMonitor.recordDeliveryAttempt(auth.getAccountIdentifier(),
auth.getDeviceId(),
messages.messages().getFirst().guid(),
userAgent,
"rest");
}
if (messagesAndHasMore.second()) {
pushNotificationScheduler.scheduleDelayedNotification(auth.getAccount(), auth.getAuthenticatedDevice(), NOTIFY_FOR_REMAINING_MESSAGES_DELAY);
}
if (messagesAndHasMore.second()) {
pushNotificationScheduler.scheduleDelayedNotification(account, device, NOTIFY_FOR_REMAINING_MESSAGES_DELAY);
}
return messages;
})
.timeout(Duration.ofSeconds(5))
.subscribeOn(messageDeliveryScheduler)
.toFuture();
return messages;
})
.timeout(Duration.ofSeconds(5))
.subscribeOn(messageDeliveryScheduler)
.toFuture();
});
}
private static long estimateMessageListSizeBytes(final OutgoingMessageEntityList messageList) {
@ -827,22 +837,27 @@ public class MessageController {
@Timed
@DELETE
@Path("/uuid/{uuid}")
public CompletableFuture<Response> removePendingMessage(@ReadOnly @Auth AuthenticatedDevice auth, @PathParam("uuid") UUID uuid) {
public CompletableFuture<Response> removePendingMessage(@Auth AuthenticatedDevice auth, @PathParam("uuid") UUID uuid) {
final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
final Device device = account.getDevice(auth.getDeviceId())
.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED));
return messagesManager.delete(
auth.getAccount().getUuid(),
auth.getAuthenticatedDevice(),
auth.getAccountIdentifier(),
device,
uuid,
null)
.thenAccept(maybeRemovedMessage -> maybeRemovedMessage.ifPresent(removedMessage -> {
WebSocketConnection.recordMessageDeliveryDuration(removedMessage.serverTimestamp(),
auth.getAuthenticatedDevice());
WebSocketConnection.recordMessageDeliveryDuration(removedMessage.serverTimestamp(), device);
if (removedMessage.sourceServiceId().isPresent()
&& removedMessage.envelopeType() != Type.SERVER_DELIVERY_RECEIPT) {
if (removedMessage.sourceServiceId().get() instanceof AciServiceIdentifier aciServiceIdentifier) {
try {
receiptSender.sendReceipt(removedMessage.destinationServiceId(), auth.getAuthenticatedDevice().getId(),
receiptSender.sendReceipt(removedMessage.destinationServiceId(), auth.getDeviceId(),
aciServiceIdentifier, removedMessage.clientTimestamp());
} catch (Exception e) {
logger.warn("Failed to send delivery receipt", e);
@ -863,7 +878,7 @@ public class MessageController {
@Consumes(MediaType.APPLICATION_JSON)
@Path("/report/{source}/{messageGuid}")
public Response reportSpamMessage(
@ReadOnly @Auth AuthenticatedDevice auth,
@Auth AuthenticatedDevice auth,
@PathParam("source") String source,
@PathParam("messageGuid") UUID messageGuid,
@Nullable SpamReport spamReport,
@ -899,7 +914,7 @@ public class MessageController {
}
}
UUID spamReporterUuid = auth.getAccount().getUuid();
UUID spamReporterUuid = auth.getAccountIdentifier();
// spam report token is optional, but if provided ensure it is non-empty.
final Optional<byte[]> maybeSpamReportToken =

View File

@ -5,8 +5,8 @@
package org.whispersystems.textsecuregcm.controllers;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import java.util.Map;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
public class MultiRecipientMismatchedDevicesException extends Exception {

View File

@ -69,7 +69,6 @@ import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.websocket.auth.ReadOnly;
/**
@ -163,7 +162,7 @@ public class OneTimeDonationController {
@StringToClassMapItem(key = "error", value = String.class)
})))
public CompletableFuture<Response> createBoostPaymentIntent(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@NotNull @Valid CreateBoostRequest request,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) {
@ -249,7 +248,7 @@ public class OneTimeDonationController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> createPayPalBoost(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@NotNull @Valid CreatePayPalBoostRequest request,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Context ContainerRequestContext containerRequestContext) {
@ -296,7 +295,7 @@ public class OneTimeDonationController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> confirmPayPalBoost(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@NotNull @Valid ConfirmPayPalBoostRequest request,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) {
@ -342,7 +341,7 @@ public class OneTimeDonationController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> createBoostReceiptCredentials(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@NotNull @Valid final CreateBoostReceiptCredentialsRequest request,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) {

View File

@ -17,7 +17,6 @@ import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator
import org.whispersystems.textsecuregcm.configuration.PaymentsServiceConfiguration;
import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager;
import org.whispersystems.textsecuregcm.entities.CurrencyConversionEntityList;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v1/payments")
@Tag(name = "Payments")
@ -43,14 +42,14 @@ public class PaymentsController {
@GET
@Path("/auth")
@Produces(MediaType.APPLICATION_JSON)
public ExternalServiceCredentials getAuth(final @ReadOnly @Auth AuthenticatedDevice auth) {
return paymentsServiceCredentialsGenerator.generateForUuid(auth.getAccount().getUuid());
public ExternalServiceCredentials getAuth(final @Auth AuthenticatedDevice auth) {
return paymentsServiceCredentialsGenerator.generateForUuid(auth.getAccountIdentifier());
}
@GET
@Path("/conversions")
@Produces(MediaType.APPLICATION_JSON)
public CurrencyConversionEntityList getConversions(final @ReadOnly @Auth AuthenticatedDevice auth) {
public CurrencyConversionEntityList getConversions(final @Auth AuthenticatedDevice auth) {
return currencyManager.getCurrencyConversions().orElseThrow();
}
}

View File

@ -26,6 +26,7 @@ import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.QueryParam;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.HttpHeaders;
@ -94,8 +95,6 @@ import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.ProfileHelper;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.auth.Mutable;
import org.whispersystems.websocket.auth.ReadOnly;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v1/profile")
@ -152,15 +151,18 @@ public class ProfileController {
@PUT
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
public Response setProfile(@Mutable @Auth AuthenticatedDevice auth, @NotNull @Valid CreateProfileRequest request) {
public Response setProfile(@Auth AuthenticatedDevice auth, @NotNull @Valid CreateProfileRequest request) {
final Optional<VersionedProfile> currentProfile = profilesManager.get(auth.getAccount().getUuid(),
request.version());
final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED));
final Optional<VersionedProfile> currentProfile =
profilesManager.get(auth.getAccountIdentifier(), request.version());
if (request.paymentAddress() != null && request.paymentAddress().length != 0) {
final boolean hasDisallowedPrefix =
dynamicConfigurationManager.getConfiguration().getPaymentsConfiguration().getDisallowedPrefixes().stream()
.anyMatch(prefix -> auth.getAccount().getNumber().startsWith(prefix));
.anyMatch(prefix -> account.getNumber().startsWith(prefix));
if (hasDisallowedPrefix && currentProfile.map(VersionedProfile::paymentAddress).isEmpty()) {
return Response.status(Response.Status.FORBIDDEN).build();
@ -179,7 +181,7 @@ public class ProfileController {
case UPDATE -> ProfileHelper.generateAvatarObjectName();
};
profilesManager.set(auth.getAccount().getUuid(),
profilesManager.set(auth.getAccountIdentifier(),
new VersionedProfile(
request.version(),
request.name(),
@ -194,7 +196,7 @@ public class ProfileController {
currentAvatar.ifPresent(s -> profilesManager.deleteAvatar(s).join());
}
accountsManager.update(auth.getAccount(), a -> {
accountsManager.update(account, a -> {
final List<AccountBadge> updatedBadges = request.badges()
.map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, a.getBadges()))
@ -216,7 +218,7 @@ public class ProfileController {
@Path("/{identifier}/{version}")
@ManagedAsync
public VersionedProfileResponse getProfile(
@ReadOnly @Auth Optional<AuthenticatedDevice> auth,
@Auth Optional<AuthenticatedDevice> maybeAuthenticatedDevice,
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@Context ContainerRequestContext containerRequestContext,
@PathParam("identifier") AciServiceIdentifier accountIdentifier,
@ -224,7 +226,11 @@ public class ProfileController {
@HeaderParam(HttpHeaders.USER_AGENT) String userAgent)
throws RateLimitExceededException {
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
final Optional<Account> maybeRequester =
maybeAuthenticatedDevice.map(
authenticatedDevice -> accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)));
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier, "getVersionedProfile", userAgent);
return buildVersionedProfileResponse(targetAccount,
@ -238,7 +244,7 @@ public class ProfileController {
@Produces(MediaType.APPLICATION_JSON)
@Path("/{identifier}/{version}/{credentialRequest}")
public CredentialProfileResponse getProfile(
@ReadOnly @Auth Optional<AuthenticatedDevice> auth,
@Auth Optional<AuthenticatedDevice> maybeAuthenticatedDevice,
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@Context ContainerRequestContext containerRequestContext,
@PathParam("identifier") AciServiceIdentifier accountIdentifier,
@ -252,7 +258,11 @@ public class ProfileController {
throw new BadRequestException();
}
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
final Optional<Account> maybeRequester =
maybeAuthenticatedDevice.map(
authenticatedDevice -> accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)));
final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier, "credentialRequest", userAgent);
final boolean isSelf = maybeRequester.map(requester -> ProfileHelper.isSelfProfileRequest(requester.getUuid(), accountIdentifier)).orElse(false);
@ -270,7 +280,7 @@ public class ProfileController {
@Path("/{identifier}")
@ManagedAsync
public BaseProfileResponse getUnversionedProfile(
@ReadOnly @Auth Optional<AuthenticatedDevice> auth,
@Auth Optional<AuthenticatedDevice> maybeAuthenticatedDevice,
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@HeaderParam(HeaderUtils.GROUP_SEND_TOKEN) Optional<GroupSendTokenHeader> groupSendToken,
@Context ContainerRequestContext containerRequestContext,
@ -278,7 +288,10 @@ public class ProfileController {
@PathParam("identifier") ServiceIdentifier identifier)
throws RateLimitExceededException {
final Optional<Account> maybeRequester = auth.map(AuthenticatedDevice::getAccount);
final Optional<Account> maybeRequester =
maybeAuthenticatedDevice.map(
authenticatedDevice -> accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier())
.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)));
final Account targetAccount;
if (groupSendToken.isPresent()) {

View File

@ -34,7 +34,6 @@ import org.whispersystems.textsecuregcm.entities.ProvisioningMessage;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
import org.whispersystems.websocket.auth.ReadOnly;
/**
* The provisioning controller facilitates transmission of provisioning messages from the primary device associated with
@ -77,7 +76,7 @@ public class ProvisioningController {
@ApiResponse(responseCode="204", description="The provisioning message was delivered to the given provisioning address")
@ApiResponse(responseCode="400", description="The provisioning message was too large")
@ApiResponse(responseCode="404", description="No device with the given provisioning address was connected at the time of the request")
public void sendProvisioningMessage(@ReadOnly @Auth final AuthenticatedDevice auth,
public void sendProvisioningMessage(@Auth final AuthenticatedDevice auth,
@Parameter(description = "The temporary provisioning address to which to send a provisioning message")
@PathParam("destination") final String provisioningAddress,
@ -93,7 +92,7 @@ public class ProvisioningController {
throw new WebApplicationException(Response.Status.BAD_REQUEST);
}
rateLimiters.getMessagesLimiter().validate(auth.getAccount().getUuid());
rateLimiters.getMessagesLimiter().validate(auth.getAccountIdentifier());
final boolean subscriberPresent =
provisioningManager.sendProvisioningMessage(provisioningAddress, Base64.getMimeDecoder().decode(message.body()));

View File

@ -30,7 +30,6 @@ import org.whispersystems.textsecuregcm.entities.UserRemoteConfigList;
import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager;
import org.whispersystems.textsecuregcm.util.Conversions;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v1/config")
@Tag(name = "Remote Config")
@ -64,7 +63,7 @@ public class RemoteConfigController {
"""
)
@ApiResponse(responseCode = "200", description = "Remote configuration values for the authenticated user", useReturnTypeSchema = true)
public UserRemoteConfigList getAll(@ReadOnly @Auth AuthenticatedDevice auth) {
public UserRemoteConfigList getAll(@Auth AuthenticatedDevice auth) {
try {
MessageDigest digest = MessageDigest.getInstance("SHA1");
@ -73,7 +72,7 @@ public class RemoteConfigController {
return new UserRemoteConfigList(Stream.concat(remoteConfigsManager.getAll().stream().map(config -> {
final byte[] hashKey = config.getHashKey() != null ? config.getHashKey().getBytes(StandardCharsets.UTF_8)
: config.getName().getBytes(StandardCharsets.UTF_8);
boolean inBucket = isInBucket(digest, auth.getAccount().getUuid(), hashKey, config.getPercentage(),
boolean inBucket = isInBucket(digest, auth.getAccountIdentifier(), hashKey, config.getPercentage(),
config.getUuids());
return new UserRemoteConfig(config.getName(), inBucket,
inBucket ? config.getValue() : config.getDefaultValue());

View File

@ -17,7 +17,6 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
import org.whispersystems.textsecuregcm.configuration.SecureStorageServiceConfiguration;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v1/storage")
@Tag(name = "Secure Storage")
@ -47,7 +46,7 @@ public class SecureStorageController {
"""
)
@ApiResponse(responseCode = "200", description = "`JSON` with generated credentials.", useReturnTypeSchema = true)
public ExternalServiceCredentials getAuth(@ReadOnly @Auth AuthenticatedDevice auth) {
return storageServiceCredentialsGenerator.generateForUuid(auth.getAccount().getUuid());
public ExternalServiceCredentials getAuth(@Auth AuthenticatedDevice auth) {
return storageServiceCredentialsGenerator.generateForUuid(auth.getAccountIdentifier());
}
}

View File

@ -34,7 +34,6 @@ import org.whispersystems.textsecuregcm.limits.RateLimitedByIp;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v2/backup")
@Tag(name = "Secure Value Recovery")
@ -78,8 +77,8 @@ public class SecureValueRecovery2Controller {
)
@ApiResponse(responseCode = "200", description = "`JSON` with generated credentials.", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
public ExternalServiceCredentials getAuth(@ReadOnly @Auth final AuthenticatedDevice auth) {
return backupServiceCredentialGenerator.generateFor(auth.getAccount().getUuid().toString());
public ExternalServiceCredentials getAuth(@Auth final AuthenticatedDevice auth) {
return backupServiceCredentialGenerator.generateFor(auth.getAccountIdentifier().toString());
}

View File

@ -28,7 +28,6 @@ import org.whispersystems.textsecuregcm.s3.PolicySigner;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v1/sticker")
@Tag(name = "Stickers")
@ -47,10 +46,10 @@ public class StickerController {
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/pack/form/{count}")
public StickerPackFormUploadAttributes getStickersForm(@ReadOnly @Auth AuthenticatedDevice auth,
public StickerPackFormUploadAttributes getStickersForm(@Auth AuthenticatedDevice auth,
@PathParam("count") @Min(1) @Max(201) int stickerCount)
throws RateLimitExceededException {
rateLimiters.getStickerPackLimiter().validate(auth.getAccount().getUuid());
rateLimiters.getStickerPackLimiter().validate(auth.getAccountIdentifier());
ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC);
String packId = generatePackId();

View File

@ -88,7 +88,6 @@ import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v1/subscription")
@io.swagger.v3.oas.annotations.tags.Tag(name = "Subscriptions")
@ -220,7 +219,7 @@ public class SubscriptionController {
@Path("/{subscriberId}")
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> deleteSubscriber(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId) throws SubscriptionException {
SubscriberCredentials subscriberCredentials =
SubscriberCredentials.process(authenticatedAccount, subscriberId, clock);
@ -232,7 +231,7 @@ public class SubscriptionController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> updateSubscriber(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId) throws SubscriptionException {
SubscriberCredentials subscriberCredentials =
SubscriberCredentials.process(authenticatedAccount, subscriberId, clock);
@ -248,7 +247,7 @@ public class SubscriptionController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> createPaymentMethod(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId,
@QueryParam("type") @DefaultValue("CARD") PaymentMethod paymentMethodType,
@HeaderParam(HttpHeaders.USER_AGENT) @Nullable final String userAgentString) throws SubscriptionException {
@ -284,7 +283,7 @@ public class SubscriptionController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> createPayPalPaymentMethod(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId,
@NotNull @Valid CreatePayPalBillingAgreementRequest request,
@Context ContainerRequestContext containerRequestContext,
@ -323,7 +322,7 @@ public class SubscriptionController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> setDefaultPaymentMethodWithProcessor(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId,
@PathParam("processor") PaymentProvider processor,
@PathParam("paymentMethodToken") @NotEmpty String paymentMethodToken) throws SubscriptionException {
@ -360,7 +359,7 @@ public class SubscriptionController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> setSubscriptionLevel(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId,
@PathParam("level") long level,
@PathParam("currency") String currency,
@ -432,7 +431,7 @@ public class SubscriptionController {
@ApiResponse(responseCode = "409", description = "subscriberId is already linked to a processor that does not support appstore payments. Delete this subscriberId and use a new one.")
@ApiResponse(responseCode = "429", description = "Rate limit exceeded.")
public CompletableFuture<SetSubscriptionLevelSuccessResponse> setAppStoreSubscription(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId,
@PathParam("originalTransactionId") String originalTransactionId) throws SubscriptionException {
final SubscriberCredentials subscriberCredentials =
@ -473,7 +472,7 @@ public class SubscriptionController {
@ApiResponse(responseCode = "404", description = "No such subscriberId exists or subscriberId is malformed or the purchaseToken does not exist")
@ApiResponse(responseCode = "409", description = "subscriberId is already linked to a processor that does not support Play Billing. Delete this subscriberId and use a new one.")
public CompletableFuture<SetSubscriptionLevelSuccessResponse> setPlayStoreSubscription(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId,
@PathParam("purchaseToken") String purchaseToken) throws SubscriptionException {
final SubscriberCredentials subscriberCredentials =
@ -627,7 +626,7 @@ public class SubscriptionController {
@ApiResponse(responseCode = "403", description = "subscriberId authentication failure OR account authentication is present")
@ApiResponse(responseCode = "404", description = "No such subscriberId exists or subscriberId is malformed")
public CompletableFuture<Response> getSubscriptionInformation(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId) throws SubscriptionException {
SubscriberCredentials subscriberCredentials =
SubscriberCredentials.process(authenticatedAccount, subscriberId, clock);
@ -662,7 +661,7 @@ public class SubscriptionController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> createSubscriptionReceiptCredentials(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@PathParam("subscriberId") String subscriberId,
@NotNull @Valid GetReceiptCredentialsRequest request) throws SubscriptionException {
@ -691,7 +690,7 @@ public class SubscriptionController {
@Path("/{subscriberId}/default_payment_method_for_ideal/{setupIntentId}")
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> setDefaultPaymentMethodForIdeal(
@ReadOnly @Auth Optional<AuthenticatedDevice> authenticatedAccount,
@Auth Optional<AuthenticatedDevice> authenticatedAccount,
@PathParam("subscriberId") String subscriberId,
@PathParam("setupIntentId") @NotEmpty String setupIntentId) throws SubscriptionException {
SubscriberCredentials subscriberCredentials =

View File

@ -5,9 +5,9 @@
package org.whispersystems.textsecuregcm.controllers;
import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession;
import javax.annotation.Nullable;
import java.time.Duration;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession;
public class VerificationSessionRateLimitExceededException extends RateLimitExceededException {

View File

@ -10,15 +10,14 @@ import com.webauthn4j.converter.jackson.deserializer.json.ByteArrayBase64Deseria
import io.micrometer.core.instrument.Metrics;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.AssertTrue;
import javax.annotation.Nullable;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import java.util.Arrays;
import java.util.Objects;
import javax.annotation.Nullable;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import java.util.Arrays;
import java.util.Objects;
public record IncomingMessage(int type,
byte destinationDeviceId,
@ -35,7 +34,7 @@ public record IncomingMessage(int type,
MetricsUtil.name(IncomingMessage.class, "rejectInvalidEnvelopeType");
public MessageProtos.Envelope toEnvelope(final ServiceIdentifier destinationIdentifier,
@Nullable Account sourceAccount,
@Nullable AciServiceIdentifier sourceServiceIdentifier,
@Nullable Byte sourceDeviceId,
final long timestamp,
final boolean story,
@ -54,9 +53,9 @@ public record IncomingMessage(int type,
.setEphemeral(ephemeral)
.setUrgent(urgent);
if (sourceAccount != null && sourceDeviceId != null) {
if (sourceServiceIdentifier != null && sourceDeviceId != null) {
envelopeBuilder
.setSourceServiceId(new AciServiceIdentifier(sourceAccount.getUuid()).toServiceIdentifierString())
.setSourceServiceId(sourceServiceIdentifier.toServiceIdentifierString())
.setSourceDevice(sourceDeviceId.intValue());
}

View File

@ -96,8 +96,8 @@ public class RestDeprecationFilter implements ContainerRequestFilter {
return false;
}
if (securityContext.getUserPrincipal() instanceof AuthenticatedDevice ad) {
return experimentEnrollmentManager.isEnrolled(ad.getAccount().getUuid(), AUTHENTICATED_EXPERIMENT_NAME);
if (securityContext.getUserPrincipal() instanceof AuthenticatedDevice authenticatedDevice) {
return experimentEnrollmentManager.isEnrolled(authenticatedDevice.getAccountIdentifier(), AUTHENTICATED_EXPERIMENT_NAME);
} else {
log.error("Security context was not null but user principal was of type {}", securityContext.getUserPrincipal().getClass().getName());
return false;

View File

@ -9,8 +9,6 @@ import java.util.Optional;
import jakarta.ws.rs.core.Response;
import org.signal.chat.messages.SendMessageResponse;
import org.signal.chat.messages.SendMultiRecipientMessageResponse;
import org.whispersystems.textsecuregcm.auth.AccountAndAuthenticatedDeviceHolder;
import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account;
@ -31,7 +29,7 @@ public interface SpamChecker {
SpamCheckResult<Response> checkForIndividualRecipientSpamHttp(
final MessageType messageType,
final ContainerRequestContext requestContext,
final Optional<? extends AccountAndAuthenticatedDeviceHolder> maybeSource,
final Optional<org.whispersystems.textsecuregcm.auth.AuthenticatedDevice> maybeSource,
final Optional<Account> maybeDestination,
final ServiceIdentifier destinationIdentifier);
@ -79,7 +77,7 @@ public interface SpamChecker {
@Override
public SpamCheckResult<Response> checkForIndividualRecipientSpamHttp(final MessageType messageType,
final ContainerRequestContext requestContext,
final Optional<? extends AccountAndAuthenticatedDeviceHolder> maybeSource,
final Optional<org.whispersystems.textsecuregcm.auth.AuthenticatedDevice> maybeSource,
final Optional<Account> maybeDestination,
final ServiceIdentifier destinationIdentifier) {
@ -95,7 +93,7 @@ public interface SpamChecker {
@Override
public SpamCheckResult<GrpcResponse<SendMessageResponse>> checkForIndividualRecipientSpamGrpc(final MessageType messageType,
final Optional<AuthenticatedDevice> maybeSource,
final Optional<org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice> maybeSource,
final Optional<Account> maybeDestination,
final ServiceIdentifier destinationIdentifier) {

View File

@ -1,36 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.websocket.auth.PrincipalSupplier;
public class AccountPrincipalSupplier implements PrincipalSupplier<AuthenticatedDevice> {
private final AccountsManager accountsManager;
public AccountPrincipalSupplier(final AccountsManager accountsManager) {
this.accountsManager = accountsManager;
}
@Override
public AuthenticatedDevice refresh(final AuthenticatedDevice oldAccount) {
final Account account = accountsManager.getByAccountIdentifier(oldAccount.getAccount().getUuid())
.orElseThrow(() -> new RefreshingAccountNotFoundException("Could not find account"));
final Device device = account.getDevice(oldAccount.getAuthenticatedDevice().getId())
.orElseThrow(() -> new RefreshingAccountNotFoundException("Could not find device"));
return new AuthenticatedDevice(account, device);
}
@Override
public AuthenticatedDevice deepCopy(final AuthenticatedDevice authenticatedDevice) {
final Account cloned = AccountUtil.cloneAccountAsNotStale(authenticatedDevice.getAccount());
return new AuthenticatedDevice(
cloned,
cloned.getDevice(authenticatedDevice.getAuthenticatedDevice().getId())
.orElseThrow(() -> new IllegalStateException(
"Could not find device from a clone of an account where the device was present")));
}
}

View File

@ -600,13 +600,10 @@ public class AccountsManager extends RedisPubSubAdapter<String, String> implemen
}
/**
* Unlink a device from the given account. The device will be immediately disconnected if it is
* connected to any chat frontend, but it is the caller's responsibility to make sure that the
* account's *other* devices are disconnected, either by use of
* {@link org.whispersystems.textsecuregcm.auth.LinkedDeviceRefreshRequirementProvider} or
* directly by calling {@link DeviceDisconnectionManager#requestDisconnection}.
* Unlink a device from the given account. The device will be immediately disconnected if it is connected to any chat
* frontend.
*
* @returns the updated Account
* @return the updated Account
*/
public CompletableFuture<Account> removeDevice(final Account account, final byte deviceId) {
if (deviceId == Device.PRIMARY_ID) {

View File

@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.websocket;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.micrometer.core.instrument.Tags;
import java.util.Optional;
import java.util.concurrent.ScheduledExecutorService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -20,7 +21,10 @@ import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener;
@ -36,6 +40,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private static final Logger log = LoggerFactory.getLogger(AuthenticatedConnectListener.class);
private final AccountsManager accountsManager;
private final ReceiptSender receiptSender;
private final MessagesManager messagesManager;
private final MessageMetrics messageMetrics;
@ -51,7 +56,9 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private final OpenWebSocketCounter openAuthenticatedWebSocketCounter;
private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter;
public AuthenticatedConnectListener(ReceiptSender receiptSender,
public AuthenticatedConnectListener(
AccountsManager accountsManager,
ReceiptSender receiptSender,
MessagesManager messagesManager,
MessageMetrics messageMetrics,
PushNotificationManager pushNotificationManager,
@ -62,6 +69,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
ClientReleaseManager clientReleaseManager,
MessageDeliveryLoopMonitor messageDeliveryLoopMonitor,
final ExperimentEnrollmentManager experimentEnrollmentManager) {
this.accountsManager = accountsManager;
this.receiptSender = receiptSender;
this.messagesManager = messagesManager;
this.messageMetrics = messageMetrics;
@ -82,7 +91,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
}
@Override
public void onWebSocketConnect(WebSocketSessionContext context) {
public void onWebSocketConnect(final WebSocketSessionContext context) {
final boolean authenticated = (context.getAuthenticated() != null);
final OpenWebSocketCounter openWebSocketCounter =
@ -92,12 +101,24 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
if (authenticated) {
final AuthenticatedDevice auth = context.getAuthenticated(AuthenticatedDevice.class);
final Optional<Account> maybeAuthenticatedAccount = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier());
final Optional<Device> maybeAuthenticatedDevice = maybeAuthenticatedAccount.flatMap(account -> account.getDevice(auth.getDeviceId()));;
if (maybeAuthenticatedAccount.isEmpty() || maybeAuthenticatedDevice.isEmpty()) {
log.warn("{}:{} not found when opening authenticated WebSocket", auth.getAccountIdentifier(), auth.getDeviceId());
context.getClient().close(1011, "Unexpected error initializing connection");
return;
}
final WebSocketConnection connection = new WebSocketConnection(receiptSender,
messagesManager,
messageMetrics,
pushNotificationManager,
pushNotificationScheduler,
auth,
maybeAuthenticatedAccount.get(),
maybeAuthenticatedDevice.get(),
context.getClient(),
scheduledExecutorService,
messageDeliveryScheduler,
@ -110,8 +131,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
// receive push notifications for inbound messages. We should do this first because, at this point, the
// connection has already closed and attempts to actually deliver a message via the connection will not succeed.
// It's preferable to start sending push notifications as soon as possible.
webSocketConnectionEventManager.handleClientDisconnected(auth.getAccount().getUuid(),
auth.getAuthenticatedDevice().getId());
webSocketConnectionEventManager.handleClientDisconnected(auth.getAccountIdentifier(), auth.getDeviceId());
// Finally, stop trying to deliver messages and send a push notification if the connection is aware of any
// undelivered messages.
@ -127,7 +147,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
// Finally, we register this client's presence, which suppresses push notifications. We do this last because
// receiving extra push notifications is generally preferable to missing out on a push notification.
webSocketConnectionEventManager.handleClientConnected(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), connection);
webSocketConnectionEventManager.handleClientConnected(auth.getAccountIdentifier(), auth.getDeviceId(), connection);
} catch (final Exception e) {
log.warn("Failed to initialize websocket", e);
context.getClient().close(1011, "Unexpected error initializing connection");

View File

@ -9,41 +9,39 @@ import static org.whispersystems.textsecuregcm.util.HeaderUtils.basicCredentials
import com.google.common.net.HttpHeaders;
import javax.annotation.Nullable;
import io.dropwizard.auth.basic.BasicCredentials;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import java.util.Optional;
public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<AuthenticatedDevice> {
private static final ReusableAuth<AuthenticatedDevice> CREDENTIALS_NOT_PRESENTED = ReusableAuth.anonymous();
private final AccountAuthenticator accountAuthenticator;
private final PrincipalSupplier<AuthenticatedDevice> principalSupplier;
public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator,
final PrincipalSupplier<AuthenticatedDevice> principalSupplier) {
public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator) {
this.accountAuthenticator = accountAuthenticator;
this.principalSupplier = principalSupplier;
}
@Override
public ReusableAuth<AuthenticatedDevice> authenticate(final UpgradeRequest request)
public Optional<AuthenticatedDevice> authenticate(final UpgradeRequest request)
throws InvalidCredentialsException {
@Nullable final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);
if (authHeader == null) {
return CREDENTIALS_NOT_PRESENTED;
return Optional.empty();
}
return basicCredentialsFromAuthHeader(authHeader)
.flatMap(accountAuthenticator::authenticate)
.map(authenticatedAccount -> ReusableAuth.authenticated(authenticatedAccount, this.principalSupplier))
final BasicCredentials credentials = basicCredentialsFromAuthHeader(authHeader)
.orElseThrow(InvalidCredentialsException::new);
final AuthenticatedDevice authenticatedDevice = accountAuthenticator.authenticate(credentials)
.orElseThrow(InvalidCredentialsException::new);
return Optional.of(authenticatedDevice);
}
}

View File

@ -37,7 +37,6 @@ import org.eclipse.jetty.util.StaticException;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
@ -52,6 +51,7 @@ import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventListener;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
@ -123,7 +123,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor;
private final ExperimentEnrollmentManager experimentEnrollmentManager;
private final AuthenticatedDevice auth;
private final Account authenticatedAccount;
private final Device authenticatedDevice;
private final WebSocketClient client;
private final int sendFuturesTimeoutMillis;
@ -156,7 +157,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
MessageMetrics messageMetrics,
PushNotificationManager pushNotificationManager,
PushNotificationScheduler pushNotificationScheduler,
AuthenticatedDevice auth,
Account authenticatedAccount,
Device authenticatedDevice,
WebSocketClient client,
ScheduledExecutorService scheduledExecutorService,
Scheduler messageDeliveryScheduler,
@ -169,7 +171,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
messageMetrics,
pushNotificationManager,
pushNotificationScheduler,
auth,
authenticatedAccount,
authenticatedDevice,
client,
DEFAULT_SEND_FUTURES_TIMEOUT_MILLIS,
scheduledExecutorService,
@ -184,7 +187,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
MessageMetrics messageMetrics,
PushNotificationManager pushNotificationManager,
PushNotificationScheduler pushNotificationScheduler,
AuthenticatedDevice auth,
Account authenticatedAccount,
Device authenticatedDevice,
WebSocketClient client,
int sendFuturesTimeoutMillis,
ScheduledExecutorService scheduledExecutorService,
@ -198,7 +202,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
this.messageMetrics = messageMetrics;
this.pushNotificationManager = pushNotificationManager;
this.pushNotificationScheduler = pushNotificationScheduler;
this.auth = auth;
this.authenticatedAccount = authenticatedAccount;
this.authenticatedDevice = authenticatedDevice;
this.client = client;
this.sendFuturesTimeoutMillis = sendFuturesTimeoutMillis;
this.scheduledExecutorService = scheduledExecutorService;
@ -209,7 +214,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
}
public void start() {
pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), auth.getAuthenticatedDevice(), client.getUserAgent());
pushNotificationManager.handleMessagesRetrieved(authenticatedAccount, authenticatedDevice, client.getUserAgent());
queueDrainStartTime.set(System.currentTimeMillis());
processStoredMessages();
}
@ -229,8 +234,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
client.close(1000, "OK");
if (storedMessageState.get() != StoredMessageState.EMPTY) {
pushNotificationScheduler.scheduleDelayedNotification(auth.getAccount(),
auth.getAuthenticatedDevice(),
pushNotificationScheduler.scheduleDelayedNotification(authenticatedAccount,
authenticatedDevice,
CLOSE_WITH_PENDING_MESSAGES_NOTIFICATION_DELAY);
}
}
@ -242,7 +247,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
sendMessageCounter.increment();
sentMessageCounter.increment();
bytesSentCounter.increment(body.map(bytes -> bytes.length).orElse(0));
messageMetrics.measureAccountEnvelopeUuidMismatches(auth.getAccount(), message);
messageMetrics.measureAccountEnvelopeUuidMismatches(authenticatedAccount, message);
// X-Signal-Key: false must be sent until Android stops assuming it missing means true
return client.sendRequest("PUT", "/api/v1/message",
@ -253,7 +258,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
} else {
messageMetrics.measureOutgoingMessageLatency(message.getServerTimestamp(),
"websocket",
auth.getAuthenticatedDevice().isPrimary(),
authenticatedDevice.isPrimary(),
message.getUrgent(),
message.getEphemeral(),
client.getUserAgent(),
@ -263,12 +268,12 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
final CompletableFuture<Void> result;
if (isSuccessResponse(response)) {
result = messagesManager.delete(auth.getAccount().getUuid(), auth.getAuthenticatedDevice(),
result = messagesManager.delete(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice,
storedMessageInfo.guid(), storedMessageInfo.serverTimestamp())
.thenApply(ignored -> null);
if (message.getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT) {
recordMessageDeliveryDuration(message.getServerTimestamp(), auth.getAuthenticatedDevice());
recordMessageDeliveryDuration(message.getServerTimestamp(), authenticatedDevice);
sendDeliveryReceiptFor(message);
}
} else {
@ -307,7 +312,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
try {
receiptSender.sendReceipt(ServiceIdentifier.valueOf(message.getDestinationServiceId()),
auth.getAuthenticatedDevice().getId(), AciServiceIdentifier.valueOf(message.getSourceServiceId()),
authenticatedDevice.getId(), AciServiceIdentifier.valueOf(message.getSourceServiceId()),
message.getClientTimestamp());
} catch (IllegalArgumentException e) {
logger.error("Could not parse UUID: {}", message.getSourceServiceId());
@ -338,7 +343,6 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
// Cleared the queue! Send a queue empty message if we need to
consecutiveRetries.set(0);
if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) {
final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent()));
final long drainDuration = System.currentTimeMillis() - queueDrainStartTime.get();
@ -399,7 +403,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
final CompletableFuture<Void> queueCleared = new CompletableFuture<>();
final Publisher<Envelope> messages =
messagesManager.getMessagesForDeviceReactive(auth.getAccount().getUuid(), auth.getAuthenticatedDevice(), cachedMessagesOnly);
messagesManager.getMessagesForDeviceReactive(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice, cachedMessagesOnly);
final AtomicBoolean hasSentFirstMessage = new AtomicBoolean();
final AtomicBoolean hasErrored = new AtomicBoolean();
@ -410,8 +414,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
.limitRate(MESSAGE_PUBLISHER_LIMIT_RATE)
.doOnNext(envelope -> {
if (hasSentFirstMessage.compareAndSet(false, true)) {
messageDeliveryLoopMonitor.recordDeliveryAttempt(auth.getAccount().getIdentifier(IdentityType.ACI),
auth.getAuthenticatedDevice().getId(),
messageDeliveryLoopMonitor.recordDeliveryAttempt(authenticatedAccount.getIdentifier(IdentityType.ACI),
authenticatedDevice.getId(),
UUID.fromString(envelope.getServerGuid()),
client.getUserAgent(),
"websocket");
@ -471,7 +475,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener {
final UUID messageGuid = UUID.fromString(envelope.getServerGuid());
if (envelope.getStory() && !client.shouldDeliverStories()) {
messagesManager.delete(auth.getAccount().getUuid(), auth.getAuthenticatedDevice(), messageGuid, envelope.getServerTimestamp());
messagesManager.delete(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice, messageGuid, envelope.getServerTimestamp());
return CompletableFuture.completedFuture(null);
} else {

View File

@ -21,6 +21,7 @@ import jakarta.ws.rs.core.MediaType;
import java.io.IOException;
import java.net.URI;
import java.util.EnumSet;
import java.util.Optional;
import org.apache.commons.lang3.RandomStringUtils;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
@ -34,9 +35,7 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ -78,8 +77,7 @@ public class WebsocketResourceProviderIntegrationTest {
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(testController);
webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.setAuthenticator(upgradeRequest ->
ReusableAuth.authenticated(mock(AuthenticatedDevice.class), PrincipalSupplier.forImmutablePrincipal()));
webSocketEnvironment.setAuthenticator(upgradeRequest -> Optional.of(mock(AuthenticatedDevice.class)));
webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE);
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {

View File

@ -1,279 +0,0 @@
package org.whispersystems.textsecuregcm;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
import io.dropwizard.auth.Auth;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletRegistration;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import java.io.IOException;
import java.net.URI;
import java.util.EnumSet;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.IntStream;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync;
import org.glassfish.jersey.server.ServerProperties;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.RefreshingAccountNotFoundException;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
public class WebsocketReuseAuthIntegrationTest {
private static final AuthenticatedDevice ACCOUNT = mock(AuthenticatedDevice.class);
@SuppressWarnings("unchecked")
private static final PrincipalSupplier<AuthenticatedDevice> PRINCIPAL_SUPPLIER = mock(PrincipalSupplier.class);
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION =
new DropwizardAppExtension<>(TestApplication.class);
private WebSocketClient client;
@BeforeEach
void setUp() throws Exception {
reset(PRINCIPAL_SUPPLIER);
reset(ACCOUNT);
when(ACCOUNT.getName()).thenReturn("original");
client = new WebSocketClient();
client.start();
}
@AfterEach
void tearDown() throws Exception {
client.stop();
}
public static class TestApplication extends Application<Configuration> {
@Override
public void run(final Configuration configuration, final Environment environment) throws Exception {
final TestController testController = new TestController();
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
final WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment =
new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(testController);
webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.setAuthenticator(upgradeRequest -> ReusableAuth.authenticated(ACCOUNT, PRINCIPAL_SUPPLIER));
webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE);
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {
});
final WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class,
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
final ServletRegistration.Dynamic websocketServlet =
environment.servlets().addServlet("WebSocket", webSocketServlet);
websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
}
}
private WebSocketResponseMessage make1WebsocketRequest(final String requestPath) throws IOException {
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
return testWebsocketListener.doGet(requestPath).join();
}
@ParameterizedTest
@ValueSource(strings = {"/test/read-auth", "/test/optional-read-auth"})
public void readAuth(final String path) throws IOException {
final WebSocketResponseMessage response = make1WebsocketRequest(path);
assertThat(response.getStatus()).isEqualTo(200);
verifyNoMoreInteractions(PRINCIPAL_SUPPLIER);
}
@ParameterizedTest
@ValueSource(strings = {"/test/write-auth", "/test/optional-write-auth"})
public void writeAuth(final String path) throws IOException {
final AuthenticatedDevice copiedAccount = mock(AuthenticatedDevice.class);
when(copiedAccount.getName()).thenReturn("copy");
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(copiedAccount);
final WebSocketResponseMessage response = make1WebsocketRequest(path);
assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.getBody().map(String::new)).get().isEqualTo("copy");
verify(PRINCIPAL_SUPPLIER, times(1)).deepCopy(any());
verifyNoMoreInteractions(PRINCIPAL_SUPPLIER);
}
@Test
public void readAfterWrite() throws IOException {
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(ACCOUNT);
final AuthenticatedDevice account2 = mock(AuthenticatedDevice.class);
when(account2.getName()).thenReturn("refresh");
when(PRINCIPAL_SUPPLIER.refresh(any())).thenReturn(account2);
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse.getBody().map(String::new)).get().isEqualTo("original");
final WebSocketResponseMessage writeResponse = testWebsocketListener.doGet("/test/write-auth").join();
assertThat(writeResponse.getBody().map(String::new)).get().isEqualTo("original");
final WebSocketResponseMessage readResponse2 = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse2.getBody().map(String::new)).get().isEqualTo("refresh");
}
@Test
public void readAfterWriteRefreshFails() throws IOException {
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(ACCOUNT);
when(PRINCIPAL_SUPPLIER.refresh(any())).thenThrow(RefreshingAccountNotFoundException.class);
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
final WebSocketResponseMessage writeResponse = testWebsocketListener.doGet("/test/write-auth").join();
assertThat(writeResponse.getBody().map(String::new)).get().isEqualTo("original");
final WebSocketResponseMessage readResponse2 = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse2.getStatus()).isEqualTo(500);
}
@Test
public void readConcurrentWithWrite() throws IOException, ExecutionException, InterruptedException, TimeoutException {
final AuthenticatedDevice deepCopy = mock(AuthenticatedDevice.class);
when(deepCopy.getName()).thenReturn("deepCopy");
when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(deepCopy);
final AuthenticatedDevice refresh = mock(AuthenticatedDevice.class);
when(refresh.getName()).thenReturn("refresh");
when(PRINCIPAL_SUPPLIER.refresh(any())).thenReturn(refresh);
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
client.connect(testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())));
// start a write request that takes a while to finish
final CompletableFuture<WebSocketResponseMessage> writeResponse =
testWebsocketListener.doGet("/test/start-delayed-write/foo");
// send a bunch of reads, they should reflect the original auth
final List<CompletableFuture<WebSocketResponseMessage>> futures = IntStream.range(0, 10)
.boxed().map(i -> testWebsocketListener.doGet("/test/read-auth"))
.toList();
CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join();
for (CompletableFuture<WebSocketResponseMessage> future : futures) {
assertThat(future.join().getBody().map(String::new)).get().isEqualTo("original");
}
assertThat(writeResponse.isDone()).isFalse();
// finish the delayed write request
testWebsocketListener.doGet("/test/finish-delayed-write/foo").get(1, TimeUnit.SECONDS);
assertThat(writeResponse.join().getBody().map(String::new)).get().isEqualTo("deepCopy");
// subsequent reads should have the refreshed auth
final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/read-auth").join();
assertThat(readResponse.getBody().map(String::new)).get().isEqualTo("refresh");
}
@Path("/test")
public static class TestController {
private final ConcurrentHashMap<String, CountDownLatch> delayedWriteLatches = new ConcurrentHashMap<>();
@GET
@Path("/read-auth")
@ManagedAsync
public String readAuth(@ReadOnly @Auth final AuthenticatedDevice account) {
return account.getName();
}
@GET
@Path("/optional-read-auth")
@ManagedAsync
public String optionalReadAuth(@ReadOnly @Auth final Optional<AuthenticatedDevice> account) {
return account.map(AuthenticatedDevice::getName).orElse("empty");
}
@GET
@Path("/write-auth")
@ManagedAsync
public String writeAuth(@Auth final AuthenticatedDevice account) {
return account.getName();
}
@GET
@Path("/optional-write-auth")
@ManagedAsync
public String optionalWriteAuth(@Auth final Optional<AuthenticatedDevice> account) {
return account.map(AuthenticatedDevice::getName).orElse("empty");
}
@GET
@Path("/start-delayed-write/{id}")
@ManagedAsync
public String startDelayedWrite(@Auth final AuthenticatedDevice account, @PathParam("id") String id)
throws InterruptedException {
delayedWriteLatches.computeIfAbsent(id, i -> new CountDownLatch(1)).await();
return account.getName();
}
@GET
@Path("/finish-delayed-write/{id}")
@ManagedAsync
public String finishDelayedWrite(@PathParam("id") String id) {
delayedWriteLatches.computeIfAbsent(id, i -> new CountDownLatch(1)).countDown();
return "ok";
}
}
}

View File

@ -18,7 +18,6 @@ import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
class CertificateGeneratorTest {
@ -37,7 +36,7 @@ class CertificateGeneratorTest {
@Test
void testCreateFor() throws IOException, InvalidKeyException, org.signal.libsignal.protocol.InvalidKeyException {
final Account account = mock(Account.class);
final Device device = mock(Device.class);
final byte deviceId = 4;
final CertificateGenerator certificateGenerator = new CertificateGenerator(
Base64.getDecoder().decode(SIGNING_CERTIFICATE),
Curve.decodePrivatePoint(Base64.getDecoder().decode(SIGNING_KEY)), 1);
@ -45,9 +44,8 @@ class CertificateGeneratorTest {
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(IDENTITY_KEY);
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.getNumber()).thenReturn("+18005551234");
when(device.getId()).thenReturn((byte) 4);
assertTrue(certificateGenerator.createFor(account, device, true).length > 0);
assertTrue(certificateGenerator.createFor(account, device, false).length > 0);
assertTrue(certificateGenerator.createFor(account, deviceId, true).length > 0);
assertTrue(certificateGenerator.createFor(account, deviceId, false).length > 0);
}
}

View File

@ -31,8 +31,6 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.TestClock;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.PrincipalSupplier;
class IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest {
@ -59,9 +57,9 @@ class IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest {
final boolean expectPqKeyCheck,
@Nullable final String expectedAlertHeader) {
final ReusableAuth<AuthenticatedDevice> reusableAuth = authenticatedDevice != null
? ReusableAuth.authenticated(authenticatedDevice, PrincipalSupplier.forImmutablePrincipal())
: ReusableAuth.anonymous();
final Optional<AuthenticatedDevice> reusableAuth = authenticatedDevice != null
? Optional.of(authenticatedDevice)
: Optional.empty();
final JettyServerUpgradeResponse response = mock(JettyServerUpgradeResponse.class);
@ -88,20 +86,24 @@ class IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest {
private static List<Arguments> handleAuthentication() {
final Device activePrimaryDevice = mock(Device.class);
when(activePrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(activePrimaryDevice.isPrimary()).thenReturn(true);
when(activePrimaryDevice.getLastSeen()).thenReturn(CLOCK.millis());
final Device minIdlePrimaryDevice = mock(Device.class);
when(minIdlePrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(minIdlePrimaryDevice.isPrimary()).thenReturn(true);
when(minIdlePrimaryDevice.getLastSeen())
.thenReturn(CLOCK.instant().minus(MIN_IDLE_DURATION).minusSeconds(1).toEpochMilli());
final Device longIdlePrimaryDevice = mock(Device.class);
when(longIdlePrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(longIdlePrimaryDevice.isPrimary()).thenReturn(true);
when(longIdlePrimaryDevice.getLastSeen())
.thenReturn(CLOCK.instant().minus(IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter.PQ_KEY_CHECK_THRESHOLD).minusSeconds(1).toEpochMilli());
final Device linkedDevice = mock(Device.class);
when(linkedDevice.getId()).thenReturn((byte) (Device.PRIMARY_ID + 1));
when(linkedDevice.isPrimary()).thenReturn(false);
final Account accountWithActivePrimaryDevice = mock(Account.class);

View File

@ -1,328 +0,0 @@
/*
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import io.dropwizard.auth.Auth;
import io.dropwizard.auth.AuthDynamicFeature;
import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
import io.dropwizard.jersey.DropwizardResourceConfig;
import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import jakarta.ws.rs.DELETE;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.PUT;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.Principal;
import java.time.Duration;
import java.util.Arrays;
import java.util.Base64;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.glassfish.jersey.server.ApplicationHandler;
import org.glassfish.jersey.server.ResourceConfig;
import org.glassfish.jersey.server.monitoring.ApplicationEventListener;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProvider;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.logging.WebsocketRequestLog;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
import org.whispersystems.websocket.messages.protobuf.SubProtocol;
import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider;
@ExtendWith(DropwizardExtensionsSupport.class)
class LinkedDeviceRefreshRequirementProviderTest {
private final ApplicationEventListener applicationEventListener = mock(ApplicationEventListener.class);
private final Account account = new Account();
private final Device authenticatedDevice = DevicesHelper.createDevice(Device.PRIMARY_ID);
private final Supplier<Optional<TestPrincipal>> principalSupplier = () -> Optional.of(
new TestPrincipal("test", account, authenticatedDevice));
private final ResourceExtension resources = ResourceExtension.builder()
.addProvider(new AuthDynamicFeature(new BasicCredentialAuthFilter.Builder<TestPrincipal>()
.setAuthenticator(c -> principalSupplier.get()).buildAuthFilter()))
.addProvider(new AuthValueFactoryProvider.Binder<>(TestPrincipal.class))
.addProvider(applicationEventListener)
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new TestResource())
.build();
private AccountsManager accountsManager;
private DisconnectionRequestManager disconnectionRequestManager;
private LinkedDeviceRefreshRequirementProvider provider;
@BeforeEach
void setup() {
accountsManager = mock(AccountsManager.class);
disconnectionRequestManager = mock(DisconnectionRequestManager.class);
provider = new LinkedDeviceRefreshRequirementProvider(accountsManager);
final WebsocketRefreshRequestEventListener listener =
new WebsocketRefreshRequestEventListener(disconnectionRequestManager, provider);
when(applicationEventListener.onRequest(any())).thenReturn(listener);
final UUID uuid = UUID.randomUUID();
account.setUuid(uuid);
account.addDevice(authenticatedDevice);
IntStream.range(2, 4)
.forEach(deviceId -> account.addDevice(DevicesHelper.createDevice((byte) deviceId)));
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
}
@Test
void testDeviceAdded() {
final int initialDeviceCount = account.getDevices().size();
final List<String> addedDeviceNames = List.of(
Base64.getEncoder().encodeToString("newDevice1".getBytes(StandardCharsets.UTF_8)),
Base64.getEncoder().encodeToString("newDevice2".getBytes(StandardCharsets.UTF_8)));
final Response response = resources.getJerseyTest()
.target("/v1/test/account/devices")
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.put(Entity.entity(addedDeviceNames, MediaType.APPLICATION_JSON_PATCH_JSON));
assertEquals(200, response.getStatus());
assertEquals(initialDeviceCount + addedDeviceNames.size(), account.getDevices().size());
verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of((byte) 1));
verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of((byte) 2));
verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of((byte) 3));
}
@ParameterizedTest
@ValueSource(ints = {1, 2})
void testDeviceRemoved(final int removedDeviceCount) {
final List<Byte> initialDeviceIds = account.getDevices().stream().map(Device::getId).toList();
final List<Byte> deletedDeviceIds = account.getDevices().stream()
.map(Device::getId)
.filter(deviceId -> deviceId != Device.PRIMARY_ID)
.limit(removedDeviceCount)
.toList();
assert deletedDeviceIds.size() == removedDeviceCount;
final String deletedDeviceIdsParam = deletedDeviceIds.stream().map(String::valueOf)
.collect(Collectors.joining(","));
final Response response = resources.getJerseyTest()
.target("/v1/test/account/devices/" + deletedDeviceIdsParam)
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.delete();
assertEquals(200, response.getStatus());
initialDeviceIds.forEach(deviceId ->
verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of(deviceId)));
verifyNoMoreInteractions(disconnectionRequestManager);
}
@Test
void testOnEvent() {
Response response = resources.getJerseyTest()
.target("/v1/test/hello")
.request()
// no authorization required
.get();
assertEquals(200, response.getStatus());
response = resources.getJerseyTest()
.target("/v1/test/authorized")
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.get();
assertEquals(200, response.getStatus());
verify(accountsManager, never()).getByAccountIdentifier(any(UUID.class));
}
@Nested
class WebSocket {
private WebSocketResourceProvider<TestPrincipal> provider;
private RemoteEndpoint remoteEndpoint;
@BeforeEach
void setup() {
ResourceConfig resourceConfig = new DropwizardResourceConfig();
resourceConfig.register(applicationEventListener);
resourceConfig.register(new TestResource());
resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder());
resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class));
resourceConfig.register(new JacksonMessageBodyProvider(SystemMapper.jsonMapper()));
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
provider = new WebSocketResourceProvider<>("127.0.0.1", RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME,
applicationHandler, requestLog, TestPrincipal.reusableAuth("test", account, authenticatedDevice),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
remoteEndpoint = mock(RemoteEndpoint.class);
Session session = mock(Session.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getRemote()).thenReturn(remoteEndpoint);
when(session.getUpgradeRequest()).thenReturn(request);
provider.onWebSocketConnect(session);
}
@Test
void testOnEvent() throws Exception {
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello",
new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
final SubProtocol.WebSocketResponseMessage response = verifyAndGetResponse(remoteEndpoint);
assertEquals(200, response.getStatus());
}
private SubProtocol.WebSocketResponseMessage verifyAndGetResponse(final RemoteEndpoint remoteEndpoint)
throws IOException {
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class));
return SubProtocol.WebSocketMessage.parseFrom(responseBytesCaptor.getValue().array()).getResponse();
}
}
public static class TestPrincipal implements Principal, AccountAndAuthenticatedDeviceHolder {
private final String name;
private final Account account;
private final Device device;
private TestPrincipal(final String name, final Account account, final Device device) {
this.name = name;
this.account = account;
this.device = device;
}
@Override
public String getName() {
return name;
}
@Override
public Account getAccount() {
return account;
}
@Override
public Device getAuthenticatedDevice() {
return device;
}
public static ReusableAuth<TestPrincipal> reusableAuth(final String name, final Account account, final Device device) {
return ReusableAuth.authenticated(new TestPrincipal(name, account, device), PrincipalSupplier.forImmutablePrincipal());
}
}
@Path("/v1/test")
public static class TestResource {
@GET
@Path("/hello")
public String testGetHello() {
return "Hello!";
}
@GET
@Path("/authorized")
public String testAuth(@Auth TestPrincipal principal) {
return "Youre in!";
}
@PUT
@Path("/account/devices")
@ChangesLinkedDevices
public String addDevices(@Auth TestPrincipal auth, List<byte[]> deviceNames) {
deviceNames.forEach(name -> {
final Device device = DevicesHelper.createDevice(auth.getAccount().getNextDeviceId());
auth.getAccount().addDevice(device);
device.setName(name);
});
return "Added devices " + deviceNames;
}
@DELETE
@Path("/account/devices/{deviceIds}")
@ChangesLinkedDevices
public String removeDevices(@Auth TestPrincipal auth, @PathParam("deviceIds") String deviceIds) {
Arrays.stream(deviceIds.split(","))
.map(Byte::valueOf)
.forEach(auth.getAccount()::removeDevice);
return "Removed device(s) " + deviceIds;
}
}
}

View File

@ -1,294 +0,0 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;
import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth;
import io.dropwizard.auth.AuthDynamicFeature;
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
import io.dropwizard.testing.junit5.DropwizardAppExtension;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletRegistration;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.client.Invocation;
import java.io.IOException;
import java.net.URI;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.glassfish.jersey.server.ManagedAsync;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.tests.util.DevicesHelper;
import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
import org.whispersystems.websocket.setup.WebSocketEnvironment;
@ExtendWith(DropwizardExtensionsSupport.class)
class PhoneNumberChangeRefreshRequirementProviderTest {
private static final String NUMBER = "+18005551234";
private static final String CHANGED_NUMBER = "+18005554321";
private static final String TEST_CRED_HEADER = HeaderUtils.basicAuthHeader("test", "password");
private static final DropwizardAppExtension<Configuration> DROPWIZARD_APP_EXTENSION = new DropwizardAppExtension<>(
TestApplication.class);
private static final AccountAuthenticator AUTHENTICATOR = mock(AccountAuthenticator.class);
private static final AccountsManager ACCOUNTS_MANAGER = mock(AccountsManager.class);
private static final DisconnectionRequestManager DISCONNECTION_REQUEST_MANAGER =
mock(DisconnectionRequestManager.class);
private WebSocketClient client;
private final Account account1 = new Account();
private final Account account2 = new Account();
private final Device authenticatedDevice = DevicesHelper.createDevice(Device.PRIMARY_ID);
@BeforeEach
void setUp() throws Exception {
reset(AUTHENTICATOR, ACCOUNTS_MANAGER, DISCONNECTION_REQUEST_MANAGER);
client = new WebSocketClient();
client.start();
final UUID uuid = UUID.randomUUID();
account1.setUuid(uuid);
account1.addDevice(authenticatedDevice);
account1.setNumber(NUMBER, UUID.randomUUID());
account2.setUuid(uuid);
account2.addDevice(authenticatedDevice);
account2.setNumber(CHANGED_NUMBER, UUID.randomUUID());
}
@AfterEach
void tearDown() throws Exception {
client.stop();
}
public static class TestApplication extends Application<Configuration> {
@Override
public void run(final Configuration configuration, final Environment environment) throws Exception {
final TestController testController = new TestController();
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
final WebSocketEnvironment<AuthenticatedDevice> webSocketEnvironment =
new WebSocketEnvironment<>(environment, webSocketConfiguration);
environment.jersey().register(testController);
webSocketEnvironment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, DISCONNECTION_REQUEST_MANAGER));
environment.jersey()
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, DISCONNECTION_REQUEST_MANAGER));
webSocketEnvironment.setConnectListener(webSocketSessionContext -> {
});
environment.jersey().register(new AuthDynamicFeature(new BasicCredentialAuthFilter.Builder<AuthenticatedDevice>()
.setAuthenticator(AUTHENTICATOR)
.buildAuthFilter()));
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(AUTHENTICATOR, mock(PrincipalSupplier.class)));
final WebSocketResourceProviderFactory<AuthenticatedDevice> webSocketServlet =
new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class,
webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME);
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
final ServletRegistration.Dynamic websocketServlet =
environment.servlets().addServlet("WebSocket", webSocketServlet);
websocketServlet.addMapping("/websocket");
websocketServlet.setAsyncSupported(true);
}
}
enum Protocol { HTTP, WEBSOCKET }
private void makeAnonymousRequest(final Protocol protocol, final String requestPath) throws IOException {
makeRequest(protocol, requestPath, true);
}
/*
* Make an authenticated request that will return account1 as the principal
*/
private void makeAuthenticatedRequest(
final Protocol protocol,
final String requestPath) throws IOException {
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedDevice(account1, authenticatedDevice)));
makeRequest(protocol,requestPath, false);
}
private void makeRequest(final Protocol protocol, final String requestPath, final boolean anonymous) throws IOException {
switch (protocol) {
case WEBSOCKET -> {
final TestWebsocketListener testWebsocketListener = new TestWebsocketListener();
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
if (!anonymous) {
upgradeRequest.setHeader(HttpHeaders.AUTHORIZATION, TEST_CRED_HEADER);
}
client.connect(
testWebsocketListener,
URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())),
upgradeRequest);
testWebsocketListener.sendRequest(requestPath, "GET", Collections.emptyList(), Optional.empty()).join();
}
case HTTP -> {
final Invocation.Builder request = DROPWIZARD_APP_EXTENSION.client()
.target("http://127.0.0.1:%s%s".formatted(DROPWIZARD_APP_EXTENSION.getLocalPort(), requestPath))
.request();
if (!anonymous) {
request.header(HttpHeaders.AUTHORIZATION, TEST_CRED_HEADER);
}
request.get();
}
}
}
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestNoChange(final Protocol protocol) throws IOException {
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account1));
makeAuthenticatedRequest(protocol, "/test/annotated");
// Event listeners can fire after responses are sent
verify(ACCOUNTS_MANAGER, timeout(5000).times(1)).getByAccountIdentifier(eq(account1.getUuid()));
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
}
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestChange(final Protocol protocol) throws IOException {
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account2));
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedDevice(account1, authenticatedDevice)));
makeAuthenticatedRequest(protocol, "/test/annotated");
// Make sure we disconnect the account if the account has changed numbers. Event listeners can fire after responses
// are sent, so use a timeout.
verify(DISCONNECTION_REQUEST_MANAGER, timeout(5000))
.requestDisconnection(account1.getUuid(), List.of(authenticatedDevice.getId()));
verifyNoMoreInteractions(DISCONNECTION_REQUEST_MANAGER);
}
@Test
void handleRequestChangeAsyncEndpoint() throws IOException {
when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account2));
when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedDevice(account1, authenticatedDevice)));
// Event listeners with asynchronous HTTP endpoints don't currently correctly maintain state between request and
// response
makeAuthenticatedRequest(Protocol.WEBSOCKET, "/test/async-annotated");
// Make sure we disconnect the account if the account has changed numbers. Event listeners can fire after responses
// are sent, so use a timeout.
verify(DISCONNECTION_REQUEST_MANAGER, timeout(5000))
.requestDisconnection(account1.getUuid(), List.of(authenticatedDevice.getId()));
verifyNoMoreInteractions(DISCONNECTION_REQUEST_MANAGER);
}
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestNotAnnotated(final Protocol protocol) throws IOException, InterruptedException {
makeAuthenticatedRequest(protocol,"/test/not-annotated");
// Give a tick for event listeners to run. Racy, but should occasionally catch an errant running listener if one is
// introduced.
Thread.sleep(100);
// Shouldn't even read the account if the method has not been annotated
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
}
@ParameterizedTest
@EnumSource(Protocol.class)
void handleRequestNotAuthenticated(final Protocol protocol) throws IOException, InterruptedException {
makeAnonymousRequest(protocol, "/test/not-authenticated");
// Give a tick for event listeners to run. Racy, but should occasionally catch an errant running listener if one is
// introduced.
Thread.sleep(100);
// Shouldn't even read the account if the method has not been annotated
verifyNoMoreInteractions(ACCOUNTS_MANAGER);
}
@Path("/test")
public static class TestController {
@GET
@Path("/annotated")
@ChangesPhoneNumber
public String annotated(@ReadOnly @Auth final AuthenticatedDevice account) {
return "ok";
}
@GET
@Path("/async-annotated")
@ChangesPhoneNumber
@ManagedAsync
public String asyncAnnotated(@ReadOnly @Auth final AuthenticatedDevice account) {
return "ok";
}
@GET
@Path("/not-authenticated")
@ChangesPhoneNumber
public String notAuthenticated() {
return "ok";
}
@GET
@Path("/not-annotated")
public String notAnnotated(@ReadOnly @Auth final AuthenticatedDevice account) {
return "ok";
}
}
}

View File

@ -211,6 +211,11 @@ class AccountControllerTest {
when(accountsManager.getByE164(eq(SENDER_HAS_STORAGE))).thenReturn(Optional.of(senderHasStorage));
when(accountsManager.getByE164(eq(SENDER_TRANSFER))).thenReturn(Optional.of(senderTransfer));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_TWO));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_3)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_3));
when(accountsManager.getByAccountIdentifier(AuthHelper.UNDISCOVERABLE_UUID)).thenReturn(Optional.of(AuthHelper.UNDISCOVERABLE_ACCOUNT));
doAnswer(invocation -> {
final byte[] proof = invocation.getArgument(0);
final byte[] hash = invocation.getArgument(1);

View File

@ -145,6 +145,8 @@ class AccountControllerV2Test {
void setUp() throws Exception {
when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter);
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any(), any(), any())).thenAnswer(
(Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0);
@ -607,6 +609,8 @@ class AccountControllerV2Test {
@BeforeEach
void setUp() throws Exception {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(changeNumberManager.updatePniKeys(any(), any(), any(), any(), any(), any(), any())).thenAnswer(
(Answer<Account>) invocation -> {
final Account account = invocation.getArgument(0);
@ -768,7 +772,9 @@ class AccountControllerV2Test {
@BeforeEach
void setup() {
AccountsHelper.setupMockUpdate(accountsManager);
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
}
@Test
void testSetPhoneNumberDiscoverability() {
Response response = resources.getJerseyTest()
@ -805,6 +811,7 @@ class AccountControllerV2Test {
@MethodSource
void testGetAccountDataReport(final Account account, final String expectedTextAfterHeader) throws Exception {
when(AuthHelper.ACCOUNTS_MANAGER.getByAccountIdentifier(account.getUuid())).thenReturn(Optional.of(account));
when(accountsManager.getByAccountIdentifier(account.getUuid())).thenReturn(Optional.of(account));
final Response response = resources.getJerseyTest()
.target("/v2/accounts/data_report")

View File

@ -73,6 +73,7 @@ import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.GrpcStatusRuntimeExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.metrics.BackupMetrics;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.EnumMapUtil;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -82,6 +83,7 @@ import reactor.core.publisher.Flux;
@ExtendWith(DropwizardExtensionsSupport.class)
public class ArchiveControllerTest {
private static final AccountsManager accountsManager = mock(AccountsManager.class);
private static final BackupAuthManager backupAuthManager = mock(BackupAuthManager.class);
private static final BackupManager backupManager = mock(BackupManager.class);
private final BackupAuthTestUtil backupAuthTestUtil = new BackupAuthTestUtil(Clock.systemUTC());
@ -95,7 +97,7 @@ public class ArchiveControllerTest {
.addProvider(new RateLimitExceededExceptionMapper())
.setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new ArchiveController(backupAuthManager, backupManager, new BackupMetrics()))
.addResource(new ArchiveController(accountsManager, backupAuthManager, backupManager, new BackupMetrics()))
.build();
private final UUID aci = UUID.randomUUID();
@ -106,6 +108,9 @@ public class ArchiveControllerTest {
public void setUp() {
reset(backupAuthManager);
reset(backupManager);
when(accountsManager.getByAccountIdentifierAsync(AuthHelper.VALID_UUID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT)));
}
@ParameterizedTest

View File

@ -9,6 +9,8 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
@ -21,9 +23,11 @@ import java.time.Instant;
import java.time.ZoneId;
import java.time.temporal.ChronoUnit;
import java.util.Base64;
import java.util.Optional;
import java.util.stream.Stream;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
@ -44,6 +48,7 @@ import org.whispersystems.textsecuregcm.entities.DeliveryCertificate;
import org.whispersystems.textsecuregcm.entities.GroupCredentials;
import org.whispersystems.textsecuregcm.entities.MessageProtos.SenderCertificate;
import org.whispersystems.textsecuregcm.entities.MessageProtos.ServerCertificate;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -66,6 +71,8 @@ class CertificateControllerTest {
private static final ServerZkAuthOperations serverZkAuthOperations;
private static final Clock clock = Clock.fixed(Instant.now(), ZoneId.systemDefault());
private static final AccountsManager accountsManager = mock(AccountsManager.class);
static {
try {
certificateGenerator = new CertificateGenerator(Base64.getDecoder().decode(signingCertificate),
@ -82,9 +89,14 @@ class CertificateControllerTest {
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class))
.setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new CertificateController(certificateGenerator, serverZkAuthOperations, genericServerSecretParams, clock))
.addResource(new CertificateController(accountsManager, certificateGenerator, serverZkAuthOperations, genericServerSecretParams, clock))
.build();
@BeforeEach
void setUp() {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
}
@Test
void testValidCertificate() throws Exception {
DeliveryCertificate certificateObject = resources.getJerseyTest()

View File

@ -38,6 +38,7 @@ import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker;
import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker.ChallengeConstraints;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
@ -45,11 +46,12 @@ import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
@ExtendWith(DropwizardExtensionsSupport.class)
class ChallengeControllerTest {
private static final AccountsManager accountsManager = mock(AccountsManager.class);
private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class);
private static final ChallengeConstraintChecker challengeConstraintChecker = mock(ChallengeConstraintChecker.class);
private static final ChallengeController challengeController =
new ChallengeController(rateLimitChallengeManager, challengeConstraintChecker);
new ChallengeController(accountsManager, rateLimitChallengeManager, challengeConstraintChecker);
private static final ResourceExtension EXTENSION = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
@ -63,6 +65,9 @@ class ChallengeControllerTest {
@BeforeEach
void setup() {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_TWO));
when(challengeConstraintChecker.challengeConstraints(any(), any()))
.thenReturn(new ChallengeConstraints(true, Optional.empty()));
}

View File

@ -27,6 +27,7 @@ import java.time.Duration;
import java.time.Instant;
import java.util.Base64;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import org.glassfish.jersey.server.ServerProperties;
@ -45,6 +46,7 @@ import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.GrpcStatusRuntimeExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.devicecheck.AppleDeviceCheckManager;
import org.whispersystems.textsecuregcm.storage.devicecheck.ChallengeNotFoundException;
import org.whispersystems.textsecuregcm.storage.devicecheck.DeviceCheckKeyIdNotFoundException;
@ -62,6 +64,7 @@ class DeviceCheckControllerTest {
private final static Duration REDEMPTION_DURATION = Duration.ofDays(5);
private final static long REDEMPTION_LEVEL = 201L;
private static final AccountsManager accountsManager = mock(AccountsManager.class);
private final static BackupAuthManager backupAuthManager = mock(BackupAuthManager.class);
private final static AppleDeviceCheckManager appleDeviceCheckManager = mock(AppleDeviceCheckManager.class);
private final static RateLimiters rateLimiters = mock(RateLimiters.class);
@ -76,7 +79,7 @@ class DeviceCheckControllerTest {
.addProvider(new RateLimitExceededExceptionMapper())
.setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new DeviceCheckController(clock, backupAuthManager, appleDeviceCheckManager, rateLimiters,
.addResource(new DeviceCheckController(clock, accountsManager, backupAuthManager, appleDeviceCheckManager, rateLimiters,
REDEMPTION_LEVEL, REDEMPTION_DURATION))
.build();
@ -86,6 +89,8 @@ class DeviceCheckControllerTest {
reset(appleDeviceCheckManager);
reset(rateLimiters);
when(rateLimiters.forDescriptor(any())).thenReturn(mock(RateLimiter.class));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
}
@ParameterizedTest

View File

@ -42,14 +42,12 @@ import java.util.concurrent.CompletableFuture;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.server.ServerProperties;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
@ -62,8 +60,6 @@ import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager;
import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.ApnRegistrationId;
import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest;
@ -116,7 +112,6 @@ class DeviceControllerTest {
private static final Account account = mock(Account.class);
private static final Account maxedAccount = mock(Account.class);
private static final Device primaryDevice = mock(Device.class);
private static final DisconnectionRequestManager disconnectionRequestManager = mock(DisconnectionRequestManager.class);
private static final Map<String, Integer> deviceConfiguration = new HashMap<>();
private static final TestClock testClock = TestClock.now();
@ -129,16 +124,12 @@ class DeviceControllerTest {
persistentTimer,
deviceConfiguration);
@RegisterExtension
public static final AuthHelper.AuthFilterExtension AUTH_FILTER_EXTENSION = new AuthHelper.AuthFilterExtension();
private static final ResourceExtension resources = ResourceExtension.builder()
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class))
.addProvider(new RateLimitExceededExceptionMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, disconnectionRequestManager))
.addProvider(new DeviceLimitExceededExceptionMapper())
.addResource(deviceController)
.build();
@ -157,8 +148,15 @@ class DeviceControllerTest {
when(account.getNumber()).thenReturn(AuthHelper.VALID_NUMBER);
when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID);
when(account.getPhoneNumberIdentifier()).thenReturn(AuthHelper.VALID_PNI);
when(account.getPrimaryDevice()).thenReturn(primaryDevice);
when(account.getDevice(anyByte())).thenReturn(Optional.empty());
when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(primaryDevice));
when(account.getDevices()).thenReturn(List.of(primaryDevice));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
when(accountsManager.getByAccountIdentifierAsync(AuthHelper.VALID_UUID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(account)));
when(accountsManager.getByE164(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(account));
when(accountsManager.getByE164(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(maxedAccount));
@ -229,7 +227,7 @@ class DeviceControllerTest {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@ -310,7 +308,7 @@ class DeviceControllerTest {
final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(primaryDevice));
when(account.getDevices()).thenReturn(List.of(primaryDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@ -362,7 +360,7 @@ class DeviceControllerTest {
final Device primaryDevice = mock(Device.class);
when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(primaryDevice));
when(account.getDevices()).thenReturn(List.of(primaryDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@ -398,7 +396,7 @@ class DeviceControllerTest {
void linkDeviceAtomicReusedToken() {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@ -447,7 +445,7 @@ class DeviceControllerTest {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@ -487,12 +485,12 @@ class DeviceControllerTest {
void linkDeviceAtomicConflictingChannel(final boolean fetchesMessages,
final Optional<ApnRegistrationId> apnRegistrationId,
final Optional<GcmRegistrationId> gcmRegistrationId) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
when(accountsManager.generateLinkDeviceToken(any())).thenReturn("test");
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final LinkDeviceToken deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
@ -548,12 +546,12 @@ class DeviceControllerTest {
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniPqLastResortPreKey) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
when(accountsManager.generateLinkDeviceToken(any())).thenReturn("test");
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final LinkDeviceToken deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
@ -613,11 +611,11 @@ class DeviceControllerTest {
aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair);
pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair);
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey()));
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey()));
@ -647,11 +645,11 @@ class DeviceControllerTest {
final KEMSignedPreKey aciPqLastResortPreKey,
final KEMSignedPreKey pniPqLastResortPreKey) {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account));
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
when(account.getIdentityKey(IdentityType.ACI)).thenReturn(aciIdentityKey);
when(account.getIdentityKey(IdentityType.PNI)).thenReturn(pniIdentityKey);
@ -698,7 +696,7 @@ class DeviceControllerTest {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final ECSignedPreKey aciSignedPreKey;
final ECSignedPreKey pniSignedPreKey;
@ -735,7 +733,7 @@ class DeviceControllerTest {
void linkDeviceRegistrationId(final int registrationId, final int pniRegistrationId, final int expectedStatusCode) {
final Device existingDevice = mock(Device.class);
when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID);
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice));
when(account.getDevices()).thenReturn(List.of(existingDevice));
final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair();
final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair();
@ -800,17 +798,16 @@ class DeviceControllerTest {
@Test
void maxDevicesTest() {
final AuthHelper.TestAccount testAccount = AUTH_FILTER_EXTENSION.createTestAccount();
final List<Device> devices = IntStream.range(0, DeviceController.MAX_DEVICES + 1)
.mapToObj(i -> mock(Device.class))
.toList();
when(testAccount.account.getDevices()).thenReturn(devices);
when(account.getDevices()).thenReturn(devices);
Response response = resources.getJerseyTest()
.target("/v1/devices/provisioning/code")
.request()
.header("Authorization", testAccount.getAuthHeader())
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get();
assertEquals(411, response.getStatus());
@ -829,7 +826,7 @@ class DeviceControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
assertThat(response.hasEntity()).isFalse();
verify(AuthHelper.VALID_DEVICE).setCapabilities(Set.of(DeviceCapability.DELETE_SYNC));
verify(primaryDevice).setCapabilities(Set.of(DeviceCapability.DELETE_SYNC));
}
}
@ -851,12 +848,12 @@ class DeviceControllerTest {
void removeDevice() {
// this is a static mock, so it might have previous invocations
clearInvocations(AuthHelper.VALID_ACCOUNT);
clearInvocations(account);
final byte deviceId = 2;
when(accountsManager.removeDevice(AuthHelper.VALID_ACCOUNT, deviceId))
.thenReturn(CompletableFuture.completedFuture(AuthHelper.VALID_ACCOUNT));
when(accountsManager.removeDevice(account, deviceId))
.thenReturn(CompletableFuture.completedFuture(account));
try (final Response response = resources
.getJerseyTest()
@ -869,14 +866,14 @@ class DeviceControllerTest {
assertThat(response.getStatus()).isEqualTo(204);
assertThat(response.hasEntity()).isFalse();
verify(accountsManager).removeDevice(AuthHelper.VALID_ACCOUNT, deviceId);
verify(accountsManager).removeDevice(account, deviceId);
}
}
@Test
void unlinkPrimaryDevice() {
// this is a static mock, so it might have previous invocations
clearInvocations(AuthHelper.VALID_ACCOUNT);
clearInvocations(account);
try (final Response response = resources
.getJerseyTest()
@ -897,7 +894,10 @@ class DeviceControllerTest {
final byte deviceId = 2;
when(accountsManager.removeDevice(AuthHelper.VALID_ACCOUNT_3, deviceId))
.thenReturn(CompletableFuture.completedFuture(AuthHelper.VALID_ACCOUNT));
.thenReturn(CompletableFuture.completedFuture(account));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_3))
.thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_3));
try (final Response response = resources
.getJerseyTest()
@ -946,7 +946,7 @@ class DeviceControllerTest {
assertEquals(204, response.getStatus());
}
verify(clientPublicKeysManager).setPublicKey(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE.getId(), request.publicKey());
verify(clientPublicKeysManager).setPublicKey(account, AuthHelper.VALID_DEVICE.getId(), request.publicKey());
}
@Test
@ -959,7 +959,7 @@ class DeviceControllerTest {
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
when(accountsManager
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any()))
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(primaryDevice), eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo)));
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
@ -985,7 +985,7 @@ class DeviceControllerTest {
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
when(accountsManager
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any()))
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(primaryDevice), eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
@ -1005,7 +1005,7 @@ class DeviceControllerTest {
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);
when(accountsManager
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any()))
.waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(primaryDevice), eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException()));
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
@ -1079,7 +1079,7 @@ class DeviceControllerTest {
new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)));
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferArchive))
when(accountsManager.recordTransferArchiveUpload(account, deviceId, deviceCreated, transferArchive))
.thenReturn(CompletableFuture.completedFuture(null));
try (final Response response = resources.getJerseyTest()
@ -1092,7 +1092,7 @@ class DeviceControllerTest {
assertEquals(204, response.getStatus());
verify(accountsManager)
.recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferArchive);
.recordTransferArchiveUpload(account, deviceId, deviceCreated, transferArchive);
}
}
@ -1103,7 +1103,7 @@ class DeviceControllerTest {
final RemoteAttachmentError transferFailure = new RemoteAttachmentError(RemoteAttachmentError.ErrorType.CONTINUE_WITHOUT_UPLOAD);
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferFailure))
when(accountsManager.recordTransferArchiveUpload(account, deviceId, deviceCreated, transferFailure))
.thenReturn(CompletableFuture.completedFuture(null));
try (final Response response = resources.getJerseyTest()
@ -1116,7 +1116,7 @@ class DeviceControllerTest {
assertEquals(204, response.getStatus());
verify(accountsManager)
.recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferFailure);
.recordTransferArchiveUpload(account, deviceId, deviceCreated, transferFailure);
}
}
@ -1186,7 +1186,7 @@ class DeviceControllerTest {
new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)));
when(rateLimiter.validateAsync(anyString())).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.waitForTransferArchive(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any()))
when(accountsManager.waitForTransferArchive(eq(account), eq(primaryDevice), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(transferArchive)));
try (final Response response = resources.getJerseyTest()
@ -1206,7 +1206,7 @@ class DeviceControllerTest {
new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)));
when(rateLimiter.validateAsync(anyString())).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.waitForTransferArchive(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any()))
when(accountsManager.waitForTransferArchive(eq(account), eq(primaryDevice), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(transferArchive)));
try (final Response response = resources.getJerseyTest()
@ -1223,7 +1223,7 @@ class DeviceControllerTest {
@Test
void waitForTransferArchiveNoArchiveUploaded() {
when(rateLimiter.validateAsync(anyString())).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.waitForTransferArchive(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any()))
when(accountsManager.waitForTransferArchive(eq(account), eq(primaryDevice), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
try (final Response response = resources.getJerseyTest()

View File

@ -19,6 +19,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator;
import org.whispersystems.textsecuregcm.configuration.DirectoryV2ClientConfiguration;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
@ -35,7 +36,7 @@ class DirectoryControllerV2Test {
final Account account = mock(Account.class);
final UUID uuid = UUID.fromString("11111111-1111-1111-1111-111111111111");
when(account.getUuid()).thenReturn(uuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(uuid);
final ExternalServiceCredentials credentials = controller.getAuthToken(
new AuthenticatedDevice(account, mock(Device.class)));

View File

@ -140,6 +140,8 @@ class DonationControllerTest {
when(receiptCredentialPresentation.getReceiptExpirationTime()).thenReturn(receiptExpiration);
when(redeemedReceiptsManager.put(same(receiptSerial), eq(receiptExpiration), eq(receiptLevel), eq(AuthHelper.VALID_UUID))).thenReturn(
CompletableFuture.completedFuture(Boolean.FALSE));
when(accountsManager.getByAccountIdentifierAsync(eq(AuthHelper.VALID_UUID))).thenReturn(
CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT)));
RedeemReceiptRequest request = new RedeemReceiptRequest(presentation, true, true);
Response response = resources.getJerseyTest()

View File

@ -242,10 +242,21 @@ class KeysControllerTest {
when(existsAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of("1337".getBytes()));
when(accounts.getByServiceIdentifier(any())).thenReturn(Optional.empty());
when(accounts.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(accounts.getByServiceIdentifier(new AciServiceIdentifier(EXISTS_UUID))).thenReturn(Optional.of(existsAccount));
when(accounts.getByServiceIdentifier(new PniServiceIdentifier(EXISTS_PNI))).thenReturn(Optional.of(existsAccount));
when(accounts.getByServiceIdentifierAsync(new AciServiceIdentifier(EXISTS_UUID)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(existsAccount)));
when(accounts.getByServiceIdentifierAsync(new PniServiceIdentifier(EXISTS_PNI)))
.thenReturn(CompletableFuture.completedFuture(Optional.of(existsAccount)));
when(accounts.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accounts.getByAccountIdentifierAsync(AuthHelper.VALID_UUID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT)));
when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter);
when(KEYS.storeEcOneTimePreKeys(any(), anyByte(), any()))

View File

@ -236,6 +236,14 @@ class MessageControllerTest {
when(accountsManager.getByServiceIdentifierAsync(MULTI_DEVICE_PNI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount)));
when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(internationalAccount)));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT));
when(accountsManager.getByAccountIdentifierAsync(AuthHelper.VALID_UUID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT)));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_3)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_3));
when(accountsManager.getByAccountIdentifierAsync(AuthHelper.VALID_UUID_3))
.thenReturn(CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT_3)));
when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getStoriesLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getInboundMessageBytes()).thenReturn(rateLimiter);

View File

@ -221,6 +221,8 @@ class ProfileControllerTest {
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(capabilitiesAccount));
when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(AuthHelper.VALID_UUID))).thenReturn(Optional.of(capabilitiesAccount));
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_TWO));
final byte[] name = TestRandomUtil.nextBytes(81);
final byte[] emoji = TestRandomUtil.nextBytes(60);
final byte[] about = TestRandomUtil.nextBytes(156);
@ -1155,6 +1157,7 @@ class ProfileControllerTest {
reset(accountsManager);
final int accountsManagerUpdateRetryCount = 2;
AccountsHelper.setupMockUpdateWithRetries(accountsManager, accountsManagerUpdateRetryCount);
when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_TWO));
// set up two invocations -- one for each AccountsManager#update try
when(AuthHelper.VALID_ACCOUNT_TWO.getBadges())
.thenReturn(List.of(

View File

@ -170,7 +170,12 @@ class RemoteConfigControllerTest {
void testHashKeyLinkedConfigs() {
boolean allUnlinkedConfigsMatched = true;
for (AuthHelper.TestAccount testAccount : AuthHelper.TEST_ACCOUNTS) {
UserRemoteConfigList configuration = resources.getJerseyTest().target("/v1/config/").request().header("Authorization", testAccount.getAuthHeader()).get(UserRemoteConfigList.class);
UserRemoteConfigList configuration = resources.getJerseyTest()
.target("/v1/config/")
.request()
.header("Authorization", testAccount.getAuthHeader())
.get(UserRemoteConfigList.class);
assertThat(configuration.getConfig()).hasSize(11);
final UserRemoteConfig linkedConfig0 = configuration.getConfig().get(7);

View File

@ -7,7 +7,6 @@ package org.whispersystems.textsecuregcm.entities;
import static org.junit.jupiter.api.Assertions.assertEquals;
import java.util.Random;
import java.util.UUID;
import javax.annotation.Nullable;
import org.junit.jupiter.api.Test;
@ -16,7 +15,6 @@ import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
@ -65,15 +63,13 @@ class OutgoingMessageEntityTest {
@Test
void entityPreservesEnvelope() {
final byte[] reportSpamToken = TestRandomUtil.nextBytes(8);
final AciServiceIdentifier sourceServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID());
final Account account = new Account();
account.setUuid(UUID.randomUUID());
IncomingMessage message = new IncomingMessage(1, (byte) 44, 55, TestRandomUtil.nextBytes(4));
final IncomingMessage message = new IncomingMessage(1, (byte) 44, 55, TestRandomUtil.nextBytes(4));
MessageProtos.Envelope baseEnvelope = message.toEnvelope(
new AciServiceIdentifier(UUID.randomUUID()),
account,
sourceServiceIdentifier,
(byte) 123,
System.currentTimeMillis(),
false,

View File

@ -153,7 +153,7 @@ class MetricsRequestEventListenerTest {
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.reusableAuth("foo"),
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.authenticatedTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class);
@ -220,7 +220,7 @@ class MetricsRequestEventListenerTest {
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.reusableAuth("foo"),
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.authenticatedTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class);

View File

@ -220,6 +220,8 @@ public class AccountsHelper {
case "getBackupVoucher" -> when(updatedAccount.getBackupVoucher()).thenAnswer(stubbing);
case "getLastSeen" -> when(updatedAccount.getLastSeen()).thenAnswer(stubbing);
case "hasLockedCredentials" -> when(updatedAccount.hasLockedCredentials()).thenAnswer(stubbing);
case "getCurrentProfileVersion" -> when(updatedAccount.getCurrentProfileVersion()).thenAnswer(stubbing);
case "getUnidentifiedAccessKey" -> when(updatedAccount.getUnidentifiedAccessKey()).thenAnswer(stubbing);
default -> throw new IllegalArgumentException("unsupported method: Account#" + stubbing.getInvocation().getMethod().getName());
}
}

View File

@ -277,6 +277,7 @@ public class AuthHelper {
when(account.getPrimaryDevice()).thenReturn(device);
when(account.getNumber()).thenReturn(number);
when(account.getUuid()).thenReturn(uuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(uuid);
when(accountsManager.getByE164(number)).thenReturn(Optional.of(account));
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
}

View File

@ -5,8 +5,7 @@
package org.whispersystems.textsecuregcm.tests.util;
import java.security.Principal;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import java.util.Optional;
public class TestPrincipal implements Principal {
@ -21,7 +20,7 @@ public class TestPrincipal implements Principal {
return name;
}
public static ReusableAuth<TestPrincipal> reusableAuth(final String name) {
return ReusableAuth.authenticated(new TestPrincipal(name), PrincipalSupplier.forImmutablePrincipal());
public static Optional<TestPrincipal> authenticatedTestPrincipal(final String name) {
return Optional.of(new TestPrincipal(name));
}
}

View File

@ -176,7 +176,7 @@ class LoggingUnhandledExceptionMapperTest {
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog,
TestPrincipal.reusableAuth("foo"),
TestPrincipal.authenticatedTestPrincipal("foo"),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);

View File

@ -28,7 +28,6 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.PrincipalSupplier;
class WebSocketAccountAuthenticatorTest {
@ -70,14 +69,12 @@ class WebSocketAccountAuthenticatorTest {
when(upgradeRequest.getHeader(eq(HttpHeaders.AUTHORIZATION))).thenReturn(authorizationHeaderValue);
}
final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(
accountAuthenticator,
mock(PrincipalSupplier.class));
final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
if (expectInvalid) {
assertThrows(InvalidCredentialsException.class, () -> webSocketAuthenticator.authenticate(upgradeRequest));
} else {
assertEquals(expectAccount, webSocketAuthenticator.authenticate(upgradeRequest).ref().isPresent());
assertEquals(expectAccount, webSocketAuthenticator.authenticate(upgradeRequest).isPresent());
}
}

View File

@ -43,11 +43,11 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
@ -111,7 +111,7 @@ class WebSocketConnectionIntegrationTest {
clientReleaseManager = mock(ClientReleaseManager.class);
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.getIdentifier(IdentityType.ACI)).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(Device.PRIMARY_ID);
}
@ -137,7 +137,8 @@ class WebSocketConnectionIntegrationTest {
new MessageMetrics(),
mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
new AuthenticatedDevice(account, device),
account,
device,
webSocketClient,
scheduledExecutorService,
messageDeliveryScheduler,
@ -159,14 +160,14 @@ class WebSocketConnectionIntegrationTest {
expectedMessages.add(envelope);
}
messagesDynamoDb.store(persistedMessages, account.getUuid(), device);
messagesDynamoDb.store(persistedMessages, account.getIdentifier(IdentityType.ACI), device);
}
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join();
expectedMessages.add(envelope);
}
@ -226,7 +227,8 @@ class WebSocketConnectionIntegrationTest {
new MessageMetrics(),
mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
new AuthenticatedDevice(account, device),
account,
device,
webSocketClient,
scheduledExecutorService,
messageDeliveryScheduler,
@ -250,13 +252,13 @@ class WebSocketConnectionIntegrationTest {
expectedMessages.add(envelope);
}
messagesDynamoDb.store(persistedMessages, account.getUuid(), device);
messagesDynamoDb.store(persistedMessages, account.getIdentifier(IdentityType.ACI), device);
}
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join();
expectedMessages.add(envelope);
}
@ -296,7 +298,8 @@ class WebSocketConnectionIntegrationTest {
new MessageMetrics(),
mock(PushNotificationManager.class),
mock(PushNotificationScheduler.class),
new AuthenticatedDevice(account, device),
account,
device,
webSocketClient,
100, // use a very short timeout, so that this test completes quickly
scheduledExecutorService,
@ -321,13 +324,13 @@ class WebSocketConnectionIntegrationTest {
expectedMessages.add(envelope);
}
messagesDynamoDb.store(persistedMessages, account.getUuid(), device);
messagesDynamoDb.store(persistedMessages, account.getIdentifier(IdentityType.ACI), device);
}
for (int i = 0; i < cachedMessageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid);
messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join();
messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join();
expectedMessages.add(envelope);
}

View File

@ -56,6 +56,7 @@ import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
@ -67,9 +68,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import reactor.core.publisher.Flux;
@ -90,7 +89,6 @@ class WebSocketConnectionTest {
private AccountsManager accountsManager;
private Account account;
private Device device;
private AuthenticatedDevice auth;
private UpgradeRequest upgradeRequest;
private MessagesManager messagesManager;
private ReceiptSender receiptSender;
@ -104,7 +102,6 @@ class WebSocketConnectionTest {
accountsManager = mock(AccountsManager.class);
account = mock(Account.class);
device = mock(Device.class);
auth = new AuthenticatedDevice(account, device);
upgradeRequest = mock(UpgradeRequest.class);
messagesManager = mock(MessagesManager.class);
receiptSender = mock(ReceiptSender.class);
@ -122,8 +119,8 @@ class WebSocketConnectionTest {
@Test
void testCredentials() throws Exception {
WebSocketAccountAuthenticator webSocketAuthenticator =
new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class));
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager,
new WebSocketAccountAuthenticator(accountAuthenticator);
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, receiptSender, messagesManager,
new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class),
mock(WebSocketConnectionEventManager.class), retrySchedulingExecutor,
messageDeliveryScheduler, clientReleaseManager, mock(MessageDeliveryLoopMonitor.class),
@ -133,9 +130,9 @@ class WebSocketConnectionTest {
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(new AuthenticatedDevice(account, device)));
ReusableAuth<AuthenticatedDevice> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated()).thenReturn(account.ref().orElse(null));
when(sessionContext.getAuthenticated(AuthenticatedDevice.class)).thenReturn(account.ref().orElse(null));
Optional<AuthenticatedDevice> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated()).thenReturn(account.orElse(null));
when(sessionContext.getAuthenticated(AuthenticatedDevice.class)).thenReturn(account.orElse(null));
final WebSocketClient webSocketClient = mock(WebSocketClient.class);
when(webSocketClient.getUserAgent()).thenReturn("Signal-Android/6.22.8");
@ -150,7 +147,7 @@ class WebSocketConnectionTest {
// unauthenticated
when(upgradeRequest.getParameterMap()).thenReturn(Map.of());
account = webSocketAuthenticator.authenticate(upgradeRequest);
assertFalse(account.ref().isPresent());
assertFalse(account.isPresent());
connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext, times(2)).addWebsocketClosedListener(
@ -174,7 +171,7 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
final Device sender1device = mock(Device.class);
@ -191,7 +188,7 @@ class WebSocketConnectionTest {
String userAgent = HttpHeaders.USER_AGENT;
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.fromIterable(outgoingMessages));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
@ -237,7 +234,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@ -320,7 +317,7 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
final Device sender1device = mock(Device.class);
@ -337,7 +334,7 @@ class WebSocketConnectionTest {
String userAgent = HttpHeaders.USER_AGENT;
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.fromIterable(pendingMessages));
final List<CompletableFuture<WebSocketResponseMessage>> futures = new LinkedList<>();
@ -364,7 +361,7 @@ class WebSocketConnectionTest {
futures.get(1).complete(response);
futures.get(0).completeExceptionally(new IOException());
verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getUuid())), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)),
verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getIdentifier(IdentityType.ACI))), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)),
eq(secondMessage.getClientTimestamp()));
connection.stop();
@ -377,7 +374,7 @@ class WebSocketConnectionTest {
final WebSocketConnection connection = webSocketConnection(client);
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(account.getIdentifier(IdentityType.ACI)).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@ -385,7 +382,7 @@ class WebSocketConnectionTest {
final AtomicBoolean returnMessageList = new AtomicBoolean(false);
when(
messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenAnswer(invocation -> {
synchronized (threadWaiting) {
threadWaiting.set(true);
@ -442,7 +439,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
final UUID accountUuid = UUID.randomUUID();
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@ -490,7 +487,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
final UUID accountUuid = UUID.randomUUID();
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@ -573,7 +570,7 @@ class WebSocketConnectionTest {
when(account.getNumber()).thenReturn("+18005551234");
final UUID accountUuid = UUID.randomUUID();
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@ -581,7 +578,7 @@ class WebSocketConnectionTest {
final List<Envelope> messages = List.of(
createMessage(senderUuid, UUID.randomUUID(), 1111L, "message the first"));
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.fromIterable(messages))
.thenReturn(Flux.empty());
@ -629,7 +626,7 @@ class WebSocketConnectionTest {
private WebSocketConnection webSocketConnection(final WebSocketClient client) {
return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(),
mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), auth, client,
mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), account, device, client,
retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager,
mock(MessageDeliveryLoopMonitor.class), mock(ExperimentEnrollmentManager.class));
}
@ -642,7 +639,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@ -669,7 +666,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@ -725,7 +722,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@ -741,11 +738,11 @@ class WebSocketConnectionTest {
// anything.
connection.processStoredMessages();
verify(messagesManager).getMessagesForDeviceReactive(account.getUuid(), device, false);
verify(messagesManager).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false);
connection.handleNewMessageAvailable();
verify(messagesManager).getMessagesForDeviceReactive(account.getUuid(), device, true);
verify(messagesManager).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, true);
}
@Test
@ -756,7 +753,7 @@ class WebSocketConnectionTest {
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(device.getId()).thenReturn(Device.PRIMARY_ID);
when(client.isOpen()).thenReturn(true);
@ -773,7 +770,7 @@ class WebSocketConnectionTest {
connection.processStoredMessages();
connection.handleMessagesPersisted();
verify(messagesManager, times(2)).getMessagesForDeviceReactive(account.getUuid(), device, false);
verify(messagesManager, times(2)).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false);
}
@Test
@ -783,9 +780,9 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn((byte) 2);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.error(new RedisException("OH NO")));
when(retrySchedulingExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer(
@ -812,9 +809,9 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn((byte) 2);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false))
when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false))
.thenReturn(Flux.error(new RedisException("OH NO")));
final WebSocketClient client = mock(WebSocketClient.class);
@ -835,7 +832,7 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
final int totalMessages = 1000;
@ -884,7 +881,7 @@ class WebSocketConnectionTest {
when(device.getId()).thenReturn(deviceId);
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid);
final AtomicBoolean canceled = new AtomicBoolean();

View File

@ -1,149 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket;
import java.security.Principal;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import org.whispersystems.websocket.auth.PrincipalSupplier;
/**
* This class holds a principal that can be reused across requests on a websocket. Since two requests may operate
* concurrently on the same principal, and some principals contain non thread-safe mutable state, appropriate use of
* this class ensures that no data races occur. It also ensures that after a principal is modified, a subsequent request
* gets the up-to-date principal
*
* @param <T> The underlying principal type
* @see PrincipalSupplier
*/
public abstract sealed class ReusableAuth<T extends Principal> {
/**
* Get a reference to the underlying principal that callers pledge not to modify.
* <p>
* The reference returned will potentially be provided to many threads concurrently accessing the principal. Callers
* should use this method only if they can ensure that they will not modify the in-memory principal object AND they do
* not intend to modify the underlying canonical representation of the principal.
* <p>
* For example, if a caller retrieves a reference to a principal, does not modify the in memory state, but updates a
* field on a database that should be reflected in subsequent retrievals of the principal, they will have met the
* first criteria, but not the second. In that case they should instead use {@link #mutableRef()}.
* <p>
* If other callers have modified the underlying principal by using {@link #mutableRef()}, this method may need to
* refresh the principal via {@link PrincipalSupplier#refresh} which could be a blocking operation.
*
* @return If authenticated, a reference to the underlying principal that should not be modified
*/
public abstract Optional<T> ref();
public interface MutableRef<T> {
T ref();
void close();
}
/**
* Get a reference to the underlying principal that may be modified.
* <p>
* The underlying principal can be safely modified. Multiple threads may operate on the same {@link ReusableAuth} so
* long as they each have their own {@link MutableRef}. After any modifications, the caller must call
* {@link MutableRef#close} to notify the principal has become dirty. Close should be called after modifications but
* before sending a response on the websocket. This ensures that a request that comes in after a modification response
* is received is guaranteed to see the modification.
*
* @return If authenticated, a reference to the underlying principal that may be modified
*/
public abstract Optional<MutableRef<T>> mutableRef();
/**
* @return A {@link ReusableAuth} indicating no credential were provided
*/
public static <T extends Principal> ReusableAuth<T> anonymous() {
//noinspection unchecked
return (ReusableAuth<T>) Anonymous.ANON_RESULT;
}
/**
* Create a successfully authenticated {@link ReusableAuth}
*
* @param principal The authenticated principal
* @param principalSupplier Instructions for how to refresh or copy this principal
* @param <T> The principal type
* @return A {@link ReusableAuth} for a successfully authenticated principal
*/
public static <T extends Principal> ReusableAuth<T> authenticated(T principal,
PrincipalSupplier<T> principalSupplier) {
return new Authenticated<>(principal, principalSupplier);
}
private static final class Anonymous<T extends Principal> extends ReusableAuth<T> {
@SuppressWarnings({"rawtypes"})
private static final ReusableAuth ANON_RESULT = new Anonymous();
@Override
public Optional<T> ref() {
return Optional.empty();
}
@Override
public Optional<MutableRef<T>> mutableRef() {
return Optional.empty();
}
}
private static final class Authenticated<T extends Principal> extends ReusableAuth<T> {
private T basePrincipal;
private final AtomicBoolean needRefresh = new AtomicBoolean(false);
private final PrincipalSupplier<T> principalSupplier;
Authenticated(final T basePrincipal, PrincipalSupplier<T> principalSupplier) {
this.basePrincipal = basePrincipal;
this.principalSupplier = principalSupplier;
}
@Override
public Optional<T> ref() {
maybeRefresh();
return Optional.of(basePrincipal);
}
@Override
public Optional<MutableRef<T>> mutableRef() {
maybeRefresh();
return Optional.of(new AuthenticatedMutableRef(principalSupplier.deepCopy(basePrincipal)));
}
private void maybeRefresh() {
if (needRefresh.compareAndSet(true, false)) {
basePrincipal = principalSupplier.refresh(basePrincipal);
}
}
private class AuthenticatedMutableRef implements MutableRef<T> {
final T ref;
private AuthenticatedMutableRef(T ref) {
this.ref = ref;
}
public T ref() {
return ref;
}
public void close() {
needRefresh.set(true);
}
}
}
private ReusableAuth() {
}
}

View File

@ -58,7 +58,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
private final Map<Long, CompletableFuture<WebSocketResponseMessage>> requestMap = new ConcurrentHashMap<>();
private final ReusableAuth<T> reusableAuth;
private final Optional<T> reusableAuth;
private final WebSocketMessageFactory messageFactory;
private final Optional<WebSocketConnectListener> connectListener;
private final ApplicationHandler jerseyHandler;
@ -77,7 +77,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
String remoteAddressPropertyName,
ApplicationHandler jerseyHandler,
WebsocketRequestLog requestLog,
ReusableAuth<T> authenticated,
Optional<T> authenticated,
WebSocketMessageFactory messageFactory,
Optional<WebSocketConnectListener> connectListener,
Duration idleTimeout) {
@ -97,7 +97,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
this.remoteEndpoint = session.getRemote();
this.context = new WebSocketSessionContext(
new WebSocketClient(session, remoteEndpoint, messageFactory, requestMap));
this.context.setAuthenticated(reusableAuth.ref().orElse(null));
this.context.setAuthenticated(reusableAuth.orElse(null));
this.session.setIdleTimeout(idleTimeout);
connectListener.ifPresent(listener -> listener.onWebSocketConnect(this.context));
@ -164,16 +164,10 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
/**
* The property name where {@link org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider} can find an
* {@link ReusableAuth} object that lives for the lifetime of the websocket
* authenticated principal that lives for the lifetime of the websocket
*/
public static final String REUSABLE_AUTH_PROPERTY = WebSocketResourceProvider.class.getName() + ".reusableAuth";
/**
* The property name where {@link org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider} can install a
* {@link org.whispersystems.websocket.ReusableAuth.MutableRef} for us to close when the request is finished
*/
public static final String RESOLVED_PRINCIPAL_PROPERTY = WebSocketResourceProvider.class.getName() + ".resolvedPrincipal";
/**
* The property name where request byte count is stored for metrics collection
*/
@ -205,16 +199,6 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
containerRequest, responseBody);
responseFuture
.whenComplete((ignoredResponse, ignoredError) -> {
// If the request ended up being one that mutates our principal, we have to close it to indicate we're done
// with the mutation operation
final Object resolvedPrincipal = containerRequest.getProperty(RESOLVED_PRINCIPAL_PROPERTY);
if (resolvedPrincipal instanceof ReusableAuth.MutableRef<?> ref) {
ref.close();
} else if (resolvedPrincipal != null) {
logger.warn("unexpected resolved principal type {} : {}", resolvedPrincipal.getClass(), resolvedPrincipal);
}
})
.thenAccept(response -> {
try {
final int responseBytes = responseBody.size();

View File

@ -64,9 +64,9 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
try {
Optional<WebSocketAuthenticator<T>> authenticator = Optional.ofNullable(environment.getAuthenticator());
final ReusableAuth<T> authenticated = authenticator.isPresent()
final Optional<T> authenticated = authenticator.isPresent()
? authenticator.get().authenticate(request)
: ReusableAuth.anonymous();
: Optional.empty();
Optional.ofNullable(environment.getAuthenticatedWebSocketUpgradeFilter())
.ifPresent(filter -> filter.handleAuthentication(authenticated, request, response));

View File

@ -6,13 +6,13 @@
package org.whispersystems.websocket.auth;
import java.security.Principal;
import java.util.Optional;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.whispersystems.websocket.ReusableAuth;
public interface AuthenticatedWebSocketUpgradeFilter<T extends Principal> {
void handleAuthentication(ReusableAuth<T> authenticated,
void handleAuthentication(@SuppressWarnings("OptionalUsedAsFieldOrParameterType") Optional<T> authenticated,
JettyServerUpgradeRequest request,
JettyServerUpgradeResponse response);
}

View File

@ -1,26 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket.auth;
import io.dropwizard.auth.Auth;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* An @{@link Auth} object annotated with {@link Mutable} indicates that the consumer of the object
* will modify the object or its underlying canonical source.
*
* Note: An {@link Auth} object that does not specify @{@link ReadOnly} will be assumed to be @Mutable
*
* @see org.whispersystems.websocket.ReusableAuth
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD, ElementType.PARAMETER})
public @interface Mutable {
}

View File

@ -1,58 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket.auth;
/**
* Teach {@link org.whispersystems.websocket.ReusableAuth} how to make a deep copy of a principal (that is safe to
* concurrently modify while the original principal is being read), and how to refresh a principal after it has been
* potentially modified.
*
* @param <T> The underlying principal type
*/
public interface PrincipalSupplier<T> {
/**
* Re-fresh the principal after it has been modified.
* <p>
* If the principal is populated from a backing store, refresh should re-read it.
*
* @param t the potentially stale principal to refresh
* @return The up-to-date principal
*/
T refresh(T t);
/**
* Create a deep, in-memory copy of the principal. This should be identical to the original principal, but should
* share no mutable state with the original. It should be safe for two threads to independently write and read from
* two independent deep copies.
*
* @param t the principal to copy
* @return An in-memory copy of the principal
*/
T deepCopy(T t);
class ImmutablePrincipalSupplier<T> implements PrincipalSupplier<T> {
@SuppressWarnings({"rawtypes"})
private static final PrincipalSupplier INSTANCE = new ImmutablePrincipalSupplier();
@Override
public T refresh(final T t) {
return t;
}
@Override
public T deepCopy(final T t) {
return t;
}
}
/**
* @return A principal supplier that can be used if the principal type does not support modification.
*/
static <T> PrincipalSupplier<T> forImmutablePrincipal() {
//noinspection unchecked
return (PrincipalSupplier<T>) ImmutablePrincipalSupplier.INSTANCE;
}
}

View File

@ -1,25 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.websocket.auth;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* An @{@link io.dropwizard.auth.Auth} object annotated with {@link ReadOnly} indicates that the consumer of the object
* will never modify the object, nor its underlying canonical source.
* <p>
* For example, a consumer of a @ReadOnly AuthenticatedAccount promises to never modify the in-memory
* AuthenticatedAccount and to never modify the underlying Account database for the account.
*
* @see org.whispersystems.websocket.ReusableAuth
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD, ElementType.PARAMETER})
public @interface ReadOnly {
}

View File

@ -5,9 +5,20 @@
package org.whispersystems.websocket.auth;
import java.security.Principal;
import java.util.Optional;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.websocket.ReusableAuth;
public interface WebSocketAuthenticator<T extends Principal> {
ReusableAuth<T> authenticate(UpgradeRequest request) throws InvalidCredentialsException;
/**
* Authenticates an account from credential headers provided in a WebSocket upgrade request.
*
* @param request the request from which to extract credentials
*
* @return the authenticated principal if credentials were provided and authenticated or empty if the caller is
* anonymous
*
* @throws InvalidCredentialsException if credentials were provided, but could not be authenticated
*/
Optional<T> authenticate(UpgradeRequest request) throws InvalidCredentialsException;
}

View File

@ -21,7 +21,6 @@ import org.glassfish.jersey.server.model.Parameter;
import org.glassfish.jersey.server.spi.internal.ValueParamProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.websocket.ReusableAuth;
import org.whispersystems.websocket.WebSocketResourceProvider;
@Singleton
@ -43,36 +42,30 @@ public class WebsocketAuthValueFactoryProvider<T extends Principal> extends Abst
return null;
}
final boolean readOnly = parameter.isAnnotationPresent(ReadOnly.class);
final boolean readOnly = true;
if (parameter.getRawType() == Optional.class
&& ParameterizedType.class.isAssignableFrom(parameter.getType().getClass())
&& principalClass == ((ParameterizedType) parameter.getType()).getActualTypeArguments()[0]) {
return containerRequest -> createPrincipal(containerRequest, readOnly);
return this::createPrincipal;
} else if (principalClass.equals(parameter.getRawType())) {
return containerRequest ->
createPrincipal(containerRequest, readOnly)
createPrincipal(containerRequest)
.orElseThrow(() -> new WebApplicationException("Authenticated resource", 401));
} else {
throw new IllegalStateException("Can't inject unassignable principal: " + principalClass + " for parameter: " + parameter);
}
}
private Optional<? extends Principal> createPrincipal(final ContainerRequest request, final boolean readOnly) {
private Optional<? extends Principal> createPrincipal(final ContainerRequest request) {
final Object obj = request.getProperty(WebSocketResourceProvider.REUSABLE_AUTH_PROPERTY);
if (!(obj instanceof ReusableAuth<?>)) {
if (!(obj instanceof Optional<?>)) {
logger.warn("Unexpected reusable auth property type {} : {}", obj.getClass(), obj);
return Optional.empty();
}
@SuppressWarnings("unchecked") final ReusableAuth<T> reusableAuth = (ReusableAuth<T>) obj;
if (readOnly) {
return reusableAuth.ref();
} else {
return reusableAuth.mutableRef().map(writeRef -> {
request.setProperty(WebSocketResourceProvider.RESOLVED_PRINCIPAL_PROPERTY, writeRef);
return writeRef.ref();
});
}
//noinspection unchecked
return (Optional<T>) obj;
}
@Singleton

View File

@ -17,6 +17,7 @@ import io.dropwizard.jersey.DropwizardResourceConfig;
import jakarta.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.security.Principal;
import java.util.Optional;
import javax.security.auth.Subject;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
@ -26,7 +27,6 @@ import org.glassfish.jersey.server.ResourceConfig;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.websocket.auth.InvalidCredentialsException;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
@ -75,7 +75,7 @@ public class WebSocketResourceProviderFactoryTest {
when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request)))
.thenReturn(ReusableAuth.authenticated(account, PrincipalSupplier.forImmutablePrincipal()));
.thenReturn(Optional.of(account));
when(environment.jersey()).thenReturn(jerseyEnvironment);
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1");
@ -129,8 +129,7 @@ public class WebSocketResourceProviderFactoryTest {
@Test
void testAuthenticatedWebSocketUpgradeFilter() throws InvalidCredentialsException {
final Account account = new Account();
final ReusableAuth<Account> reusableAuth =
ReusableAuth.authenticated(account, PrincipalSupplier.forImmutablePrincipal());
final Optional<Account> reusableAuth = Optional.of(account);
when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenReturn(reusableAuth);

View File

@ -59,7 +59,6 @@ import org.glassfish.jersey.server.ResourceConfig;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.websocket.auth.PrincipalSupplier;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.logging.WebsocketRequestLog;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
@ -81,7 +80,7 @@ class WebSocketResourceProviderTest {
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME,
applicationHandler, requestLog,
immutableTestPrincipal("fooz"),
Optional.of(new TestPrincipal("fooz")),
new ProtobufWebSocketMessageFactory(),
Optional.of(connectListener),
Duration.ofMillis(30000));
@ -109,7 +108,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("foo")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -186,7 +185,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("foo")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -242,7 +241,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("foo")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -282,7 +281,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("foo")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -322,7 +321,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("authorizedUserName"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("authorizedUserName")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -362,7 +361,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, ReusableAuth.anonymous(),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.empty(),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -401,7 +400,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("something"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("something")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -441,7 +440,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, ReusableAuth.anonymous(),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.empty(),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -481,7 +480,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("gooduser")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -522,7 +521,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("gooduser")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -564,7 +563,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("gooduser")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -604,7 +603,7 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"),
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("gooduser")),
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
@ -729,10 +728,6 @@ class WebSocketResourceProviderTest {
}
}
public static ReusableAuth<TestPrincipal> immutableTestPrincipal(final String name) {
return ReusableAuth.authenticated(new TestPrincipal(name), PrincipalSupplier.forImmutablePrincipal());
}
public static class TestException extends Exception {
public TestException(String message) {