From c952baa672af87a6aee90e64b78dd855c3080b06 Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Mon, 23 Jun 2025 08:40:05 -0500 Subject: [PATCH] Don't cache authenticated accounts in memory --- .../textsecuregcm/WhisperServerService.java | 20 +- .../AccountAndAuthenticatedDeviceHolder.java | 10 + .../auth/AuthenticatedDevice.java | 18 + .../auth/CertificateGenerator.java | 4 +- .../auth/ContainerRequestUtil.java | 44 --- ...ceAuthenticatedWebSocketUpgradeFilter.java | 27 +- ...inkedDeviceRefreshRequirementProvider.java | 96 ----- ...umberChangeRefreshRequirementProvider.java | 56 --- ...socketRefreshApplicationEventListener.java | 38 -- .../WebsocketRefreshRequestEventListener.java | 74 ---- .../WebsocketRefreshRequirementProvider.java | 34 -- .../controllers/AccountController.java | 140 +++++--- .../controllers/AccountControllerV2.java | 39 ++- .../controllers/ArchiveController.java | 118 ++++--- .../controllers/AttachmentControllerV4.java | 7 +- .../controllers/CallLinkController.java | 7 +- .../controllers/CallRoutingControllerV2.java | 7 +- .../controllers/CertificateController.java | 29 +- .../controllers/ChallengeController.java | 35 +- .../controllers/DeviceCheckController.java | 42 ++- .../controllers/DeviceController.java | 117 ++++--- .../controllers/DirectoryV2Controller.java | 7 +- .../controllers/DonationController.java | 41 ++- .../controllers/GetCallingRelaysResponse.java | 3 +- .../controllers/KeepAliveController.java | 7 +- .../KeyTransparencyController.java | 7 +- .../controllers/KeysController.java | 213 ++++++------ .../controllers/MessageController.java | 113 +++--- ...tiRecipientMismatchedDevicesException.java | 2 +- .../OneTimeDonationController.java | 9 +- .../controllers/PaymentsController.java | 7 +- .../controllers/ProfileController.java | 41 ++- .../controllers/ProvisioningController.java | 5 +- .../controllers/RemoteConfigController.java | 5 +- .../controllers/SecureStorageController.java | 5 +- .../SecureValueRecovery2Controller.java | 5 +- .../controllers/StickerController.java | 5 +- .../controllers/SubscriptionController.java | 23 +- ...tionSessionRateLimitExceededException.java | 4 +- .../entities/IncomingMessage.java | 13 +- .../filters/RestDeprecationFilter.java | 4 +- .../textsecuregcm/spam/SpamChecker.java | 8 +- .../storage/AccountPrincipalSupplier.java | 36 -- .../storage/AccountsManager.java | 9 +- .../AuthenticatedConnectListener.java | 32 +- .../WebSocketAccountAuthenticator.java | 24 +- .../websocket/WebSocketConnection.java | 42 ++- ...socketResourceProviderIntegrationTest.java | 6 +- .../WebsocketReuseAuthIntegrationTest.java | 279 --------------- .../auth/CertificateGeneratorTest.java | 8 +- ...thenticatedWebSocketUpgradeFilterTest.java | 12 +- ...dDeviceRefreshRequirementProviderTest.java | 328 ------------------ ...rChangeRefreshRequirementProviderTest.java | 294 ---------------- .../controllers/AccountControllerTest.java | 5 + .../controllers/AccountControllerV2Test.java | 7 + .../controllers/ArchiveControllerTest.java | 7 +- .../CertificateControllerTest.java | 14 +- .../controllers/ChallengeControllerTest.java | 7 +- .../DeviceCheckControllerTest.java | 7 +- .../controllers/DeviceControllerTest.java | 92 ++--- .../DirectoryControllerV2Test.java | 3 +- .../controllers/DonationControllerTest.java | 2 + .../controllers/KeysControllerTest.java | 11 + .../controllers/MessageControllerTest.java | 8 + .../controllers/ProfileControllerTest.java | 3 + .../RemoteConfigControllerTest.java | 7 +- .../entities/OutgoingMessageEntityTest.java | 10 +- .../MetricsRequestEventListenerTest.java | 4 +- .../tests/util/AccountsHelper.java | 2 + .../textsecuregcm/tests/util/AuthHelper.java | 1 + .../tests/util/TestPrincipal.java | 7 +- .../LoggingUnhandledExceptionMapperTest.java | 2 +- .../WebSocketAccountAuthenticatorTest.java | 7 +- .../WebSocketConnectionIntegrationTest.java | 25 +- .../websocket/WebSocketConnectionTest.java | 69 ++-- .../websocket/ReusableAuth.java | 149 -------- .../websocket/WebSocketResourceProvider.java | 24 +- .../WebSocketResourceProviderFactory.java | 4 +- .../AuthenticatedWebSocketUpgradeFilter.java | 4 +- .../websocket/auth/Mutable.java | 26 -- .../websocket/auth/PrincipalSupplier.java | 58 ---- .../websocket/auth/ReadOnly.java | 25 -- .../auth/WebSocketAuthenticator.java | 15 +- .../WebsocketAuthValueFactoryProvider.java | 23 +- .../WebSocketResourceProviderFactoryTest.java | 7 +- .../WebSocketResourceProviderTest.java | 31 +- 86 files changed, 961 insertions(+), 2264 deletions(-) delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/auth/ContainerRequestUtil.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/auth/LinkedDeviceRefreshRequirementProvider.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProvider.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequirementProvider.java delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountPrincipalSupplier.java delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/WebsocketReuseAuthIntegrationTest.java delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/auth/LinkedDeviceRefreshRequirementProviderTest.java delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java delete mode 100644 websocket-resources/src/main/java/org/whispersystems/websocket/ReusableAuth.java delete mode 100644 websocket-resources/src/main/java/org/whispersystems/websocket/auth/Mutable.java delete mode 100644 websocket-resources/src/main/java/org/whispersystems/websocket/auth/PrincipalSupplier.java delete mode 100644 websocket-resources/src/main/java/org/whispersystems/websocket/auth/ReadOnly.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 9adda66e8..8b461225c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -85,7 +85,6 @@ import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator import org.whispersystems.textsecuregcm.auth.IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter; import org.whispersystems.textsecuregcm.auth.PhoneVerificationTokenManager; import org.whispersystems.textsecuregcm.auth.RegistrationLockVerificationManager; -import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener; import org.whispersystems.textsecuregcm.auth.grpc.ProhibitAuthenticationInterceptor; import org.whispersystems.textsecuregcm.auth.grpc.RequireAuthenticationInterceptor; import org.whispersystems.textsecuregcm.backup.BackupAuthManager; @@ -212,7 +211,6 @@ import org.whispersystems.textsecuregcm.spam.RegistrationRecoveryChecker; import org.whispersystems.textsecuregcm.spam.SpamChecker; import org.whispersystems.textsecuregcm.spam.SpamFilter; import org.whispersystems.textsecuregcm.storage.AccountLockManager; -import org.whispersystems.textsecuregcm.storage.AccountPrincipalSupplier; import org.whispersystems.textsecuregcm.storage.Accounts; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.ChangeNumberManager; @@ -980,23 +978,19 @@ public class WhisperServerService extends Application(AuthenticatedDevice.class)); - environment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, - disconnectionRequestManager)); environment.jersey().register(new TimestampResponseFilter()); /// WebSocketEnvironment webSocketEnvironment = new WebSocketEnvironment<>(environment, config.getWebSocketConfiguration(), Duration.ofMillis(90000)); webSocketEnvironment.jersey().register(new VirtualExecutorServiceProvider("managed-async-websocket-virtual-thread-")); - webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator, new AccountPrincipalSupplier(accountsManager))); + webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator)); webSocketEnvironment.setAuthenticatedWebSocketUpgradeFilter(new IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter( keysManager, config.idlePrimaryDeviceReminderConfiguration().minIdleDuration(), Clock.systemUTC())); webSocketEnvironment.setConnectListener( - new AuthenticatedConnectListener(receiptSender, messagesManager, messageMetrics, pushNotificationManager, + new AuthenticatedConnectListener(accountsManager, receiptSender, messagesManager, messageMetrics, pushNotificationManager, pushNotificationScheduler, webSocketConnectionEventManager, websocketScheduledExecutor, messageDeliveryScheduler, clientReleaseManager, messageDeliveryLoopMonitor, experimentEnrollmentManager)); - webSocketEnvironment.jersey() - .register(new WebsocketRefreshApplicationEventListener(accountsManager, disconnectionRequestManager)); webSocketEnvironment.jersey().register(new RateLimitByIpFilter(rateLimiters)); webSocketEnvironment.jersey().register(new RequestStatisticsFilter(TrafficSource.WEBSOCKET)); webSocketEnvironment.jersey().register(MultiRecipientMessageProvider.class); @@ -1083,15 +1077,15 @@ public class WhisperServerService extends Application provisioningEnvironment = new WebSocketEnvironment<>(environment, webSocketEnvironment.getRequestLog(), Duration.ofMillis(60000)); - provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, - disconnectionRequestManager)); provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager, provisioningWebsocketTimeoutExecutor, Duration.ofSeconds(90))); provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET, clientReleaseManager)); provisioningEnvironment.jersey().register(new KeepAliveController(webSocketConnectionEventManager)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AccountAndAuthenticatedDeviceHolder.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AccountAndAuthenticatedDeviceHolder.java index bf10bd657..293cb62df 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AccountAndAuthenticatedDeviceHolder.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AccountAndAuthenticatedDeviceHolder.java @@ -7,10 +7,20 @@ package org.whispersystems.textsecuregcm.auth; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; +import java.time.Instant; +import java.util.UUID; public interface AccountAndAuthenticatedDeviceHolder { + UUID getAccountIdentifier(); + + byte getDeviceId(); + + Instant getPrimaryDeviceLastSeen(); + + @Deprecated(forRemoval = true) Account getAccount(); + @Deprecated(forRemoval = true) Device getAuthenticatedDevice(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedDevice.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedDevice.java index 7193d5dad..f9f122209 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedDevice.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedDevice.java @@ -6,7 +6,10 @@ package org.whispersystems.textsecuregcm.auth; import java.security.Principal; +import java.time.Instant; +import java.util.UUID; import javax.security.auth.Subject; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; @@ -30,6 +33,21 @@ public class AuthenticatedDevice implements Principal, AccountAndAuthenticatedDe return device; } + @Override + public UUID getAccountIdentifier() { + return account.getIdentifier(IdentityType.ACI); + } + + @Override + public byte getDeviceId() { + return device.getId(); + } + + @Override + public Instant getPrimaryDeviceLastSeen() { + return Instant.ofEpochMilli(account.getPrimaryDevice().getLastSeen()); + } + // Principal implementation @Override diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/CertificateGenerator.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/CertificateGenerator.java index 93a568ae6..123b2b59e 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/CertificateGenerator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/CertificateGenerator.java @@ -31,9 +31,9 @@ public class CertificateGenerator { this.serverCertificate = ServerCertificate.parseFrom(serverCertificate); } - public byte[] createFor(Account account, Device device, boolean includeE164) throws InvalidKeyException { + public byte[] createFor(final Account account, final byte deviceId, boolean includeE164) throws InvalidKeyException { SenderCertificate.Certificate.Builder builder = SenderCertificate.Certificate.newBuilder() - .setSenderDevice(Math.toIntExact(device.getId())) + .setSenderDevice(Math.toIntExact(deviceId)) .setExpires(System.currentTimeMillis() + TimeUnit.DAYS.toMillis(expiresDays)) .setIdentityKey(ByteString.copyFrom(account.getIdentityKey(IdentityType.ACI).serialize())) .setSigner(serverCertificate) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/ContainerRequestUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/ContainerRequestUtil.java deleted file mode 100644 index af1d15bbb..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/ContainerRequestUtil.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright 2013-2021 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.auth; - -import jakarta.ws.rs.core.SecurityContext; -import java.util.Optional; -import java.util.Set; -import java.util.UUID; -import java.util.stream.Collectors; -import org.glassfish.jersey.server.ContainerRequest; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.Device; - -class ContainerRequestUtil { - - /** - * A read-only subset of the authenticated Account object, to enforce that filter-based consumers do not perform - * account modifying operations. - */ - record AccountInfo(UUID accountId, String e164, Set deviceIds) { - - static AccountInfo fromAccount(final Account account) { - return new AccountInfo( - account.getUuid(), - account.getNumber(), - account.getDevices().stream().map(Device::getId).collect(Collectors.toSet())); - } - } - - static Optional getAuthenticatedAccount(final ContainerRequest request) { - return Optional.ofNullable(request.getSecurityContext()) - .map(SecurityContext::getUserPrincipal) - .map(principal -> { - if (principal instanceof AccountAndAuthenticatedDeviceHolder aaadh) { - return aaadh.getAccount(); - } - return null; - }) - .map(AccountInfo::fromAccount); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter.java index d090df0a0..531b61e64 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter.java @@ -8,17 +8,16 @@ package org.whispersystems.textsecuregcm.auth; import com.google.common.annotations.VisibleForTesting; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest; -import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse; -import org.whispersystems.textsecuregcm.identity.IdentityType; -import org.whispersystems.textsecuregcm.metrics.MetricsUtil; -import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.storage.KeysManager; -import org.whispersystems.websocket.ReusableAuth; -import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter; import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.util.Optional; +import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest; +import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.KeysManager; +import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter; public class IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter implements AuthenticatedWebSocketUpgradeFilter { @@ -58,21 +57,19 @@ public class IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter implements } @Override - public void handleAuthentication(final ReusableAuth authenticated, + public void handleAuthentication(final Optional authenticated, final JettyServerUpgradeRequest request, final JettyServerUpgradeResponse response) { // No action needed if the connection is unauthenticated (in which case we don't know when we've last seen the // primary device) or if the authenticated device IS the primary device - authenticated.ref() - .filter(authenticatedDevice -> !authenticatedDevice.getAuthenticatedDevice().isPrimary()) + authenticated + .filter(authenticatedDevice -> authenticatedDevice.getDeviceId() != Device.PRIMARY_ID) .ifPresent(authenticatedDevice -> { - final Instant primaryDeviceLastSeen = - Instant.ofEpochMilli(authenticatedDevice.getAccount().getPrimaryDevice().getLastSeen()); + final Instant primaryDeviceLastSeen = authenticatedDevice.getPrimaryDeviceLastSeen(); if (primaryDeviceLastSeen.isBefore(clock.instant().minus(PQ_KEY_CHECK_THRESHOLD)) && - keysManager.getLastResort(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI), Device.PRIMARY_ID) - .join().isEmpty()) { + keysManager.getLastResort(authenticatedDevice.getAccountIdentifier(), Device.PRIMARY_ID).join().isEmpty()) { response.addHeader(ALERT_HEADER, CRITICAL_IDLE_PRIMARY_DEVICE_ALERT); CRITICAL_IDLE_PRIMARY_WARNING_COUNTER.increment(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/LinkedDeviceRefreshRequirementProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/LinkedDeviceRefreshRequirementProvider.java deleted file mode 100644 index 01eb00fe1..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/LinkedDeviceRefreshRequirementProvider.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Copyright 2021 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.auth; - -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.UUID; -import java.util.stream.Collectors; -import org.glassfish.jersey.server.ContainerRequest; -import org.glassfish.jersey.server.monitoring.RequestEvent; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.AccountsManager; -import org.whispersystems.textsecuregcm.util.Pair; - -/** - * This {@link WebsocketRefreshRequirementProvider} observes intra-request changes in devices linked to an - * {@link Account} and triggers a WebSocket refresh if that set changes. If a change in linked devices is observed, then - * any active WebSocket connections for the account must be closed in order for clients to get a refreshed - * {@link io.dropwizard.auth.Auth} object with a current device list. - * - * @see AuthenticatedDevice - */ -public class LinkedDeviceRefreshRequirementProvider implements WebsocketRefreshRequirementProvider { - - private final AccountsManager accountsManager; - - private static final Logger logger = LoggerFactory.getLogger(LinkedDeviceRefreshRequirementProvider.class); - - private static final String ACCOUNT_UUID = LinkedDeviceRefreshRequirementProvider.class.getName() + ".accountUuid"; - private static final String LINKED_DEVICE_IDS = LinkedDeviceRefreshRequirementProvider.class.getName() + ".deviceIds"; - - public LinkedDeviceRefreshRequirementProvider(final AccountsManager accountsManager) { - this.accountsManager = accountsManager; - } - - @Override - public void handleRequestFiltered(final RequestEvent requestEvent) { - if (requestEvent.getUriInfo().getMatchedResourceMethod().getInvocable().getHandlingMethod().getAnnotation( - ChangesLinkedDevices.class) != null) { - // The authenticated principal, if any, will be available after filters have run. Now that the account is known, - // capture a snapshot of the account's linked devices before carrying out the request’s business logic. - ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest()) - .ifPresent(account -> setAccount(requestEvent.getContainerRequest(), account)); - } - } - - public static void setAccount(final ContainerRequest containerRequest, final Account account) { - setAccount(containerRequest, ContainerRequestUtil.AccountInfo.fromAccount(account)); - } - - private static void setAccount(final ContainerRequest containerRequest, final ContainerRequestUtil.AccountInfo info) { - containerRequest.setProperty(ACCOUNT_UUID, info.accountId()); - containerRequest.setProperty(LINKED_DEVICE_IDS, info.deviceIds()); - } - - @Override - public List> handleRequestFinished(final RequestEvent requestEvent) { - // Now that the request is finished, check whether the set of linked devices has changed. If the value did change or - // if a devices was added or removed, all devices must disconnect and reauthenticate. - if (requestEvent.getContainerRequest().getProperty(LINKED_DEVICE_IDS) != null) { - - @SuppressWarnings("unchecked") final Set initialLinkedDeviceIds = - (Set) requestEvent.getContainerRequest().getProperty(LINKED_DEVICE_IDS); - - return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID)) - .map(ContainerRequestUtil.AccountInfo::fromAccount) - .map(accountInfo -> { - final Set deviceIdsToDisplace; - final Set currentLinkedDeviceIds = accountInfo.deviceIds(); - - if (!initialLinkedDeviceIds.equals(currentLinkedDeviceIds)) { - deviceIdsToDisplace = new HashSet<>(initialLinkedDeviceIds); - deviceIdsToDisplace.addAll(currentLinkedDeviceIds); - } else { - deviceIdsToDisplace = Collections.emptySet(); - } - - return deviceIdsToDisplace.stream() - .map(deviceId -> new Pair<>(accountInfo.accountId(), deviceId)) - .collect(Collectors.toList()); - }).orElseGet(() -> { - logger.error("Request had account, but it is no longer present"); - return Collections.emptyList(); - }); - } else { - return Collections.emptyList(); - } - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProvider.java deleted file mode 100644 index db1c35ae9..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProvider.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright 2013-2021 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.auth; - -import java.util.Collections; -import java.util.List; -import java.util.UUID; -import java.util.stream.Collectors; -import org.glassfish.jersey.server.monitoring.RequestEvent; -import org.whispersystems.textsecuregcm.storage.AccountsManager; -import org.whispersystems.textsecuregcm.util.Pair; - -public class PhoneNumberChangeRefreshRequirementProvider implements WebsocketRefreshRequirementProvider { - - private static final String ACCOUNT_UUID = - PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".accountUuid"; - - private static final String INITIAL_NUMBER_KEY = - PhoneNumberChangeRefreshRequirementProvider.class.getName() + ".initialNumber"; - private final AccountsManager accountsManager; - - public PhoneNumberChangeRefreshRequirementProvider(final AccountsManager accountsManager) { - this.accountsManager = accountsManager; - } - - @Override - public void handleRequestFiltered(final RequestEvent requestEvent) { - if (requestEvent.getUriInfo().getMatchedResourceMethod().getInvocable().getHandlingMethod() - .getAnnotation(ChangesPhoneNumber.class) == null) { - return; - } - ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest()) - .ifPresent(account -> { - requestEvent.getContainerRequest().setProperty(INITIAL_NUMBER_KEY, account.e164()); - requestEvent.getContainerRequest().setProperty(ACCOUNT_UUID, account.accountId()); - }); - } - - @Override - public List> handleRequestFinished(final RequestEvent requestEvent) { - final String initialNumber = (String) requestEvent.getContainerRequest().getProperty(INITIAL_NUMBER_KEY); - - if (initialNumber == null) { - return Collections.emptyList(); - } - return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID)) - .filter(account -> !initialNumber.equals(account.getNumber())) - .map(account -> account.getDevices().stream() - .map(device -> new Pair<>(account.getUuid(), device.getId())) - .collect(Collectors.toList())) - .orElse(Collections.emptyList()); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java deleted file mode 100644 index 10944cc46..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshApplicationEventListener.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2021 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.auth; - -import org.glassfish.jersey.server.monitoring.ApplicationEvent; -import org.glassfish.jersey.server.monitoring.ApplicationEventListener; -import org.glassfish.jersey.server.monitoring.RequestEvent; -import org.glassfish.jersey.server.monitoring.RequestEventListener; -import org.whispersystems.textsecuregcm.storage.AccountsManager; - -/** - * Delegates request events to a listener that watches for intra-request changes that require websocket refreshes - */ -public class WebsocketRefreshApplicationEventListener implements ApplicationEventListener { - - private final WebsocketRefreshRequestEventListener websocketRefreshRequestEventListener; - - public WebsocketRefreshApplicationEventListener(final AccountsManager accountsManager, - final DisconnectionRequestManager disconnectionRequestManager) { - - this.websocketRefreshRequestEventListener = new WebsocketRefreshRequestEventListener( - disconnectionRequestManager, - new LinkedDeviceRefreshRequirementProvider(accountsManager), - new PhoneNumberChangeRefreshRequirementProvider(accountsManager)); - } - - @Override - public void onEvent(final ApplicationEvent event) { - } - - @Override - public RequestEventListener onRequest(final RequestEvent requestEvent) { - return websocketRefreshRequestEventListener; - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java deleted file mode 100644 index 9f6a9bd0e..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequestEventListener.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright 2013-2021 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.auth; - -import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; - -import io.micrometer.core.instrument.Counter; -import io.micrometer.core.instrument.Metrics; -import jakarta.ws.rs.container.ResourceInfo; -import jakarta.ws.rs.core.Context; -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; -import org.glassfish.jersey.server.monitoring.RequestEvent; -import org.glassfish.jersey.server.monitoring.RequestEvent.Type; -import org.glassfish.jersey.server.monitoring.RequestEventListener; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class WebsocketRefreshRequestEventListener implements RequestEventListener { - - private final DisconnectionRequestManager disconnectionRequestManager; - private final WebsocketRefreshRequirementProvider[] providers; - - private static final Counter DISPLACED_ACCOUNTS = Metrics.counter( - name(WebsocketRefreshRequestEventListener.class, "displacedAccounts")); - - private static final Counter DISPLACED_DEVICES = Metrics.counter( - name(WebsocketRefreshRequestEventListener.class, "displacedDevices")); - - private static final Logger logger = LoggerFactory.getLogger(WebsocketRefreshRequestEventListener.class); - - public WebsocketRefreshRequestEventListener( - final DisconnectionRequestManager disconnectionRequestManager, - final WebsocketRefreshRequirementProvider... providers) { - - this.disconnectionRequestManager = disconnectionRequestManager; - this.providers = providers; - } - - @Context - private ResourceInfo resourceInfo; - - @Override - public void onEvent(final RequestEvent event) { - if (event.getType() == Type.REQUEST_FILTERED) { - for (final WebsocketRefreshRequirementProvider provider : providers) { - provider.handleRequestFiltered(event); - } - } else if (event.getType() == Type.FINISHED) { - final AtomicInteger displacedDevices = new AtomicInteger(0); - - Arrays.stream(providers) - .flatMap(provider -> provider.handleRequestFinished(event).stream()) - .distinct() - .forEach(pair -> { - try { - displacedDevices.incrementAndGet(); - disconnectionRequestManager.requestDisconnection(pair.first(), List.of(pair.second())); - } catch (final Exception e) { - logger.error("Could not displace device presence", e); - } - }); - - if (displacedDevices.get() > 0) { - DISPLACED_ACCOUNTS.increment(); - DISPLACED_DEVICES.increment(displacedDevices.get()); - } - } - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequirementProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequirementProvider.java deleted file mode 100644 index 5e020381f..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/WebsocketRefreshRequirementProvider.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright 2013-2021 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.auth; - -import java.util.List; -import java.util.UUID; -import org.glassfish.jersey.server.monitoring.RequestEvent; -import org.whispersystems.textsecuregcm.util.Pair; - -/** - * A websocket refresh requirement provider watches for intra-request changes (e.g. to authentication status) that - * require a websocket refresh. - */ -public interface WebsocketRefreshRequirementProvider { - - /** - * Processes a request after filters have run and the request has been mapped to a destination controller. - * - * @param requestEvent the request event to observe - */ - void handleRequestFiltered(RequestEvent requestEvent); - - /** - * Processes a request after all normal request handling has been completed. - * - * @param requestEvent the request event to observe - * @return a list of pairs of account UUID/device ID pairs identifying websockets that need to be refreshed as a - * result of the observed request - */ - List> handleRequestFinished(RequestEvent requestEvent); -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java index 1127cdbe4..621073e16 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java @@ -66,8 +66,6 @@ import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier; import org.whispersystems.textsecuregcm.util.Util; -import org.whispersystems.websocket.auth.Mutable; -import org.whispersystems.websocket.auth.ReadOnly; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @Path("/v1/accounts") @@ -97,11 +95,14 @@ public class AccountController { @Path("/gcm/") @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) - public void setGcmRegistrationId(@Mutable @Auth AuthenticatedDevice auth, + public void setGcmRegistrationId(@Auth AuthenticatedDevice auth, @NotNull @Valid GcmRegistrationId registrationId) { - final Account account = auth.getAccount(); - final Device device = auth.getAuthenticatedDevice(); + final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); + + final Device device = account.getDevice(auth.getDeviceId()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); if (Objects.equals(device.getGcmId(), registrationId.gcmRegistrationId())) { return; @@ -116,9 +117,12 @@ public class AccountController { @DELETE @Path("/gcm/") - public void deleteGcmRegistrationId(@Mutable @Auth AuthenticatedDevice auth) { - Account account = auth.getAccount(); - Device device = auth.getAuthenticatedDevice(); + public void deleteGcmRegistrationId(@Auth AuthenticatedDevice auth) { + final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); + + final Device device = account.getDevice(auth.getDeviceId()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); accounts.updateDevice(account, device.getId(), d -> { d.setGcmId(null); @@ -131,11 +135,14 @@ public class AccountController { @Path("/apn/") @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) - public void setApnRegistrationId(@Mutable @Auth AuthenticatedDevice auth, + public void setApnRegistrationId(@Auth AuthenticatedDevice auth, @NotNull @Valid ApnRegistrationId registrationId) { - final Account account = auth.getAccount(); - final Device device = auth.getAuthenticatedDevice(); + final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); + + final Device device = account.getDevice(auth.getDeviceId()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); // Unlike FCM tokens, we need current "last updated" timestamps for APNs tokens and so update device records // unconditionally @@ -148,9 +155,12 @@ public class AccountController { @DELETE @Path("/apn/") - public void deleteApnRegistrationId(@Mutable @Auth AuthenticatedDevice auth) { - Account account = auth.getAccount(); - Device device = auth.getAuthenticatedDevice(); + public void deleteApnRegistrationId(@Auth AuthenticatedDevice auth) { + final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); + + final Device device = account.getDevice(auth.getDeviceId()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); accounts.updateDevice(account, device.getId(), d -> { d.setApnId(null); @@ -166,17 +176,23 @@ public class AccountController { @PUT @Produces(MediaType.APPLICATION_JSON) @Path("/registration_lock") - public void setRegistrationLock(@Mutable @Auth AuthenticatedDevice auth, @NotNull @Valid RegistrationLock accountLock) { - SaltedTokenHash credentials = SaltedTokenHash.generateFor(accountLock.getRegistrationLock()); + public void setRegistrationLock(@Auth AuthenticatedDevice auth, @NotNull @Valid RegistrationLock accountLock) { + final SaltedTokenHash credentials = SaltedTokenHash.generateFor(accountLock.getRegistrationLock()); - accounts.update(auth.getAccount(), + final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); + + accounts.update(account, a -> a.setRegistrationLock(credentials.hash(), credentials.salt())); } @DELETE @Path("/registration_lock") - public void removeRegistrationLock(@Mutable @Auth AuthenticatedDevice auth) { - accounts.update(auth.getAccount(), a -> a.setRegistrationLock(null, null)); + public void removeRegistrationLock(@Auth AuthenticatedDevice auth) { + final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); + + accounts.update(account, a -> a.setRegistrationLock(null, null)); } @PUT @@ -190,7 +206,7 @@ public class AccountController { @ApiResponse(responseCode = "204", description = "Device name changed successfully") @ApiResponse(responseCode = "404", description = "No device found with the given ID") @ApiResponse(responseCode = "403", description = "Not authorized to change the name of the device with the given ID") - public void setName(@Mutable @Auth final AuthenticatedDevice auth, + public void setName(@Auth final AuthenticatedDevice auth, @NotNull @Valid final DeviceName deviceName, @Nullable @@ -199,15 +215,16 @@ public class AccountController { requiredMode = Schema.RequiredMode.NOT_REQUIRED) final Byte deviceId) { - final Account account = auth.getAccount(); - final byte targetDeviceId = deviceId == null ? auth.getAuthenticatedDevice().getId() : deviceId; + final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); + + final byte targetDeviceId = deviceId == null ? auth.getDeviceId() : deviceId; if (account.getDevice(targetDeviceId).isEmpty()) { throw new NotFoundException(); } - final boolean mayChangeName = auth.getAuthenticatedDevice().isPrimary() || - auth.getAuthenticatedDevice().getId() == targetDeviceId; + final boolean mayChangeName = auth.getDeviceId() == Device.PRIMARY_ID || auth.getDeviceId() == targetDeviceId; if (!mayChangeName) { throw new ForbiddenException(); @@ -221,14 +238,14 @@ public class AccountController { @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) public void setAccountAttributes( - @Mutable @Auth AuthenticatedDevice auth, + @Auth AuthenticatedDevice auth, @HeaderParam(HeaderUtils.X_SIGNAL_AGENT) String userAgent, @NotNull @Valid AccountAttributes attributes) { - final Account account = auth.getAccount(); - final byte deviceId = auth.getAuthenticatedDevice().getId(); + final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); final Account updatedAccount = accounts.update(account, a -> { - a.getDevice(deviceId).ifPresent(d -> { + a.getDevice(auth.getDeviceId()).ifPresent(d -> { d.setFetchesMessages(attributes.getFetchesMessages()); d.setName(attributes.getName()); d.setLastSeen(Util.todayInMillis()); @@ -252,8 +269,11 @@ public class AccountController { @GET @Path("/whoami") @Produces(MediaType.APPLICATION_JSON) - public AccountIdentityResponse whoAmI(@ReadOnly @Auth AuthenticatedDevice auth) { - return AccountIdentityResponseBuilder.fromAccount(auth.getAccount()); + public AccountIdentityResponse whoAmI(@Auth final AuthenticatedDevice auth) { + final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); + + return AccountIdentityResponseBuilder.fromAccount(account); } @DELETE @@ -267,8 +287,11 @@ public class AccountController { ) @ApiResponse(responseCode = "204", description = "Username successfully deleted.", useReturnTypeSchema = true) @ApiResponse(responseCode = "401", description = "Account authentication check failed.") - public CompletableFuture deleteUsernameHash(@Mutable @Auth final AuthenticatedDevice auth) { - return accounts.clearUsernameHash(auth.getAccount()) + public CompletableFuture deleteUsernameHash(@Auth final AuthenticatedDevice auth) { + final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); + + return accounts.clearUsernameHash(account) .thenApply(Util.ASYNC_EMPTY_RESPONSE); } @@ -289,10 +312,13 @@ public class AccountController { @ApiResponse(responseCode = "422", description = "Invalid request format.") @ApiResponse(responseCode = "429", description = "Ratelimited.") public CompletableFuture reserveUsernameHash( - @Mutable @Auth final AuthenticatedDevice auth, + @Auth final AuthenticatedDevice auth, @NotNull @Valid final ReserveUsernameHashRequest usernameRequest) throws RateLimitExceededException { - rateLimiters.getUsernameReserveLimiter().validate(auth.getAccount().getUuid()); + final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); + + rateLimiters.getUsernameReserveLimiter().validate(auth.getAccountIdentifier()); for (final byte[] hash : usernameRequest.usernameHashes()) { if (hash.length != USERNAME_HASH_LENGTH) { @@ -300,7 +326,7 @@ public class AccountController { } } - return accounts.reserveUsernameHash(auth.getAccount(), usernameRequest.usernameHashes()) + return accounts.reserveUsernameHash(account, usernameRequest.usernameHashes()) .thenApply(reservation -> new ReserveUsernameHashResponse(reservation.reservedUsernameHash())) .exceptionally(throwable -> { if (ExceptionUtils.unwrap(throwable) instanceof UsernameHashNotAvailableException) { @@ -329,18 +355,21 @@ public class AccountController { @ApiResponse(responseCode = "422", description = "Invalid request format.") @ApiResponse(responseCode = "429", description = "Ratelimited.") public CompletableFuture confirmUsernameHash( - @Mutable @Auth final AuthenticatedDevice auth, + @Auth final AuthenticatedDevice auth, @NotNull @Valid final ConfirmUsernameHashRequest confirmRequest) { + final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); + try { usernameHashZkProofVerifier.verifyProof(confirmRequest.zkProof(), confirmRequest.usernameHash()); } catch (final BaseUsernameException e) { throw new WebApplicationException(Response.status(422).build()); } - return rateLimiters.getUsernameSetLimiter().validateAsync(auth.getAccount().getUuid()) + return rateLimiters.getUsernameSetLimiter().validateAsync(account.getUuid()) .thenCompose(ignored -> accounts.confirmReservedUsernameHash( - auth.getAccount(), + account, confirmRequest.usernameHash(), confirmRequest.encryptedUsername())) .thenApply(updatedAccount -> new UsernameHashResponse(updatedAccount.getUsernameHash() @@ -374,7 +403,7 @@ public class AccountController { @ApiResponse(responseCode = "400", description = "Request must not be authenticated.") @ApiResponse(responseCode = "404", description = "Account not found for the given username.") public CompletableFuture lookupUsernameHash( - @ReadOnly @Auth final Optional maybeAuthenticatedAccount, + @Auth final Optional maybeAuthenticatedAccount, @PathParam("usernameHash") final String usernameHash) { requireNotAuthenticated(maybeAuthenticatedAccount); @@ -413,12 +442,14 @@ public class AccountController { @ApiResponse(responseCode = "422", description = "Invalid request format.") @ApiResponse(responseCode = "429", description = "Ratelimited.") public UsernameLinkHandle updateUsernameLink( - @Mutable @Auth final AuthenticatedDevice auth, + @Auth final AuthenticatedDevice auth, @NotNull @Valid final EncryptedUsername encryptedUsername) throws RateLimitExceededException { - // check ratelimiter for username link operations - rateLimiters.forDescriptor(RateLimiters.For.USERNAME_LINK_OPERATION).validate(auth.getAccount().getUuid()); - final Account account = auth.getAccount(); + // check ratelimiter for username link operations + rateLimiters.forDescriptor(RateLimiters.For.USERNAME_LINK_OPERATION).validate(auth.getAccountIdentifier()); + + final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); // check if username hash is set for the account if (account.getUsernameHash().isEmpty()) { @@ -431,7 +462,7 @@ public class AccountController { } else { usernameLinkHandle = UUID.randomUUID(); } - updateUsernameLink(auth.getAccount(), usernameLinkHandle, encryptedUsername.usernameLinkEncryptedValue()); + updateUsernameLink(account, usernameLinkHandle, encryptedUsername.usernameLinkEncryptedValue()); return new UsernameLinkHandle(usernameLinkHandle); } @@ -447,10 +478,14 @@ public class AccountController { @ApiResponse(responseCode = "204", description = "Username Link successfully deleted.", useReturnTypeSchema = true) @ApiResponse(responseCode = "401", description = "Account authentication check failed.") @ApiResponse(responseCode = "429", description = "Ratelimited.") - public void deleteUsernameLink(@Mutable @Auth final AuthenticatedDevice auth) throws RateLimitExceededException { + public void deleteUsernameLink(@Auth final AuthenticatedDevice auth) throws RateLimitExceededException { // check ratelimiter for username link operations - rateLimiters.forDescriptor(RateLimiters.For.USERNAME_LINK_OPERATION).validate(auth.getAccount().getUuid()); - clearUsernameLink(auth.getAccount()); + rateLimiters.forDescriptor(RateLimiters.For.USERNAME_LINK_OPERATION).validate(auth.getAccountIdentifier()); + + final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); + + clearUsernameLink(account); } @GET @@ -470,7 +505,7 @@ public class AccountController { @ApiResponse(responseCode = "422", description = "Invalid request format.") @ApiResponse(responseCode = "429", description = "Ratelimited.") public CompletableFuture lookupUsernameLink( - @ReadOnly @Auth final Optional maybeAuthenticatedAccount, + @Auth final Optional maybeAuthenticatedAccount, @PathParam("uuid") final UUID usernameLinkHandle) { requireNotAuthenticated(maybeAuthenticatedAccount); @@ -496,7 +531,7 @@ public class AccountController { @Path("/account/{identifier}") @RateLimitedByIp(RateLimiters.For.CHECK_ACCOUNT_EXISTENCE) public Response accountExists( - @ReadOnly @Auth final Optional authenticatedAccount, + @Auth final Optional authenticatedAccount, @Parameter(description = "An ACI or PNI account identifier to check") @PathParam("identifier") final ServiceIdentifier accountIdentifier) { @@ -511,8 +546,11 @@ public class AccountController { @DELETE @Path("/me") - public CompletableFuture deleteAccount(@Mutable @Auth AuthenticatedDevice auth) { - return accounts.delete(auth.getAccount(), AccountsManager.DeletionReason.USER_REQUEST).thenApply(Util.ASYNC_EMPTY_RESPONSE); + public CompletableFuture deleteAccount(@Auth AuthenticatedDevice auth) { + final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); + + return accounts.delete(account, AccountsManager.DeletionReason.USER_REQUEST).thenApply(Util.ASYNC_EMPTY_RESPONSE); } private void clearUsernameLink(final Account account) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java index 95bc58536..c069aa103 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2.java @@ -55,8 +55,7 @@ import org.whispersystems.textsecuregcm.push.MessageTooLargeException; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.ChangeNumberManager; -import org.whispersystems.websocket.auth.Mutable; -import org.whispersystems.websocket.auth.ReadOnly; +import org.whispersystems.textsecuregcm.storage.Device; @Path("/v2/accounts") @io.swagger.v3.oas.annotations.tags.Tag(name = "Account") @@ -101,12 +100,12 @@ public class AccountControllerV2 { @ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header( name = "Retry-After", description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed")) - public AccountIdentityResponse changeNumber(@Mutable @Auth final AuthenticatedDevice authenticatedDevice, + public AccountIdentityResponse changeNumber(@Auth final AuthenticatedDevice authenticatedDevice, @NotNull @Valid final ChangeNumberRequest request, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgentString, @Context final ContainerRequestContext requestContext) throws RateLimitExceededException, InterruptedException { - if (!authenticatedDevice.getAuthenticatedDevice().isPrimary()) { + if (authenticatedDevice.getDeviceId() != Device.PRIMARY_ID) { throw new ForbiddenException(); } @@ -116,8 +115,11 @@ public class AccountControllerV2 { final String number = request.number(); + final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + // Only verify and check reglock if there's a data change to be made... - if (!authenticatedDevice.getAccount().getNumber().equals(number)) { + if (!account.getNumber().equals(number)) { rateLimiters.getRegistrationLimiter().validate(number); @@ -139,7 +141,7 @@ public class AccountControllerV2 { // ...but always attempt to make the change in case a client retries and needs to re-send messages try { final Account updatedAccount = changeNumberManager.changeNumber( - authenticatedDevice.getAccount(), + account, request.number(), request.pniIdentityKey(), request.devicePniSignedPrekeys(), @@ -185,11 +187,11 @@ public class AccountControllerV2 { content = @Content(schema = @Schema(implementation = StaleDevicesResponse.class))) @ApiResponse(responseCode = "413", description = "One or more device messages was too large") public AccountIdentityResponse distributePhoneNumberIdentityKeys( - @Mutable @Auth final AuthenticatedDevice authenticatedDevice, + @Auth final AuthenticatedDevice authenticatedDevice, @HeaderParam(HttpHeaders.USER_AGENT) @Nullable final String userAgentString, @NotNull @Valid final PhoneNumberIdentityKeyDistributionRequest request) { - if (!authenticatedDevice.getAuthenticatedDevice().isPrimary()) { + if (authenticatedDevice.getDeviceId() != Device.PRIMARY_ID) { throw new ForbiddenException(); } @@ -197,9 +199,12 @@ public class AccountControllerV2 { throw new WebApplicationException("Invalid signature", 422); } + final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + try { final Account updatedAccount = changeNumberManager.updatePniKeys( - authenticatedDevice.getAccount(), + account, request.pniIdentityKey(), request.devicePniSignedPrekeys(), request.devicePniPqLastResortPrekeys(), @@ -235,10 +240,13 @@ public class AccountControllerV2 { @Operation(summary = "Sets whether the account should be discoverable by phone number in the directory.") @ApiResponse(responseCode = "204", description = "The setting was successfully updated.") public void setPhoneNumberDiscoverability( - @Mutable @Auth AuthenticatedDevice auth, - @NotNull @Valid PhoneNumberDiscoverabilityRequest phoneNumberDiscoverability - ) { - accountsManager.update(auth.getAccount(), a -> a.setDiscoverableByPhoneNumber( + @Auth AuthenticatedDevice auth, + @NotNull @Valid PhoneNumberDiscoverabilityRequest phoneNumberDiscoverability) { + + final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + + accountsManager.update(account, a -> a.setDiscoverableByPhoneNumber( phoneNumberDiscoverability.discoverableByPhoneNumber())); } @@ -249,9 +257,10 @@ public class AccountControllerV2 { @ApiResponse(responseCode = "200", description = "Response with data report. A plain text representation is a field in the response.", useReturnTypeSchema = true) - public AccountDataReportResponse getAccountDataReport(@ReadOnly @Auth final AuthenticatedDevice auth) { + public AccountDataReportResponse getAccountDataReport(@Auth final AuthenticatedDevice auth) { - final Account account = auth.getAccount(); + final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); return new AccountDataReportResponse(UUID.randomUUID(), Instant.now(), new AccountDataReportResponse.AccountAndDevicesDataReport( diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ArchiveController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ArchiveController.java index d1411d721..28a34ca3b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ArchiveController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ArchiveController.java @@ -37,6 +37,7 @@ import jakarta.ws.rs.PUT; import jakarta.ws.rs.Path; import jakarta.ws.rs.Produces; import jakarta.ws.rs.QueryParam; +import jakarta.ws.rs.WebApplicationException; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; import java.io.IOException; @@ -71,14 +72,15 @@ import org.whispersystems.textsecuregcm.backup.MediaEncryptionParameters; import org.whispersystems.textsecuregcm.entities.RemoteAttachment; import org.whispersystems.textsecuregcm.metrics.BackupMetrics; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; +import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.BackupAuthCredentialAdapter; import org.whispersystems.textsecuregcm.util.ByteArrayAdapter; import org.whispersystems.textsecuregcm.util.ByteArrayBase64UrlAdapter; import org.whispersystems.textsecuregcm.util.ECPublicKeyAdapter; import org.whispersystems.textsecuregcm.util.ExactlySize; import org.whispersystems.textsecuregcm.util.Util; -import org.whispersystems.websocket.auth.Mutable; -import org.whispersystems.websocket.auth.ReadOnly; import reactor.core.publisher.Mono; @Path("/v1/archives") @@ -88,14 +90,18 @@ public class ArchiveController { public final static String X_SIGNAL_ZK_AUTH = "X-Signal-ZK-Auth"; public final static String X_SIGNAL_ZK_AUTH_SIGNATURE = "X-Signal-ZK-Auth-Signature"; + private final AccountsManager accountsManager; private final BackupAuthManager backupAuthManager; private final BackupManager backupManager; private final BackupMetrics backupMetrics; public ArchiveController( + final AccountsManager accountsManager, final BackupAuthManager backupAuthManager, final BackupManager backupManager, final BackupMetrics backupMetrics) { + + this.accountsManager = accountsManager; this.backupAuthManager = backupAuthManager; this.backupManager = backupManager; this.backupMetrics = backupMetrics; @@ -138,13 +144,22 @@ public class ArchiveController { @ApiResponse(responseCode = "403", description = "The device did not have permission to set the backup-id. Only the primary device can set the backup-id for an account") @ApiResponse(responseCode = "429", description = "Rate limited. Too many attempts to change the backup-id have been made") public CompletionStage setBackupId( - @Mutable @Auth final AuthenticatedDevice account, + @Auth final AuthenticatedDevice authenticatedDevice, @Valid @NotNull final SetBackupIdRequest setBackupIdRequest) throws RateLimitExceededException { - return this.backupAuthManager - .commitBackupId(account.getAccount(), account.getAuthenticatedDevice(), - setBackupIdRequest.messagesBackupAuthCredentialRequest, - setBackupIdRequest.mediaBackupAuthCredentialRequest) - .thenApply(Util.ASYNC_EMPTY_RESPONSE); + + return accountsManager.getByAccountIdentifierAsync(authenticatedDevice.getAccountIdentifier()) + .thenCompose(maybeAccount -> { + final Account account = maybeAccount + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + + final Device device = account.getDevice(authenticatedDevice.getDeviceId()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + + return backupAuthManager + .commitBackupId(account, device, setBackupIdRequest.messagesBackupAuthCredentialRequest, + setBackupIdRequest.mediaBackupAuthCredentialRequest) + .thenApply(Util.ASYNC_EMPTY_RESPONSE); + }); } public record RedeemBackupReceiptRequest( @@ -188,12 +203,17 @@ public class ArchiveController { @ApiResponse(responseCode = "409", description = "The target account does not have a backup-id commitment") @ApiResponse(responseCode = "429", description = "Rate limited.") public CompletionStage redeemReceipt( - @Mutable @Auth final AuthenticatedDevice account, + @Auth final AuthenticatedDevice authenticatedDevice, @Valid @NotNull final RedeemBackupReceiptRequest redeemBackupReceiptRequest) { - return this.backupAuthManager.redeemReceipt( - account.getAccount(), - redeemBackupReceiptRequest.receiptCredentialPresentation()) - .thenApply(Util.ASYNC_EMPTY_RESPONSE); + + return accountsManager.getByAccountIdentifierAsync(authenticatedDevice.getAccountIdentifier()) + .thenCompose(maybeAccount -> { + final Account account = maybeAccount + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + + return backupAuthManager.redeemReceipt(account, redeemBackupReceiptRequest.receiptCredentialPresentation()) + .thenApply(Util.ASYNC_EMPTY_RESPONSE); + }); } public record BackupAuthCredentialsResponse( @@ -252,7 +272,7 @@ public class ArchiveController { @ApiResponse(responseCode = "404", description = "Could not find an existing blinded backup id") @ApiResponse(responseCode = "429", description = "Rate limited.") public CompletionStage getBackupZKCredentials( - @Mutable @Auth AuthenticatedDevice auth, + @Auth AuthenticatedDevice authenticatedDevice, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @NotNull @QueryParam("redemptionStartSeconds") Long startSeconds, @NotNull @QueryParam("redemptionEndSeconds") Long endSeconds) { @@ -260,27 +280,33 @@ public class ArchiveController { final Map> credentialsByType = new ConcurrentHashMap<>(); - return CompletableFuture.allOf(Arrays.stream(BackupCredentialType.values()) - .map(credentialType -> this.backupAuthManager.getBackupAuthCredentials( - auth.getAccount(), - credentialType, - Instant.ofEpochSecond(startSeconds), Instant.ofEpochSecond(endSeconds)) - .thenAccept(credentials -> { - backupMetrics.updateGetCredentialCounter( - UserAgentTagUtil.getPlatformTag(userAgent), - credentialType, - credentials.size()); - credentialsByType.put(credentialType, credentials.stream() - .map(credential -> new BackupAuthCredentialsResponse.BackupAuthCredential( - credential.credential().serialize(), - credential.redemptionTime().getEpochSecond())) - .toList()); - })) - .toArray(CompletableFuture[]::new)) - .thenApply(ignored -> new BackupAuthCredentialsResponse(credentialsByType.entrySet().stream() - .collect(Collectors.toMap( - e -> BackupAuthCredentialsResponse.CredentialType.fromLibsignalType(e.getKey()), - Map.Entry::getValue)))); + return accountsManager.getByAccountIdentifierAsync(authenticatedDevice.getAccountIdentifier()) + .thenCompose(maybeAccount -> { + final Account account = maybeAccount + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + + return CompletableFuture.allOf(Arrays.stream(BackupCredentialType.values()) + .map(credentialType -> this.backupAuthManager.getBackupAuthCredentials( + account, + credentialType, + Instant.ofEpochSecond(startSeconds), Instant.ofEpochSecond(endSeconds)) + .thenAccept(credentials -> { + backupMetrics.updateGetCredentialCounter( + UserAgentTagUtil.getPlatformTag(userAgent), + credentialType, + credentials.size()); + credentialsByType.put(credentialType, credentials.stream() + .map(credential -> new BackupAuthCredentialsResponse.BackupAuthCredential( + credential.credential().serialize(), + credential.redemptionTime().getEpochSecond())) + .toList()); + })) + .toArray(CompletableFuture[]::new)) + .thenApply(ignored -> new BackupAuthCredentialsResponse(credentialsByType.entrySet().stream() + .collect(Collectors.toMap( + e -> BackupAuthCredentialsResponse.CredentialType.fromLibsignalType(e.getKey()), + Map.Entry::getValue)))); + }); } @@ -343,7 +369,7 @@ public class ArchiveController { @ApiResponse(responseCode = "429", description = "Rate limited.") @ApiResponseZkAuth public CompletionStage readAuth( - @ReadOnly @Auth final Optional account, + @Auth final Optional account, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class)) @@ -395,7 +421,7 @@ public class ArchiveController { @ApiResponse(responseCode = "429", description = "Rate limited.") @ApiResponseZkAuth public CompletionStage backupInfo( - @ReadOnly @Auth final Optional account, + @Auth final Optional account, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class)) @@ -441,7 +467,7 @@ public class ArchiveController { @ApiResponse(responseCode = "204", description = "The public key was set") @ApiResponse(responseCode = "429", description = "Rate limited.") public CompletionStage setPublicKey( - @ReadOnly @Auth final Optional account, + @Auth final Optional account, @Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class)) @NotNull @@ -481,7 +507,7 @@ public class ArchiveController { @ApiResponse(responseCode = "429", description = "Rate limited.") @ApiResponseZkAuth public CompletionStage backup( - @ReadOnly @Auth final Optional account, + @Auth final Optional account, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class)) @@ -518,7 +544,7 @@ public class ArchiveController { @ApiResponse(responseCode = "429", description = "Rate limited.") @ApiResponseZkAuth public CompletionStage uploadTemporaryAttachment( - @ReadOnly @Auth final Optional account, + @Auth final Optional account, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @@ -606,7 +632,7 @@ public class ArchiveController { @ApiResponse(responseCode = "429", description = "Rate limited.") @ApiResponseZkAuth public CompletionStage copyMedia( - @ReadOnly @Auth final Optional account, + @Auth final Optional account, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class)) @@ -705,7 +731,7 @@ public class ArchiveController { @ApiResponse(responseCode = "429", description = "Rate limited.") @ApiResponseZkAuth public CompletionStage copyMedia( - @ReadOnly @Auth final Optional account, + @Auth final Optional account, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class)) @@ -744,7 +770,7 @@ public class ArchiveController { @ApiResponse(responseCode = "429", description = "Rate limited.") @ApiResponseZkAuth public CompletionStage refresh( - @ReadOnly @Auth final Optional account, + @Auth final Optional account, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class)) @@ -811,7 +837,7 @@ public class ArchiveController { @ApiResponse(responseCode = "429", description = "Rate limited.") @ApiResponseZkAuth public CompletionStage listMedia( - @ReadOnly @Auth final Optional account, + @Auth final Optional account, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class)) @@ -867,7 +893,7 @@ public class ArchiveController { @ApiResponse(responseCode = "429", description = "Rate limited.") @ApiResponseZkAuth public CompletionStage deleteMedia( - @ReadOnly @Auth final Optional account, + @Auth final Optional account, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class)) @@ -904,7 +930,7 @@ public class ArchiveController { @ApiResponse(responseCode = "429", description = "Rate limited.") @ApiResponseZkAuth public CompletionStage deleteBackup( - @ReadOnly @Auth final Optional account, + @Auth final Optional account, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @Parameter(description = BackupAuthCredentialPresentationHeader.DESCRIPTION, schema = @Schema(implementation = String.class)) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AttachmentControllerV4.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AttachmentControllerV4.java index 2523c5a0f..5ea8baf5d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AttachmentControllerV4.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AttachmentControllerV4.java @@ -26,7 +26,6 @@ import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV3; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; -import org.whispersystems.websocket.auth.ReadOnly; /** @@ -78,11 +77,11 @@ public class AttachmentControllerV4 { @ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header( name = "Retry-After", description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed")) - public AttachmentDescriptorV3 getAttachmentUploadForm(@ReadOnly @Auth AuthenticatedDevice auth) + public AttachmentDescriptorV3 getAttachmentUploadForm(@Auth AuthenticatedDevice auth) throws RateLimitExceededException { - rateLimiter.validate(auth.getAccount().getUuid()); + rateLimiter.validate(auth.getAccountIdentifier()); final String key = generateAttachmentKey(); - final boolean useCdn3 = this.experimentEnrollmentManager.isEnrolled(auth.getAccount().getUuid(), CDN3_EXPERIMENT_NAME); + final boolean useCdn3 = this.experimentEnrollmentManager.isEnrolled(auth.getAccountIdentifier(), CDN3_EXPERIMENT_NAME); int cdn = useCdn3 ? 3 : 2; final AttachmentGenerator.Descriptor descriptor = this.attachmentGenerators.get(cdn).generateAttachment(key); return new AttachmentDescriptorV3(cdn, key, descriptor.headers(), descriptor.signedUploadLocation()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallLinkController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallLinkController.java index ff698fa5b..edc884286 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallLinkController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallLinkController.java @@ -20,7 +20,6 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.entities.CreateCallLinkCredential; import org.whispersystems.textsecuregcm.entities.GetCreateCallLinkCredentialsRequest; import org.whispersystems.textsecuregcm.limits.RateLimiters; -import org.whispersystems.websocket.auth.ReadOnly; @Path("/v1/call-link") @io.swagger.v3.oas.annotations.tags.Tag(name = "CallLink") @@ -52,11 +51,11 @@ public class CallLinkController { @ApiResponse(responseCode = "422", description = "Invalid request format.") @ApiResponse(responseCode = "429", description = "Ratelimited.") public CreateCallLinkCredential getCreateAuth( - final @ReadOnly @Auth AuthenticatedDevice auth, + final @Auth AuthenticatedDevice auth, final @NotNull @Valid GetCreateCallLinkCredentialsRequest request ) throws RateLimitExceededException { - rateLimiters.getCreateCallLinkLimiter().validate(auth.getAccount().getUuid()); + rateLimiters.getCreateCallLinkLimiter().validate(auth.getAccountIdentifier()); final Instant truncatedDayTimestamp = Instant.now().truncatedTo(ChronoUnit.DAYS); @@ -68,7 +67,7 @@ public class CallLinkController { } return new CreateCallLinkCredential( - createCallLinkCredentialRequest.issueCredential(new ServiceId.Aci(auth.getAccount().getUuid()), truncatedDayTimestamp, genericServerSecretParams).serialize(), + createCallLinkCredentialRequest.issueCredential(new ServiceId.Aci(auth.getAccountIdentifier()), truncatedDayTimestamp, genericServerSecretParams).serialize(), truncatedDayTimestamp.getEpochSecond() ); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallRoutingControllerV2.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallRoutingControllerV2.java index e90549934..d3762e42a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallRoutingControllerV2.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallRoutingControllerV2.java @@ -18,11 +18,9 @@ import jakarta.ws.rs.Produces; import jakarta.ws.rs.core.MediaType; import java.io.IOException; import java.util.List; -import java.util.UUID; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.CloudflareTurnCredentialsManager; import org.whispersystems.textsecuregcm.limits.RateLimiters; -import org.whispersystems.websocket.auth.ReadOnly; @io.swagger.v3.oas.annotations.tags.Tag(name = "Calling") @Path("/v2/calling") @@ -56,11 +54,10 @@ public class CallRoutingControllerV2 { @ApiResponse(responseCode = "401", description = "Account authentication check failed.") @ApiResponse(responseCode = "422", description = "Invalid request format.") @ApiResponse(responseCode = "429", description = "Rate limited.") - public GetCallingRelaysResponse getCallingRelays(final @ReadOnly @Auth AuthenticatedDevice auth) + public GetCallingRelaysResponse getCallingRelays(final @Auth AuthenticatedDevice auth) throws RateLimitExceededException, IOException { - final UUID aci = auth.getAccount().getUuid(); - rateLimiters.getCallEndpointLimiter().validate(aci); + rateLimiters.getCallEndpointLimiter().validate(auth.getAccountIdentifier()); try { return new GetCallingRelaysResponse(List.of(cloudflareTurnCredentialsManager.retrieveFromCloudflare())); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CertificateController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CertificateController.java index 508a2aea6..fa71cc5b5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CertificateController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CertificateController.java @@ -8,18 +8,18 @@ package org.whispersystems.textsecuregcm.controllers; import static com.codahale.metrics.MetricRegistry.name; import com.google.common.annotations.VisibleForTesting; -import com.google.common.net.HttpHeaders; import io.dropwizard.auth.Auth; import io.micrometer.core.instrument.Metrics; import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.ws.rs.BadRequestException; import jakarta.ws.rs.DefaultValue; import jakarta.ws.rs.GET; -import jakarta.ws.rs.HeaderParam; import jakarta.ws.rs.Path; import jakarta.ws.rs.Produces; import jakarta.ws.rs.QueryParam; +import jakarta.ws.rs.WebApplicationException; import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.security.InvalidKeyException; import java.time.Clock; import java.time.Duration; @@ -38,13 +38,16 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.CertificateGenerator; import org.whispersystems.textsecuregcm.entities.DeliveryCertificate; import org.whispersystems.textsecuregcm.entities.GroupCredentials; -import org.whispersystems.websocket.auth.ReadOnly; +import org.whispersystems.textsecuregcm.identity.IdentityType; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @Path("/v1/certificate") @Tag(name = "Certificate") public class CertificateController { + private final AccountsManager accountsManager; private final CertificateGenerator certificateGenerator; private final ServerZkAuthOperations serverZkAuthOperations; private final GenericServerSecretParams genericServerSecretParams; @@ -56,10 +59,13 @@ public class CertificateController { private static final String INCLUDE_E164_TAG_NAME = "includeE164"; public CertificateController( + final AccountsManager accountsManager, @Nonnull CertificateGenerator certificateGenerator, @Nonnull ServerZkAuthOperations serverZkAuthOperations, @Nonnull GenericServerSecretParams genericServerSecretParams, @Nonnull Clock clock) { + + this.accountsManager = accountsManager; this.certificateGenerator = Objects.requireNonNull(certificateGenerator); this.serverZkAuthOperations = Objects.requireNonNull(serverZkAuthOperations); this.genericServerSecretParams = genericServerSecretParams; @@ -69,23 +75,25 @@ public class CertificateController { @GET @Produces(MediaType.APPLICATION_JSON) @Path("/delivery") - public DeliveryCertificate getDeliveryCertificate(@ReadOnly @Auth AuthenticatedDevice auth, + public DeliveryCertificate getDeliveryCertificate(@Auth AuthenticatedDevice auth, @QueryParam("includeE164") @DefaultValue("true") boolean includeE164) throws InvalidKeyException { Metrics.counter(GENERATE_DELIVERY_CERTIFICATE_COUNTER_NAME, INCLUDE_E164_TAG_NAME, String.valueOf(includeE164)) .increment(); + final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + return new DeliveryCertificate( - certificateGenerator.createFor(auth.getAccount(), auth.getAuthenticatedDevice(), includeE164)); + certificateGenerator.createFor(account, auth.getDeviceId(), includeE164)); } @GET @Produces(MediaType.APPLICATION_JSON) @Path("/auth/group") public GroupCredentials getGroupAuthenticationCredentials( - @ReadOnly @Auth AuthenticatedDevice auth, - @HeaderParam(HttpHeaders.USER_AGENT) String userAgent, + @Auth AuthenticatedDevice auth, @QueryParam("redemptionStartSeconds") long startSeconds, @QueryParam("redemptionEndSeconds") long endSeconds) { @@ -102,13 +110,16 @@ public class CertificateController { throw new BadRequestException(); } + final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + final List credentials = new ArrayList<>(); final List callLinkAuthCredentials = new ArrayList<>(); Instant redemption = redemptionStart; - ServiceId.Aci aci = new ServiceId.Aci(auth.getAccount().getUuid()); - ServiceId.Pni pni = new ServiceId.Pni(auth.getAccount().getPhoneNumberIdentifier()); + final ServiceId.Aci aci = new ServiceId.Aci(account.getIdentifier(IdentityType.ACI)); + final ServiceId.Pni pni = new ServiceId.Pni(account.getIdentifier(IdentityType.PNI)); while (!redemption.isAfter(redemptionEnd)) { AuthCredentialWithPniResponse authCredentialWithPni = serverZkAuthOperations.issueAuthCredentialWithPniZkc(aci, pni, redemption); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ChallengeController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ChallengeController.java index 12cc974a3..12a2b772f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ChallengeController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ChallengeController.java @@ -25,6 +25,7 @@ import jakarta.ws.rs.POST; import jakarta.ws.rs.PUT; import jakarta.ws.rs.Path; import jakarta.ws.rs.Produces; +import jakarta.ws.rs.WebApplicationException; import jakarta.ws.rs.container.ContainerRequestContext; import jakarta.ws.rs.core.Context; import jakarta.ws.rs.core.MediaType; @@ -40,12 +41,14 @@ import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker; import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker.ChallengeConstraints; -import org.whispersystems.websocket.auth.ReadOnly; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; @Path("/v1/challenge") @Tag(name = "Challenge") public class ChallengeController { + private final AccountsManager accountsManager; private final RateLimitChallengeManager rateLimitChallengeManager; private final ChallengeConstraintChecker challengeConstraintChecker; @@ -53,8 +56,10 @@ public class ChallengeController { private static final String CHALLENGE_TYPE_TAG = "type"; public ChallengeController( + final AccountsManager accountsManager, final RateLimitChallengeManager rateLimitChallengeManager, final ChallengeConstraintChecker challengeConstraintChecker) { + this.accountsManager = accountsManager; this.rateLimitChallengeManager = rateLimitChallengeManager; this.challengeConstraintChecker = challengeConstraintChecker; } @@ -77,15 +82,18 @@ public class ChallengeController { @ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header( name = "Retry-After", description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed")) - public Response handleChallengeResponse(@ReadOnly @Auth final AuthenticatedDevice auth, + public Response handleChallengeResponse(@Auth final AuthenticatedDevice auth, @Valid final AnswerChallengeRequest answerRequest, @Context ContainerRequestContext requestContext, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) throws RateLimitExceededException, IOException { + final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)); final ChallengeConstraints constraints = challengeConstraintChecker.challengeConstraints( - requestContext, auth.getAccount()); + requestContext, account); try { if (answerRequest instanceof final AnswerPushChallengeRequest pushChallengeRequest) { tags = tags.and(CHALLENGE_TYPE_TAG, "push"); @@ -93,14 +101,14 @@ public class ChallengeController { if (!constraints.pushPermitted()) { return Response.status(429).build(); } - rateLimitChallengeManager.answerPushChallenge(auth.getAccount(), pushChallengeRequest.getChallenge()); + rateLimitChallengeManager.answerPushChallenge(account, pushChallengeRequest.getChallenge()); } else if (answerRequest instanceof AnswerCaptchaChallengeRequest captchaChallengeRequest) { tags = tags.and(CHALLENGE_TYPE_TAG, "captcha"); final String remoteAddress = (String) requestContext.getProperty( RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); boolean success = rateLimitChallengeManager.answerCaptchaChallenge( - auth.getAccount(), + account, captchaChallengeRequest.getCaptcha(), remoteAddress, userAgent, @@ -126,7 +134,7 @@ public class ChallengeController { summary = "Request a push challenge", description = """ Clients may proactively request a push challenge by making an empty POST request. Push challenges will only be - sent to the requesting account’s main device. When the push is received it may be provided as proof of completed + sent to the requesting account’s main device. When the push is received it may be provided as proof of completed challenge to /v1/challenge. APNs challenge payloads will be formatted as follows: ``` @@ -140,12 +148,12 @@ public class ChallengeController { "rateLimitChallenge": "{CHALLENGE_TOKEN}" } ``` - FCM challenge payloads will be formatted as follows: + FCM challenge payloads will be formatted as follows: ``` {"rateLimitChallenge": "{CHALLENGE_TOKEN}"} ``` - Clients may retry the PUT in the event of an HTTP/5xx response (except HTTP/508) from the server, but must + Clients may retry the PUT in the event of an HTTP/5xx response (except HTTP/508) from the server, but must implement an exponential back-off system and limit the total number of retries. """ ) @@ -163,15 +171,18 @@ public class ChallengeController { @ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header( name = "Retry-After", description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed")) - public Response requestPushChallenge(@ReadOnly @Auth final AuthenticatedDevice auth, + public Response requestPushChallenge(@Auth final AuthenticatedDevice auth, @Context ContainerRequestContext requestContext) { - final ChallengeConstraints constraints = challengeConstraintChecker.challengeConstraints( - requestContext, auth.getAccount()); + + final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + + final ChallengeConstraints constraints = challengeConstraintChecker.challengeConstraints(requestContext, account); if (!constraints.pushPermitted()) { return Response.status(429).build(); } try { - rateLimitChallengeManager.sendPushChallenge(auth.getAccount()); + rateLimitChallengeManager.sendPushChallenge(account); return Response.status(200).build(); } catch (final NotPushRegisteredException e) { return Response.status(404).build(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceCheckController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceCheckController.java index 52605edc1..9eaec898f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceCheckController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceCheckController.java @@ -33,6 +33,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.backup.BackupAuthManager; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.devicecheck.AppleDeviceCheckManager; import org.whispersystems.textsecuregcm.storage.devicecheck.ChallengeNotFoundException; import org.whispersystems.textsecuregcm.storage.devicecheck.DeviceCheckKeyIdNotFoundException; @@ -41,7 +42,6 @@ import org.whispersystems.textsecuregcm.storage.devicecheck.DuplicatePublicKeyEx import org.whispersystems.textsecuregcm.storage.devicecheck.RequestReuseException; import org.whispersystems.textsecuregcm.storage.devicecheck.TooManyKeysException; import org.whispersystems.textsecuregcm.util.SystemMapper; -import org.whispersystems.websocket.auth.ReadOnly; /** * Process platform device attestations. @@ -55,6 +55,7 @@ import org.whispersystems.websocket.auth.ReadOnly; public class DeviceCheckController { private final Clock clock; + private final AccountsManager accountsManager; private final BackupAuthManager backupAuthManager; private final AppleDeviceCheckManager deviceCheckManager; private final RateLimiters rateLimiters; @@ -63,12 +64,14 @@ public class DeviceCheckController { public DeviceCheckController( final Clock clock, + final AccountsManager accountsManager, final BackupAuthManager backupAuthManager, final AppleDeviceCheckManager deviceCheckManager, final RateLimiters rateLimiters, final long backupRedemptionLevel, final Duration backupRedemptionDuration) { this.clock = clock; + this.accountsManager = accountsManager; this.backupAuthManager = backupAuthManager; this.deviceCheckManager = deviceCheckManager; this.backupRedemptionLevel = backupRedemptionLevel; @@ -94,14 +97,17 @@ public class DeviceCheckController { @ApiResponse(responseCode = "200", description = "The response body includes a challenge") @ApiResponse(responseCode = "429", description = "Ratelimited.") @ManagedAsync - public ChallengeResponse attestChallenge(@ReadOnly @Auth AuthenticatedDevice authenticatedDevice) + public ChallengeResponse attestChallenge(@Auth AuthenticatedDevice authenticatedDevice) throws RateLimitExceededException { rateLimiters.forDescriptor(RateLimiters.For.DEVICE_CHECK_CHALLENGE) - .validate(authenticatedDevice.getAccount().getUuid()); + .validate(authenticatedDevice.getAccountIdentifier()); + + final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); return new ChallengeResponse(deviceCheckManager.createChallenge( AppleDeviceCheckManager.ChallengeType.ATTEST, - authenticatedDevice.getAccount())); + account)); } @PUT @@ -125,7 +131,7 @@ public class DeviceCheckController { @ApiResponse(responseCode = "409", description = "The provided keyId has already been registered to a different account") @ManagedAsync public void attest( - @ReadOnly @Auth final AuthenticatedDevice authenticatedDevice, + @Auth final AuthenticatedDevice authenticatedDevice, @Valid @NotNull @@ -135,8 +141,11 @@ public class DeviceCheckController { @RequestBody(description = "The attestation data, created by [attestKey](https://developer.apple.com/documentation/devicecheck/dcappattestservice/attestkey(_:clientdatahash:completionhandler:))") @NotNull final byte[] attestation) { + final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + try { - deviceCheckManager.registerAttestation(authenticatedDevice.getAccount(), parseKeyId(keyId), attestation); + deviceCheckManager.registerAttestation(account, parseKeyId(keyId), attestation); } catch (TooManyKeysException e) { throw new WebApplicationException(Response.status(413).build()); } catch (ChallengeNotFoundException e) { @@ -166,17 +175,19 @@ public class DeviceCheckController { @ApiResponse(responseCode = "429", description = "Ratelimited.") @ManagedAsync public ChallengeResponse assertChallenge( - @ReadOnly @Auth AuthenticatedDevice authenticatedDevice, + @Auth AuthenticatedDevice authenticatedDevice, @Parameter(schema = @Schema(description = "The type of action you will make an assertion for", allowableValues = {"backup"}, implementation = String.class)) @QueryParam("action") Action action) throws RateLimitExceededException { rateLimiters.forDescriptor(RateLimiters.For.DEVICE_CHECK_CHALLENGE) - .validate(authenticatedDevice.getAccount().getUuid()); - return new ChallengeResponse( - deviceCheckManager.createChallenge(toChallengeType(action), - authenticatedDevice.getAccount())); + .validate(authenticatedDevice.getAccountIdentifier()); + + final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + + return new ChallengeResponse(deviceCheckManager.createChallenge(toChallengeType(action), account)); } @POST @@ -199,7 +210,7 @@ public class DeviceCheckController { @ApiResponse(responseCode = "401", description = "The assertion could not be verified") @ManagedAsync public void assertion( - @ReadOnly @Auth final AuthenticatedDevice authenticatedDevice, + @Auth final AuthenticatedDevice authenticatedDevice, @Valid @NotNull @@ -218,9 +229,12 @@ public class DeviceCheckController { @RequestBody(description = "The assertion created by [generateAssertion](https://developer.apple.com/documentation/devicecheck/dcappattestservice/generateassertion(_:clientdatahash:completionhandler:))") @NotNull final byte[] assertion) { + final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + try { deviceCheckManager.validateAssert( - authenticatedDevice.getAccount(), + account, parseKeyId(keyId), toChallengeType(request.assertionRequest().action()), request.assertionRequest().challenge(), @@ -237,7 +251,7 @@ public class DeviceCheckController { // The request assertion was validated, execute it switch (request.assertionRequest().action()) { case BACKUP -> backupAuthManager.extendBackupVoucher( - authenticatedDevice.getAccount(), + account, new Account.BackupVoucher(backupRedemptionLevel, clock.instant().plus(backupRedemptionDuration))) .join(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index 657ebd9e1..0a050ce6c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -34,7 +34,6 @@ import jakarta.ws.rs.PathParam; import jakarta.ws.rs.Produces; import jakarta.ws.rs.QueryParam; import jakarta.ws.rs.WebApplicationException; -import jakarta.ws.rs.core.Context; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; import java.time.Duration; @@ -51,11 +50,9 @@ import java.util.concurrent.CompletionStage; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import javax.annotation.Nullable; -import org.glassfish.jersey.server.ContainerRequest; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader; import org.whispersystems.textsecuregcm.auth.ChangesLinkedDevices; -import org.whispersystems.textsecuregcm.auth.LinkedDeviceRefreshRequirementProvider; import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest; import org.whispersystems.textsecuregcm.entities.DeviceInfo; @@ -87,11 +84,10 @@ import org.whispersystems.textsecuregcm.util.DeviceCapabilityAdapter; import org.whispersystems.textsecuregcm.util.EnumMapUtil; import org.whispersystems.textsecuregcm.util.ExceptionUtils; import org.whispersystems.textsecuregcm.util.LinkDeviceToken; +import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; -import org.whispersystems.websocket.auth.Mutable; -import org.whispersystems.websocket.auth.ReadOnly; @Path("/v1/devices") @Tag(name = "Devices") @@ -152,10 +148,10 @@ public class DeviceController { @GET @Produces(MediaType.APPLICATION_JSON) - public DeviceInfoList getDevices(@ReadOnly @Auth AuthenticatedDevice auth) { + public DeviceInfoList getDevices(@Auth AuthenticatedDevice auth) { // Devices may change their own names (and primary devices may change the names of linked devices) and so the device // state associated with the authenticated account may be stale. Fetch a fresh copy to compensate. - return accounts.getByAccountIdentifier(auth.getAccount().getIdentifier(IdentityType.ACI)) + return accounts.getByAccountIdentifier(auth.getAccountIdentifier()) .map(account -> new DeviceInfoList(account.getDevices().stream() .map(DeviceInfo::forDevice) .toList())) @@ -166,9 +162,8 @@ public class DeviceController { @Produces(MediaType.APPLICATION_JSON) @Path("/{device_id}") @ChangesLinkedDevices - public void removeDevice(@Mutable @Auth AuthenticatedDevice auth, @PathParam("device_id") byte deviceId) { - if (auth.getAuthenticatedDevice().getId() != Device.PRIMARY_ID && - auth.getAuthenticatedDevice().getId() != deviceId) { + public void removeDevice(@Auth AuthenticatedDevice auth, @PathParam("device_id") byte deviceId) { + if (auth.getDeviceId() != Device.PRIMARY_ID && auth.getDeviceId() != deviceId) { throw new WebApplicationException(Response.Status.UNAUTHORIZED); } @@ -176,13 +171,16 @@ public class DeviceController { throw new ForbiddenException(); } - accounts.removeDevice(auth.getAccount(), deviceId).join(); + final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + + accounts.removeDevice(account, deviceId).join(); } /** * Generates a signed device-linking token. Generally, primary devices will include the signed device-linking token in * a provisioning message to a new device, and then the new device will include the token in its request to - * {@link #linkDevice(BasicAuthorizationHeader, String, LinkDeviceRequest, ContainerRequest)}. + * {@link #linkDevice(BasicAuthorizationHeader, String, LinkDeviceRequest)}. * * @param auth the authenticated account/device * @@ -207,10 +205,11 @@ public class DeviceController { @ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header( name = "Retry-After", description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed")) - public LinkDeviceToken createDeviceToken(@ReadOnly @Auth AuthenticatedDevice auth) + public LinkDeviceToken createDeviceToken(@Auth AuthenticatedDevice auth) throws RateLimitExceededException, DeviceLimitExceededException { - final Account account = auth.getAccount(); + final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); rateLimiters.getAllocateDeviceLimiter().validate(account.getUuid()); @@ -224,7 +223,7 @@ public class DeviceController { throw new DeviceLimitExceededException(account.getDevices().size(), maxDeviceLimit); } - if (auth.getAuthenticatedDevice().getId() != Device.PRIMARY_ID) { + if (auth.getDeviceId() != Device.PRIMARY_ID) { throw new WebApplicationException(Response.Status.UNAUTHORIZED); } @@ -252,8 +251,7 @@ public class DeviceController { description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed")) public LinkDeviceResponse linkDevice(@HeaderParam(HttpHeaders.AUTHORIZATION) BasicAuthorizationHeader authorizationHeader, @HeaderParam(HttpHeaders.USER_AGENT) @Nullable String userAgent, - @NotNull @Valid LinkDeviceRequest linkDeviceRequest, - @Context ContainerRequest containerRequest) + @NotNull @Valid LinkDeviceRequest linkDeviceRequest) throws RateLimitExceededException, DeviceLimitExceededException { final Account account = accounts.checkDeviceLinkingToken(linkDeviceRequest.verificationCode()) @@ -279,11 +277,6 @@ public class DeviceController { throw new WebApplicationException(Response.status(422).build()); } - // Normally, the "do we need to refresh somebody's websockets" listener can do this on its own. In this case, - // we're not using the conventional authentication system, and so we need to give it a hint so it knows who the - // active user is and what their device states look like. - LinkedDeviceRefreshRequirementProvider.setAccount(containerRequest, account); - final int maxDeviceLimit = maxDeviceConfiguration.getOrDefault(account.getNumber(), MAX_DEVICES); if (account.getDevices().size() >= maxDeviceLimit) { @@ -351,7 +344,7 @@ public class DeviceController { @ApiResponse(responseCode = "400", description = "The given token identifier or timeout was invalid") @ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay") public CompletionStage waitForLinkedDevice( - @ReadOnly @Auth final AuthenticatedDevice authenticatedDevice, + @Auth final AuthenticatedDevice authenticatedDevice, @PathParam("tokenIdentifier") @Schema(description = "A 'link device' token identifier provided by the 'create link device token' endpoint") @@ -374,12 +367,18 @@ public class DeviceController { final AtomicInteger linkedDeviceListenerCounter = getCounterForLinkedDeviceListeners(userAgent); linkedDeviceListenerCounter.incrementAndGet(); - return rateLimiters.getWaitForLinkedDeviceLimiter() - .validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI)) - .thenCompose(ignored -> persistentTimer.start(WAIT_FOR_LINKED_DEVICE_TIMER_NAMESPACE, tokenIdentifier)) - .thenCompose(sample -> accounts.waitForNewLinkedDevice( - authenticatedDevice.getAccount().getUuid(), - authenticatedDevice.getAuthenticatedDevice(), + return rateLimiters.getWaitForLinkedDeviceLimiter().validateAsync(authenticatedDevice.getAccountIdentifier()) + .thenCompose(ignored -> accounts.getByAccountIdentifierAsync(authenticatedDevice.getAccountIdentifier())) + .thenCompose(maybeAccount -> { + final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + + return persistentTimer.start(WAIT_FOR_LINKED_DEVICE_TIMER_NAMESPACE, tokenIdentifier) + .thenApply(sample -> new Pair<>(account, sample)); + }) + .thenCompose(accountAndSample -> accounts.waitForNewLinkedDevice( + authenticatedDevice.getAccountIdentifier(), + accountAndSample.first().getDevice(authenticatedDevice.getDeviceId()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)), tokenIdentifier, Duration.ofSeconds(timeoutSeconds)) .thenApply(maybeDeviceInfo -> maybeDeviceInfo @@ -391,7 +390,7 @@ public class DeviceController { linkedDeviceListenerCounter.decrementAndGet(); if (response != null && response.getStatus() == Response.Status.OK.getStatusCode()) { - sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME) + accountAndSample.second().stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME) .publishPercentileHistogram(true) .tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))) .register(Metrics.globalRegistry)); @@ -410,14 +409,15 @@ public class DeviceController { @PUT @Produces(MediaType.APPLICATION_JSON) @Path("/capabilities") - public void setCapabilities(@Mutable @Auth final AuthenticatedDevice auth, + public void setCapabilities(@Auth final AuthenticatedDevice auth, @NotNull final Map capabilities) { - assert (auth.getAuthenticatedDevice() != null); - final byte deviceId = auth.getAuthenticatedDevice().getId(); - accounts.updateDevice(auth.getAccount(), deviceId, + final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + + accounts.updateDevice(account, auth.getDeviceId(), d -> d.setCapabilities(DeviceCapabilityAdapter.mapToSet(capabilities))); } @@ -435,12 +435,13 @@ public class DeviceController { @ApiResponse(responseCode = "200", description = "Public key stored successfully") @ApiResponse(responseCode = "401", description = "Account authentication check failed") @ApiResponse(responseCode = "422", description = "Invalid request format") - public CompletableFuture setPublicKey(@Mutable @Auth final AuthenticatedDevice auth, + public CompletableFuture setPublicKey(@Auth final AuthenticatedDevice auth, final SetPublicKeyRequest setPublicKeyRequest) { - return clientPublicKeysManager.setPublicKey(auth.getAccount(), - auth.getAuthenticatedDevice().getId(), - setPublicKeyRequest.publicKey()); + final Account account = accounts.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + + return clientPublicKeysManager.setPublicKey(account, auth.getDeviceId(), setPublicKeyRequest.publicKey()); } private static boolean isCapabilityDowngrade(final Account account, final Set capabilities) { @@ -531,15 +532,21 @@ public class DeviceController { @ApiResponse(responseCode = "204", description = "Success") @ApiResponse(responseCode = "422", description = "The request object could not be parsed or was otherwise invalid") @ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay") - public CompletionStage recordTransferArchiveUploaded(@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice, + public CompletionStage recordTransferArchiveUploaded(@Auth final AuthenticatedDevice authenticatedDevice, @NotNull @Valid final TransferArchiveUploadedRequest transferArchiveUploadedRequest) { return rateLimiters.getUploadTransferArchiveLimiter() - .validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI)) - .thenCompose(ignored -> accounts.recordTransferArchiveUpload(authenticatedDevice.getAccount(), - transferArchiveUploadedRequest.destinationDeviceId(), - Instant.ofEpochMilli(transferArchiveUploadedRequest.destinationDeviceCreated()), - transferArchiveUploadedRequest.transferArchive())); + .validateAsync(authenticatedDevice.getAccountIdentifier()) + .thenCompose(ignored -> accounts.getByAccountIdentifierAsync(authenticatedDevice.getAccountIdentifier())) + .thenCompose(maybeAccount -> { + + final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + + return accounts.recordTransferArchiveUpload(account, + transferArchiveUploadedRequest.destinationDeviceId(), + Instant.ofEpochMilli(transferArchiveUploadedRequest.destinationDeviceCreated()), + transferArchiveUploadedRequest.transferArchive()); + }); } @GET @@ -558,7 +565,7 @@ public class DeviceController { @ApiResponse(responseCode = "204", description = "No transfer archive was uploaded before the call completed; clients may repeat the call to continue waiting") @ApiResponse(responseCode = "400", description = "The given timeout was invalid") @ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay") - public CompletionStage waitForTransferArchive(@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice, + public CompletionStage waitForTransferArchive(@Auth final AuthenticatedDevice authenticatedDevice, @QueryParam("timeout") @DefaultValue("30") @@ -575,24 +582,30 @@ public class DeviceController { @HeaderParam(HttpHeaders.USER_AGENT) @Nullable String userAgent) { - final String rateLimiterKey = authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI) + - ":" + authenticatedDevice.getAuthenticatedDevice().getId(); + final String rateLimiterKey = authenticatedDevice.getAccountIdentifier() + ":" + authenticatedDevice.getDeviceId(); return rateLimiters.getWaitForTransferArchiveLimiter().validateAsync(rateLimiterKey) - .thenCompose(ignored -> persistentTimer.start(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAMESPACE, rateLimiterKey)) - .thenCompose(sample -> accounts.waitForTransferArchive(authenticatedDevice.getAccount(), - authenticatedDevice.getAuthenticatedDevice(), + .thenCompose(ignored -> accounts.getByAccountIdentifierAsync(authenticatedDevice.getAccountIdentifier())) + .thenCompose(maybeAccount -> { + final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + + return persistentTimer.start(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAMESPACE, rateLimiterKey) + .thenApply(sample -> new Pair<>(account, sample)); + }) + .thenCompose(accountAndSample -> accounts.waitForTransferArchive(accountAndSample.first(), + accountAndSample.first().getDevice(authenticatedDevice.getDeviceId()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)), Duration.ofSeconds(timeoutSeconds)) .thenApply(maybeTransferArchive -> maybeTransferArchive .map(transferArchive -> Response.status(Response.Status.OK).entity(transferArchive).build()) .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())) .whenComplete((response, throwable) -> { if (response != null && response.getStatus() == Response.Status.OK.getStatusCode()) { - sample.stop(Timer.builder(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAME) + accountAndSample.second().stop(Timer.builder(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAME) .publishPercentileHistogram(true) .tags(Tags.of( UserAgentTagUtil.getPlatformTag(userAgent), - primaryPlatformTag(authenticatedDevice.getAccount()))) + primaryPlatformTag(accountAndSample.first()))) .register(Metrics.globalRegistry)); } })); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DirectoryV2Controller.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DirectoryV2Controller.java index 51d43f429..150006c26 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DirectoryV2Controller.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DirectoryV2Controller.java @@ -14,12 +14,10 @@ import jakarta.ws.rs.Path; import jakarta.ws.rs.Produces; import jakarta.ws.rs.core.MediaType; import java.time.Clock; -import java.util.UUID; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator; import org.whispersystems.textsecuregcm.configuration.DirectoryV2ClientConfiguration; -import org.whispersystems.websocket.auth.ReadOnly; @Path("/v2/directory") @Tag(name = "Directory") @@ -57,8 +55,7 @@ public class DirectoryV2Controller { """ ) @ApiResponse(responseCode = "200", description = "`JSON` with generated credentials.", useReturnTypeSchema = true) - public ExternalServiceCredentials getAuthToken(final @ReadOnly @Auth AuthenticatedDevice auth) { - final UUID uuid = auth.getAccount().getUuid(); - return directoryServiceTokenGenerator.generateForUuid(uuid); + public ExternalServiceCredentials getAuthToken(final @Auth AuthenticatedDevice auth) { + return directoryServiceTokenGenerator.generateForUuid(auth.getAccountIdentifier()); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DonationController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DonationController.java index de6e54102..369463675 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DonationController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DonationController.java @@ -15,6 +15,7 @@ import jakarta.ws.rs.Consumes; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; import jakarta.ws.rs.Produces; +import jakarta.ws.rs.WebApplicationException; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; import jakarta.ws.rs.core.Response.Status; @@ -33,10 +34,10 @@ import org.signal.libsignal.zkgroup.receipts.ServerZkReceiptOperations; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration; import org.whispersystems.textsecuregcm.entities.RedeemReceiptRequest; +import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountBadge; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.RedeemedReceiptsManager; -import org.whispersystems.websocket.auth.Mutable; @Path("/v1/donation") @Tag(name = "Donations") @@ -86,7 +87,7 @@ public class DonationController { """) @ApiResponse(responseCode = "429", description = "Rate limited.") public CompletionStage redeemReceipt( - @Mutable @Auth final AuthenticatedDevice auth, + @Auth final AuthenticatedDevice auth, @NotNull @Valid final RedeemReceiptRequest request) { return CompletableFuture.supplyAsync(() -> { ReceiptCredentialPresentation receiptCredentialPresentation; @@ -118,23 +119,29 @@ public class DonationController { .type(MediaType.TEXT_PLAIN_TYPE) .build()); } - return redeemedReceiptsManager.put( - receiptSerial, receiptExpiration.getEpochSecond(), receiptLevel, auth.getAccount().getUuid()) - .thenCompose(receiptMatched -> { - if (!receiptMatched) { - return CompletableFuture.completedFuture(Response.status(Status.BAD_REQUEST) - .entity("receipt serial is already redeemed") - .type(MediaType.TEXT_PLAIN_TYPE) - .build()); - } - return accountsManager.updateAsync(auth.getAccount(), a -> { - a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible())); - if (request.isPrimary()) { - a.makeBadgePrimaryIfExists(clock, badgeId); + return accountsManager.getByAccountIdentifierAsync(auth.getAccountIdentifier()) + .thenCompose(maybeAccount -> { + final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); + + return redeemedReceiptsManager.put( + receiptSerial, receiptExpiration.getEpochSecond(), receiptLevel, auth.getAccountIdentifier()) + .thenCompose(receiptMatched -> { + if (!receiptMatched) { + return CompletableFuture.completedFuture(Response.status(Status.BAD_REQUEST) + .entity("receipt serial is already redeemed") + .type(MediaType.TEXT_PLAIN_TYPE) + .build()); } - }) - .thenApply(ignored -> Response.ok().build()); + + return accountsManager.updateAsync(account, a -> { + a.addBadge(clock, new AccountBadge(badgeId, receiptExpiration, request.isVisible())); + if (request.isPrimary()) { + a.makeBadgePrimaryIfExists(clock, badgeId); + } + }) + .thenApply(ignored -> Response.ok().build()); + }); }); }).thenCompose(Function.identity()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/GetCallingRelaysResponse.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/GetCallingRelaysResponse.java index 753043e9d..68e2a1da4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/GetCallingRelaysResponse.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/GetCallingRelaysResponse.java @@ -5,9 +5,8 @@ package org.whispersystems.textsecuregcm.controllers; -import org.whispersystems.textsecuregcm.auth.TurnToken; - import java.util.List; +import org.whispersystems.textsecuregcm.auth.TurnToken; public record GetCallingRelaysResponse(List relays) { } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeepAliveController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeepAliveController.java index 21a627b7e..b19397fb1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeepAliveController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeepAliveController.java @@ -23,7 +23,6 @@ import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager; -import org.whispersystems.websocket.auth.ReadOnly; import org.whispersystems.websocket.session.WebSocketSession; import org.whispersystems.websocket.session.WebSocketSessionContext; @@ -45,16 +44,16 @@ public class KeepAliveController { } @GET - public Response getKeepAlive(@ReadOnly @Auth Optional maybeAuth, + public Response getKeepAlive(@Auth Optional maybeAuth, @WebSocketSession WebSocketSessionContext context) { maybeAuth.ifPresent(auth -> { - if (!webSocketConnectionEventManager.isLocallyPresent(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId())) { + if (!webSocketConnectionEventManager.isLocallyPresent(auth.getAccountIdentifier(), auth.getDeviceId())) { final Duration age = Duration.between(context.getClient().getCreated(), Instant.now()); logger.debug("***** No local subscription found for {}::{}; age = {}ms, User-Agent = {}", - auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), age.toMillis(), + auth.getAccountIdentifier(), auth.getDeviceId(), age.toMillis(), context.getClient().getUserAgent()); context.getClient().close(1000, "OK"); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyController.java index 5dfb22ccc..1f8934d36 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeyTransparencyController.java @@ -49,7 +49,6 @@ import org.whispersystems.textsecuregcm.keytransparency.KeyTransparencyServiceCl import org.whispersystems.textsecuregcm.limits.RateLimitedByIp; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.util.ExceptionUtils; -import org.whispersystems.websocket.auth.ReadOnly; @Path("/v1/key-transparency") @Tag(name = "KeyTransparency") @@ -90,7 +89,7 @@ public class KeyTransparencyController { @RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_SEARCH_PER_IP) @Produces(MediaType.APPLICATION_JSON) public KeyTransparencySearchResponse search( - @ReadOnly @Auth final Optional authenticatedAccount, + @Auth final Optional authenticatedAccount, @NotNull @Valid final KeyTransparencySearchRequest request) { // Disallow clients from making authenticated requests to this endpoint @@ -142,7 +141,7 @@ public class KeyTransparencyController { @RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_MONITOR_PER_IP) @Produces(MediaType.APPLICATION_JSON) public KeyTransparencyMonitorResponse monitor( - @ReadOnly @Auth final Optional authenticatedAccount, + @Auth final Optional authenticatedAccount, @NotNull @Valid final KeyTransparencyMonitorRequest request) { // Disallow clients from making authenticated requests to this endpoint @@ -204,7 +203,7 @@ public class KeyTransparencyController { @RateLimitedByIp(RateLimiters.For.KEY_TRANSPARENCY_DISTINGUISHED_PER_IP) @Produces(MediaType.APPLICATION_JSON) public KeyTransparencyDistinguishedKeyResponse getDistinguishedKey( - @ReadOnly @Auth final Optional authenticatedAccount, + @Auth final Optional authenticatedAccount, @Parameter(description = "The distinguished tree head size returned by a previously verified call") @QueryParam("lastTreeHeadSize") @Valid final Optional<@Positive Long> lastTreeHeadSize) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index a2ada6b1d..85d661fa6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -74,7 +74,6 @@ import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.Util; -import org.whispersystems.websocket.auth.ReadOnly; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @Path("/v2/keys") @@ -111,16 +110,21 @@ public class KeysController { description = "Gets the number of one-time prekeys uploaded for this device and still available") @ApiResponse(responseCode = "200", description = "Body contains the number of available one-time prekeys for the device.", useReturnTypeSchema = true) @ApiResponse(responseCode = "401", description = "Account authentication check failed.") - public CompletableFuture getStatus(@ReadOnly @Auth final AuthenticatedDevice auth, + public CompletableFuture getStatus(@Auth final AuthenticatedDevice auth, @QueryParam("identity") @DefaultValue("aci") final IdentityType identityType) { - final CompletableFuture ecCountFuture = - keysManager.getEcCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId()); + return accounts.getByAccountIdentifierAsync(auth.getAccountIdentifier()) + .thenCompose(maybeAccount -> { + final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); - final CompletableFuture pqCountFuture = - keysManager.getPqCount(auth.getAccount().getIdentifier(identityType), auth.getAuthenticatedDevice().getId()); + final CompletableFuture ecCountFuture = + keysManager.getEcCount(account.getIdentifier(identityType), auth.getDeviceId()); - return ecCountFuture.thenCombine(pqCountFuture, PreKeyCount::new); + final CompletableFuture pqCountFuture = + keysManager.getPqCount(account.getIdentifier(identityType), auth.getDeviceId()); + + return ecCountFuture.thenCombine(pqCountFuture, PreKeyCount::new); + }); } @PUT @@ -132,7 +136,7 @@ public class KeysController { @ApiResponse(responseCode = "403", description = "Attempt to change identity key from a non-primary device.") @ApiResponse(responseCode = "422", description = "Invalid request format.") public CompletableFuture setKeys( - @ReadOnly @Auth final AuthenticatedDevice auth, + @Auth final AuthenticatedDevice auth, @RequestBody @NotNull @Valid final SetKeysRequest setKeysRequest, @Parameter(allowEmptyValue=true) @@ -143,63 +147,70 @@ public class KeysController { @QueryParam("identity") @DefaultValue("aci") final IdentityType identityType, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) { - final Account account = auth.getAccount(); - final Device device = auth.getAuthenticatedDevice(); - final UUID identifier = account.getIdentifier(identityType); + return accounts.getByAccountIdentifierAsync(auth.getAccountIdentifier()) + .thenCompose(maybeAccount -> { + final Account account = maybeAccount + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); - checkSignedPreKeySignatures(setKeysRequest, account.getIdentityKey(identityType), userAgent); + final Device device = account.getDevice(auth.getDeviceId()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); - final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent); - final Tag primaryDeviceTag = Tag.of(PRIMARY_DEVICE_TAG_NAME, String.valueOf(auth.getAuthenticatedDevice().isPrimary())); - final Tag identityTypeTag = Tag.of(IDENTITY_TYPE_TAG_NAME, identityType.name()); + final UUID identifier = account.getIdentifier(identityType); - final List> storeFutures = new ArrayList<>(4); + checkSignedPreKeySignatures(setKeysRequest, account.getIdentityKey(identityType), userAgent); - if (!setKeysRequest.preKeys().isEmpty()) { - final Tags tags = Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "ec")); + final Tag platformTag = UserAgentTagUtil.getPlatformTag(userAgent); + final Tag primaryDeviceTag = Tag.of(PRIMARY_DEVICE_TAG_NAME, String.valueOf(auth.getDeviceId() == Device.PRIMARY_ID)); + final Tag identityTypeTag = Tag.of(IDENTITY_TYPE_TAG_NAME, identityType.name()); - Metrics.counter(STORE_KEYS_COUNTER_NAME, tags).increment(); + final List> storeFutures = new ArrayList<>(4); - DistributionSummary.builder(STORE_KEY_BUNDLE_SIZE_DISTRIBUTION_NAME) - .tags(tags) - .publishPercentileHistogram() - .register(Metrics.globalRegistry) - .record(setKeysRequest.preKeys().size()); + if (!setKeysRequest.preKeys().isEmpty()) { + final Tags tags = Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "ec")); - storeFutures.add(keysManager.storeEcOneTimePreKeys(identifier, device.getId(), setKeysRequest.preKeys())); - } + Metrics.counter(STORE_KEYS_COUNTER_NAME, tags).increment(); - if (setKeysRequest.signedPreKey() != null) { - Metrics.counter(STORE_KEYS_COUNTER_NAME, - Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "ec-signed"))) - .increment(); + DistributionSummary.builder(STORE_KEY_BUNDLE_SIZE_DISTRIBUTION_NAME) + .tags(tags) + .publishPercentileHistogram() + .register(Metrics.globalRegistry) + .record(setKeysRequest.preKeys().size()); - storeFutures.add(keysManager.storeEcSignedPreKeys(identifier, device.getId(), setKeysRequest.signedPreKey())); - } + storeFutures.add(keysManager.storeEcOneTimePreKeys(identifier, device.getId(), setKeysRequest.preKeys())); + } - if (!setKeysRequest.pqPreKeys().isEmpty()) { - final Tags tags = Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "kyber")); - Metrics.counter(STORE_KEYS_COUNTER_NAME, tags).increment(); + if (setKeysRequest.signedPreKey() != null) { + Metrics.counter(STORE_KEYS_COUNTER_NAME, + Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "ec-signed"))) + .increment(); - DistributionSummary.builder(STORE_KEY_BUNDLE_SIZE_DISTRIBUTION_NAME) - .tags(tags) - .publishPercentileHistogram() - .register(Metrics.globalRegistry) - .record(setKeysRequest.pqPreKeys().size()); + storeFutures.add(keysManager.storeEcSignedPreKeys(identifier, device.getId(), setKeysRequest.signedPreKey())); + } - storeFutures.add(keysManager.storeKemOneTimePreKeys(identifier, device.getId(), setKeysRequest.pqPreKeys())); - } + if (!setKeysRequest.pqPreKeys().isEmpty()) { + final Tags tags = Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "kyber")); + Metrics.counter(STORE_KEYS_COUNTER_NAME, tags).increment(); - if (setKeysRequest.pqLastResortPreKey() != null) { - Metrics.counter(STORE_KEYS_COUNTER_NAME, - Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "kyber-last-resort"))) - .increment(); + DistributionSummary.builder(STORE_KEY_BUNDLE_SIZE_DISTRIBUTION_NAME) + .tags(tags) + .publishPercentileHistogram() + .register(Metrics.globalRegistry) + .record(setKeysRequest.pqPreKeys().size()); - storeFutures.add(keysManager.storePqLastResort(identifier, device.getId(), setKeysRequest.pqLastResortPreKey())); - } + storeFutures.add(keysManager.storeKemOneTimePreKeys(identifier, device.getId(), setKeysRequest.pqPreKeys())); + } - return CompletableFuture.allOf(storeFutures.toArray(EMPTY_FUTURE_ARRAY)) - .thenApply(Util.ASYNC_EMPTY_RESPONSE); + if (setKeysRequest.pqLastResortPreKey() != null) { + Metrics.counter(STORE_KEYS_COUNTER_NAME, + Tags.of(platformTag, primaryDeviceTag, identityTypeTag, Tag.of(KEY_TYPE_TAG_NAME, "kyber-last-resort"))) + .increment(); + + storeFutures.add(keysManager.storePqLastResort(identifier, device.getId(), setKeysRequest.pqLastResortPreKey())); + } + + return CompletableFuture.allOf(storeFutures.toArray(EMPTY_FUTURE_ARRAY)) + .thenApply(Util.ASYNC_EMPTY_RESPONSE); + }); } private void checkSignedPreKeySignatures(final SetKeysRequest setKeysRequest, @@ -253,64 +264,69 @@ public class KeysController { """) @ApiResponse(responseCode = "422", description = "Invalid request format") public CompletableFuture checkKeys( - @ReadOnly @Auth final AuthenticatedDevice auth, + @Auth final AuthenticatedDevice auth, @RequestBody @NotNull @Valid final CheckKeysRequest checkKeysRequest) { - final UUID identifier = auth.getAccount().getIdentifier(checkKeysRequest.identityType()); - final byte deviceId = auth.getAuthenticatedDevice().getId(); + return accounts.getByAccountIdentifierAsync(auth.getAccountIdentifier()) + .thenCompose(maybeAccount -> { + final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); - final CompletableFuture> ecSignedPreKeyFuture = - keysManager.getEcSignedPreKey(identifier, deviceId); + final UUID identifier = account.getIdentifier(checkKeysRequest.identityType()); + final byte deviceId = auth.getDeviceId(); - final CompletableFuture> lastResortKeyFuture = - keysManager.getLastResort(identifier, deviceId); + final CompletableFuture> ecSignedPreKeyFuture = + keysManager.getEcSignedPreKey(identifier, deviceId); - return CompletableFuture.allOf(ecSignedPreKeyFuture, lastResortKeyFuture) - .thenApply(ignored -> { - final Optional maybeSignedPreKey = ecSignedPreKeyFuture.join(); - final Optional maybeLastResortKey = lastResortKeyFuture.join(); + final CompletableFuture> lastResortKeyFuture = + keysManager.getLastResort(identifier, deviceId); - final boolean digestsMatch; + return CompletableFuture.allOf(ecSignedPreKeyFuture, lastResortKeyFuture) + .thenApply(ignored -> { + final Optional maybeSignedPreKey = ecSignedPreKeyFuture.join(); + final Optional maybeLastResortKey = lastResortKeyFuture.join(); - if (maybeSignedPreKey.isPresent() && maybeLastResortKey.isPresent()) { - final IdentityKey identityKey = auth.getAccount().getIdentityKey(checkKeysRequest.identityType()); - final ECSignedPreKey ecSignedPreKey = maybeSignedPreKey.get(); - final KEMSignedPreKey lastResortKey = maybeLastResortKey.get(); + final boolean digestsMatch; - final MessageDigest messageDigest; + if (maybeSignedPreKey.isPresent() && maybeLastResortKey.isPresent()) { + final IdentityKey identityKey = account.getIdentityKey(checkKeysRequest.identityType()); + final ECSignedPreKey ecSignedPreKey = maybeSignedPreKey.get(); + final KEMSignedPreKey lastResortKey = maybeLastResortKey.get(); - try { - messageDigest = MessageDigest.getInstance("SHA-256"); - } catch (final NoSuchAlgorithmException e) { - throw new AssertionError("Every implementation of the Java platform is required to support SHA-256", e); - } + final MessageDigest messageDigest; - messageDigest.update(identityKey.serialize()); + try { + messageDigest = MessageDigest.getInstance("SHA-256"); + } catch (final NoSuchAlgorithmException e) { + throw new AssertionError("Every implementation of the Java platform is required to support SHA-256", e); + } - { - final ByteBuffer ecSignedPreKeyIdBuffer = ByteBuffer.allocate(Long.BYTES); - ecSignedPreKeyIdBuffer.putLong(ecSignedPreKey.keyId()); - ecSignedPreKeyIdBuffer.flip(); + messageDigest.update(identityKey.serialize()); - messageDigest.update(ecSignedPreKeyIdBuffer); - messageDigest.update(ecSignedPreKey.serializedPublicKey()); - } + { + final ByteBuffer ecSignedPreKeyIdBuffer = ByteBuffer.allocate(Long.BYTES); + ecSignedPreKeyIdBuffer.putLong(ecSignedPreKey.keyId()); + ecSignedPreKeyIdBuffer.flip(); - { - final ByteBuffer lastResortKeyIdBuffer = ByteBuffer.allocate(Long.BYTES); - lastResortKeyIdBuffer.putLong(lastResortKey.keyId()); - lastResortKeyIdBuffer.flip(); + messageDigest.update(ecSignedPreKeyIdBuffer); + messageDigest.update(ecSignedPreKey.serializedPublicKey()); + } - messageDigest.update(lastResortKeyIdBuffer); - messageDigest.update(lastResortKey.serializedPublicKey()); - } + { + final ByteBuffer lastResortKeyIdBuffer = ByteBuffer.allocate(Long.BYTES); + lastResortKeyIdBuffer.putLong(lastResortKey.keyId()); + lastResortKeyIdBuffer.flip(); - digestsMatch = MessageDigest.isEqual(messageDigest.digest(), checkKeysRequest.digest()); - } else { - digestsMatch = false; - } + messageDigest.update(lastResortKeyIdBuffer); + messageDigest.update(lastResortKey.serializedPublicKey()); + } - return Response.status(digestsMatch ? Response.Status.OK : Response.Status.CONFLICT).build(); + digestsMatch = MessageDigest.isEqual(messageDigest.digest(), checkKeysRequest.digest()); + } else { + digestsMatch = false; + } + + return Response.status(digestsMatch ? Response.Status.OK : Response.Status.CONFLICT).build(); + }); }); } @@ -327,7 +343,7 @@ public class KeysController { name = "Retry-After", description = "If present, a positive integer indicating the number of seconds before a subsequent attempt could succeed")) public PreKeyResponse getDeviceKeys( - @ReadOnly @Auth Optional auth, + @Auth Optional maybeAuthenticatedDevice, @HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional accessKey, @HeaderParam(HeaderUtils.GROUP_SEND_TOKEN) Optional groupSendToken, @@ -340,15 +356,18 @@ public class KeysController { @HeaderParam(HttpHeaders.USER_AGENT) String userAgent) throws RateLimitExceededException { - if (auth.isEmpty() && accessKey.isEmpty() && groupSendToken.isEmpty()) { + if (maybeAuthenticatedDevice.isEmpty() && accessKey.isEmpty() && groupSendToken.isEmpty()) { throw new WebApplicationException(Response.Status.UNAUTHORIZED); } - final Optional account = auth.map(AuthenticatedDevice::getAccount); + final Optional account = maybeAuthenticatedDevice + .map(authenticatedDevice -> accounts.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED))); + final Optional maybeTarget = accounts.getByServiceIdentifier(targetIdentifier); if (groupSendToken.isPresent()) { - if (auth.isPresent() || accessKey.isPresent()) { + if (maybeAuthenticatedDevice.isPresent() || accessKey.isPresent()) { throw new BadRequestException(); } try { @@ -364,7 +383,7 @@ public class KeysController { if (account.isPresent()) { rateLimiters.getPreKeysLimiter().validate( - account.get().getUuid() + "." + auth.get().getAuthenticatedDevice().getId() + "__" + targetIdentifier.uuid() + account.get().getUuid() + "." + maybeAuthenticatedDevice.get().getDeviceId() + "__" + targetIdentifier.uuid() + "." + deviceId); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 04d564563..ee3cb2641 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -105,6 +105,7 @@ import org.whispersystems.textsecuregcm.spam.SpamChecker; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; +import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers; import org.whispersystems.textsecuregcm.storage.ReportMessageManager; @@ -112,7 +113,6 @@ import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.websocket.WebSocketConnection; import org.whispersystems.websocket.WebsocketHeaders; -import org.whispersystems.websocket.auth.ReadOnly; import reactor.core.scheduler.Scheduler; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @@ -236,7 +236,7 @@ public class MessageController { @ApiResponse( responseCode="428", description="The sender should complete a challenge before proceeding") - public Response sendMessage(@ReadOnly @Auth final Optional source, + public Response sendMessage(@Auth final Optional source, @Parameter(description="The recipient's unidentified access key") @HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) final Optional accessKey, @@ -274,12 +274,14 @@ public class MessageController { sendStoryMessage(destinationIdentifier, messages, context); } else if (source.isPresent()) { final AuthenticatedDevice authenticatedDevice = source.get(); + final Account account = accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); - if (authenticatedDevice.getAccount().isIdentifiedBy(destinationIdentifier)) { + if (account.isIdentifiedBy(destinationIdentifier)) { needsSync = false; - sendSyncMessage(source.get(), destinationIdentifier, messages, context); + sendSyncMessage(source.get(), account, destinationIdentifier, messages, context); } else { - needsSync = authenticatedDevice.getAccount().getDevices().size() > 1; + needsSync = account.getDevices().size() > 1; sendIdentifiedSenderIndividualMessage(authenticatedDevice, destinationIdentifier, messages, context); } } else { @@ -302,7 +304,7 @@ public class MessageController { final Account destination = accountsManager.getByServiceIdentifier(destinationIdentifier).orElseThrow(NotFoundException::new); - rateLimiters.getMessagesLimiter().validate(source.getAccount().getUuid(), destination.getUuid()); + rateLimiters.getMessagesLimiter().validate(source.getAccountIdentifier(), destination.getUuid()); sendIndividualMessage(destination, destinationIdentifier, @@ -314,6 +316,7 @@ public class MessageController { } private void sendSyncMessage(final AuthenticatedDevice source, + final Account sourceAccount, final ServiceIdentifier destinationIdentifier, final IncomingMessageList messages, final ContainerRequestContext context) @@ -323,7 +326,7 @@ public class MessageController { throw new WebApplicationException(Status.FORBIDDEN); } - sendIndividualMessage(source.getAccount(), + sendIndividualMessage(sourceAccount, destinationIdentifier, source, messages, @@ -420,8 +423,8 @@ public class MessageController { try { return message.toEnvelope( destinationIdentifier, - sender != null ? sender.getAccount() : null, - sender != null ? sender.getAuthenticatedDevice().getId() : null, + sender != null ? new AciServiceIdentifier(sender.getAccountIdentifier()) : null, + sender != null ? sender.getDeviceId() : null, messages.timestamp() == 0 ? System.currentTimeMillis() : messages.timestamp(), isStory, messages.online(), @@ -437,7 +440,7 @@ public class MessageController { .collect(Collectors.toMap(IncomingMessage::destinationDeviceId, IncomingMessage::destinationRegistrationId)); final Optional syncMessageSenderDeviceId = messageType == MessageType.SYNC - ? Optional.ofNullable(sender).map(authenticatedDevice -> authenticatedDevice.getAuthenticatedDevice().getId()) + ? Optional.ofNullable(sender).map(AuthenticatedDevice::getDeviceId) : Optional.empty(); try { @@ -755,31 +758,37 @@ public class MessageController { @Timed @GET @Produces(MediaType.APPLICATION_JSON) - public CompletableFuture getPendingMessages(@ReadOnly @Auth AuthenticatedDevice auth, + public CompletableFuture getPendingMessages(@Auth AuthenticatedDevice auth, @HeaderParam(WebsocketHeaders.X_SIGNAL_RECEIVE_STORIES) String receiveStoriesHeader, @HeaderParam(HttpHeaders.USER_AGENT) String userAgent) { - boolean shouldReceiveStories = WebsocketHeaders.parseReceiveStoriesHeader(receiveStoriesHeader); + return accountsManager.getByAccountIdentifierAsync(auth.getAccountIdentifier()) + .thenCompose(maybeAccount -> { + final Account account = maybeAccount.orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); + final Device device = account.getDevice(auth.getDeviceId()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); - pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), auth.getAuthenticatedDevice(), userAgent); + final boolean shouldReceiveStories = WebsocketHeaders.parseReceiveStoriesHeader(receiveStoriesHeader); - return messagesManager.getMessagesForDevice( - auth.getAccount().getUuid(), - auth.getAuthenticatedDevice(), - false) - .map(messagesAndHasMore -> { - Stream envelopes = messagesAndHasMore.first().stream(); - if (!shouldReceiveStories) { - envelopes = envelopes.filter(e -> !e.getStory()); - } + pushNotificationManager.handleMessagesRetrieved(account, device, userAgent); + + return messagesManager.getMessagesForDevice( + auth.getAccountIdentifier(), + device, + false) + .map(messagesAndHasMore -> { + Stream envelopes = messagesAndHasMore.first().stream(); + if (!shouldReceiveStories) { + envelopes = envelopes.filter(e -> !e.getStory()); + } final OutgoingMessageEntityList messages = new OutgoingMessageEntityList(envelopes .map(OutgoingMessageEntity::fromEnvelope) .peek(outgoingMessageEntity -> { - messageMetrics.measureAccountOutgoingMessageUuidMismatches(auth.getAccount(), outgoingMessageEntity); + messageMetrics.measureAccountOutgoingMessageUuidMismatches(account, outgoingMessageEntity); messageMetrics.measureOutgoingMessageLatency(outgoingMessageEntity.serverTimestamp(), "rest", - auth.getAuthenticatedDevice().isPrimary(), + auth.getDeviceId() == Device.PRIMARY_ID, outgoingMessageEntity.urgent(), // Messages fetched via this endpoint (as opposed to WebSocketConnection) are never ephemeral // because, by definition, the client doesn't have a "live" connection via which to receive @@ -791,26 +800,27 @@ public class MessageController { .collect(Collectors.toList()), messagesAndHasMore.second()); - Metrics.summary(OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))) - .record(estimateMessageListSizeBytes(messages)); + Metrics.summary(OUTGOING_MESSAGE_LIST_SIZE_BYTES_DISTRIBUTION_NAME, Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))) + .record(estimateMessageListSizeBytes(messages)); - if (!messages.messages().isEmpty()) { - messageDeliveryLoopMonitor.recordDeliveryAttempt(auth.getAccount().getIdentifier(IdentityType.ACI), - auth.getAuthenticatedDevice().getId(), - messages.messages().getFirst().guid(), - userAgent, - "rest"); - } + if (!messages.messages().isEmpty()) { + messageDeliveryLoopMonitor.recordDeliveryAttempt(auth.getAccountIdentifier(), + auth.getDeviceId(), + messages.messages().getFirst().guid(), + userAgent, + "rest"); + } - if (messagesAndHasMore.second()) { - pushNotificationScheduler.scheduleDelayedNotification(auth.getAccount(), auth.getAuthenticatedDevice(), NOTIFY_FOR_REMAINING_MESSAGES_DELAY); - } + if (messagesAndHasMore.second()) { + pushNotificationScheduler.scheduleDelayedNotification(account, device, NOTIFY_FOR_REMAINING_MESSAGES_DELAY); + } - return messages; - }) - .timeout(Duration.ofSeconds(5)) - .subscribeOn(messageDeliveryScheduler) - .toFuture(); + return messages; + }) + .timeout(Duration.ofSeconds(5)) + .subscribeOn(messageDeliveryScheduler) + .toFuture(); + }); } private static long estimateMessageListSizeBytes(final OutgoingMessageEntityList messageList) { @@ -827,22 +837,27 @@ public class MessageController { @Timed @DELETE @Path("/uuid/{uuid}") - public CompletableFuture removePendingMessage(@ReadOnly @Auth AuthenticatedDevice auth, @PathParam("uuid") UUID uuid) { + public CompletableFuture removePendingMessage(@Auth AuthenticatedDevice auth, @PathParam("uuid") UUID uuid) { + final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); + + final Device device = account.getDevice(auth.getDeviceId()) + .orElseThrow(() -> new WebApplicationException(Status.UNAUTHORIZED)); + return messagesManager.delete( - auth.getAccount().getUuid(), - auth.getAuthenticatedDevice(), + auth.getAccountIdentifier(), + device, uuid, null) .thenAccept(maybeRemovedMessage -> maybeRemovedMessage.ifPresent(removedMessage -> { - WebSocketConnection.recordMessageDeliveryDuration(removedMessage.serverTimestamp(), - auth.getAuthenticatedDevice()); + WebSocketConnection.recordMessageDeliveryDuration(removedMessage.serverTimestamp(), device); if (removedMessage.sourceServiceId().isPresent() && removedMessage.envelopeType() != Type.SERVER_DELIVERY_RECEIPT) { if (removedMessage.sourceServiceId().get() instanceof AciServiceIdentifier aciServiceIdentifier) { try { - receiptSender.sendReceipt(removedMessage.destinationServiceId(), auth.getAuthenticatedDevice().getId(), + receiptSender.sendReceipt(removedMessage.destinationServiceId(), auth.getDeviceId(), aciServiceIdentifier, removedMessage.clientTimestamp()); } catch (Exception e) { logger.warn("Failed to send delivery receipt", e); @@ -863,7 +878,7 @@ public class MessageController { @Consumes(MediaType.APPLICATION_JSON) @Path("/report/{source}/{messageGuid}") public Response reportSpamMessage( - @ReadOnly @Auth AuthenticatedDevice auth, + @Auth AuthenticatedDevice auth, @PathParam("source") String source, @PathParam("messageGuid") UUID messageGuid, @Nullable SpamReport spamReport, @@ -899,7 +914,7 @@ public class MessageController { } } - UUID spamReporterUuid = auth.getAccount().getUuid(); + UUID spamReporterUuid = auth.getAccountIdentifier(); // spam report token is optional, but if provided ensure it is non-empty. final Optional maybeSpamReportToken = diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MultiRecipientMismatchedDevicesException.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MultiRecipientMismatchedDevicesException.java index 1b66c764e..7daee831c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MultiRecipientMismatchedDevicesException.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MultiRecipientMismatchedDevicesException.java @@ -5,8 +5,8 @@ package org.whispersystems.textsecuregcm.controllers; -import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import java.util.Map; +import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; public class MultiRecipientMismatchedDevicesException extends Exception { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/OneTimeDonationController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/OneTimeDonationController.java index ab856f86e..c3ed49c2a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/OneTimeDonationController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/OneTimeDonationController.java @@ -69,7 +69,6 @@ import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; -import org.whispersystems.websocket.auth.ReadOnly; /** @@ -163,7 +162,7 @@ public class OneTimeDonationController { @StringToClassMapItem(key = "error", value = String.class) }))) public CompletableFuture createBoostPaymentIntent( - @ReadOnly @Auth Optional authenticatedAccount, + @Auth Optional authenticatedAccount, @NotNull @Valid CreateBoostRequest request, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) { @@ -249,7 +248,7 @@ public class OneTimeDonationController { @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) public CompletableFuture createPayPalBoost( - @ReadOnly @Auth Optional authenticatedAccount, + @Auth Optional authenticatedAccount, @NotNull @Valid CreatePayPalBoostRequest request, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @Context ContainerRequestContext containerRequestContext) { @@ -296,7 +295,7 @@ public class OneTimeDonationController { @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) public CompletableFuture confirmPayPalBoost( - @ReadOnly @Auth Optional authenticatedAccount, + @Auth Optional authenticatedAccount, @NotNull @Valid ConfirmPayPalBoostRequest request, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) { @@ -342,7 +341,7 @@ public class OneTimeDonationController { @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) public CompletableFuture createBoostReceiptCredentials( - @ReadOnly @Auth Optional authenticatedAccount, + @Auth Optional authenticatedAccount, @NotNull @Valid final CreateBoostReceiptCredentialsRequest request, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/PaymentsController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/PaymentsController.java index 6d6383fdc..15d0b8fcf 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/PaymentsController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/PaymentsController.java @@ -17,7 +17,6 @@ import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator import org.whispersystems.textsecuregcm.configuration.PaymentsServiceConfiguration; import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager; import org.whispersystems.textsecuregcm.entities.CurrencyConversionEntityList; -import org.whispersystems.websocket.auth.ReadOnly; @Path("/v1/payments") @Tag(name = "Payments") @@ -43,14 +42,14 @@ public class PaymentsController { @GET @Path("/auth") @Produces(MediaType.APPLICATION_JSON) - public ExternalServiceCredentials getAuth(final @ReadOnly @Auth AuthenticatedDevice auth) { - return paymentsServiceCredentialsGenerator.generateForUuid(auth.getAccount().getUuid()); + public ExternalServiceCredentials getAuth(final @Auth AuthenticatedDevice auth) { + return paymentsServiceCredentialsGenerator.generateForUuid(auth.getAccountIdentifier()); } @GET @Path("/conversions") @Produces(MediaType.APPLICATION_JSON) - public CurrencyConversionEntityList getConversions(final @ReadOnly @Auth AuthenticatedDevice auth) { + public CurrencyConversionEntityList getConversions(final @Auth AuthenticatedDevice auth) { return currencyManager.getCurrencyConversions().orElseThrow(); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java index ebe5c38ed..469a93504 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java @@ -26,6 +26,7 @@ import jakarta.ws.rs.Path; import jakarta.ws.rs.PathParam; import jakarta.ws.rs.Produces; import jakarta.ws.rs.QueryParam; +import jakarta.ws.rs.WebApplicationException; import jakarta.ws.rs.container.ContainerRequestContext; import jakarta.ws.rs.core.Context; import jakarta.ws.rs.core.HttpHeaders; @@ -94,8 +95,6 @@ import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.ProfileHelper; import org.whispersystems.textsecuregcm.util.Util; -import org.whispersystems.websocket.auth.Mutable; -import org.whispersystems.websocket.auth.ReadOnly; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @Path("/v1/profile") @@ -152,15 +151,18 @@ public class ProfileController { @PUT @Produces(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON) - public Response setProfile(@Mutable @Auth AuthenticatedDevice auth, @NotNull @Valid CreateProfileRequest request) { + public Response setProfile(@Auth AuthenticatedDevice auth, @NotNull @Valid CreateProfileRequest request) { - final Optional currentProfile = profilesManager.get(auth.getAccount().getUuid(), - request.version()); + final Account account = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED)); + + final Optional currentProfile = + profilesManager.get(auth.getAccountIdentifier(), request.version()); if (request.paymentAddress() != null && request.paymentAddress().length != 0) { final boolean hasDisallowedPrefix = dynamicConfigurationManager.getConfiguration().getPaymentsConfiguration().getDisallowedPrefixes().stream() - .anyMatch(prefix -> auth.getAccount().getNumber().startsWith(prefix)); + .anyMatch(prefix -> account.getNumber().startsWith(prefix)); if (hasDisallowedPrefix && currentProfile.map(VersionedProfile::paymentAddress).isEmpty()) { return Response.status(Response.Status.FORBIDDEN).build(); @@ -179,7 +181,7 @@ public class ProfileController { case UPDATE -> ProfileHelper.generateAvatarObjectName(); }; - profilesManager.set(auth.getAccount().getUuid(), + profilesManager.set(auth.getAccountIdentifier(), new VersionedProfile( request.version(), request.name(), @@ -194,7 +196,7 @@ public class ProfileController { currentAvatar.ifPresent(s -> profilesManager.deleteAvatar(s).join()); } - accountsManager.update(auth.getAccount(), a -> { + accountsManager.update(account, a -> { final List updatedBadges = request.badges() .map(badges -> ProfileHelper.mergeBadgeIdsWithExistingAccountBadges(clock, badgeConfigurationMap, badges, a.getBadges())) @@ -216,7 +218,7 @@ public class ProfileController { @Path("/{identifier}/{version}") @ManagedAsync public VersionedProfileResponse getProfile( - @ReadOnly @Auth Optional auth, + @Auth Optional maybeAuthenticatedDevice, @HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional accessKey, @Context ContainerRequestContext containerRequestContext, @PathParam("identifier") AciServiceIdentifier accountIdentifier, @@ -224,7 +226,11 @@ public class ProfileController { @HeaderParam(HttpHeaders.USER_AGENT) String userAgent) throws RateLimitExceededException { - final Optional maybeRequester = auth.map(AuthenticatedDevice::getAccount); + final Optional maybeRequester = + maybeAuthenticatedDevice.map( + authenticatedDevice -> accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED))); + final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier, "getVersionedProfile", userAgent); return buildVersionedProfileResponse(targetAccount, @@ -238,7 +244,7 @@ public class ProfileController { @Produces(MediaType.APPLICATION_JSON) @Path("/{identifier}/{version}/{credentialRequest}") public CredentialProfileResponse getProfile( - @ReadOnly @Auth Optional auth, + @Auth Optional maybeAuthenticatedDevice, @HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional accessKey, @Context ContainerRequestContext containerRequestContext, @PathParam("identifier") AciServiceIdentifier accountIdentifier, @@ -252,7 +258,11 @@ public class ProfileController { throw new BadRequestException(); } - final Optional maybeRequester = auth.map(AuthenticatedDevice::getAccount); + final Optional maybeRequester = + maybeAuthenticatedDevice.map( + authenticatedDevice -> accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED))); + final Account targetAccount = verifyPermissionToReceiveProfile(maybeRequester, accessKey, accountIdentifier, "credentialRequest", userAgent); final boolean isSelf = maybeRequester.map(requester -> ProfileHelper.isSelfProfileRequest(requester.getUuid(), accountIdentifier)).orElse(false); @@ -270,7 +280,7 @@ public class ProfileController { @Path("/{identifier}") @ManagedAsync public BaseProfileResponse getUnversionedProfile( - @ReadOnly @Auth Optional auth, + @Auth Optional maybeAuthenticatedDevice, @HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional accessKey, @HeaderParam(HeaderUtils.GROUP_SEND_TOKEN) Optional groupSendToken, @Context ContainerRequestContext containerRequestContext, @@ -278,7 +288,10 @@ public class ProfileController { @PathParam("identifier") ServiceIdentifier identifier) throws RateLimitExceededException { - final Optional maybeRequester = auth.map(AuthenticatedDevice::getAccount); + final Optional maybeRequester = + maybeAuthenticatedDevice.map( + authenticatedDevice -> accountsManager.getByAccountIdentifier(authenticatedDevice.getAccountIdentifier()) + .orElseThrow(() -> new WebApplicationException(Response.Status.UNAUTHORIZED))); final Account targetAccount; if (groupSendToken.isPresent()) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java index 5c6232b46..00c2af142 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java @@ -34,7 +34,6 @@ import org.whispersystems.textsecuregcm.entities.ProvisioningMessage; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.push.ProvisioningManager; -import org.whispersystems.websocket.auth.ReadOnly; /** * The provisioning controller facilitates transmission of provisioning messages from the primary device associated with @@ -77,7 +76,7 @@ public class ProvisioningController { @ApiResponse(responseCode="204", description="The provisioning message was delivered to the given provisioning address") @ApiResponse(responseCode="400", description="The provisioning message was too large") @ApiResponse(responseCode="404", description="No device with the given provisioning address was connected at the time of the request") - public void sendProvisioningMessage(@ReadOnly @Auth final AuthenticatedDevice auth, + public void sendProvisioningMessage(@Auth final AuthenticatedDevice auth, @Parameter(description = "The temporary provisioning address to which to send a provisioning message") @PathParam("destination") final String provisioningAddress, @@ -93,7 +92,7 @@ public class ProvisioningController { throw new WebApplicationException(Response.Status.BAD_REQUEST); } - rateLimiters.getMessagesLimiter().validate(auth.getAccount().getUuid()); + rateLimiters.getMessagesLimiter().validate(auth.getAccountIdentifier()); final boolean subscriberPresent = provisioningManager.sendProvisioningMessage(provisioningAddress, Base64.getMimeDecoder().decode(message.body())); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RemoteConfigController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RemoteConfigController.java index a0555af1b..85791435c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RemoteConfigController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RemoteConfigController.java @@ -30,7 +30,6 @@ import org.whispersystems.textsecuregcm.entities.UserRemoteConfigList; import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager; import org.whispersystems.textsecuregcm.util.Conversions; import org.whispersystems.textsecuregcm.util.Util; -import org.whispersystems.websocket.auth.ReadOnly; @Path("/v1/config") @Tag(name = "Remote Config") @@ -64,7 +63,7 @@ public class RemoteConfigController { """ ) @ApiResponse(responseCode = "200", description = "Remote configuration values for the authenticated user", useReturnTypeSchema = true) - public UserRemoteConfigList getAll(@ReadOnly @Auth AuthenticatedDevice auth) { + public UserRemoteConfigList getAll(@Auth AuthenticatedDevice auth) { try { MessageDigest digest = MessageDigest.getInstance("SHA1"); @@ -73,7 +72,7 @@ public class RemoteConfigController { return new UserRemoteConfigList(Stream.concat(remoteConfigsManager.getAll().stream().map(config -> { final byte[] hashKey = config.getHashKey() != null ? config.getHashKey().getBytes(StandardCharsets.UTF_8) : config.getName().getBytes(StandardCharsets.UTF_8); - boolean inBucket = isInBucket(digest, auth.getAccount().getUuid(), hashKey, config.getPercentage(), + boolean inBucket = isInBucket(digest, auth.getAccountIdentifier(), hashKey, config.getPercentage(), config.getUuids()); return new UserRemoteConfig(config.getName(), inBucket, inBucket ? config.getValue() : config.getDefaultValue()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureStorageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureStorageController.java index c152e06af..7df8de81a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureStorageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureStorageController.java @@ -17,7 +17,6 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator; import org.whispersystems.textsecuregcm.configuration.SecureStorageServiceConfiguration; -import org.whispersystems.websocket.auth.ReadOnly; @Path("/v1/storage") @Tag(name = "Secure Storage") @@ -47,7 +46,7 @@ public class SecureStorageController { """ ) @ApiResponse(responseCode = "200", description = "`JSON` with generated credentials.", useReturnTypeSchema = true) - public ExternalServiceCredentials getAuth(@ReadOnly @Auth AuthenticatedDevice auth) { - return storageServiceCredentialsGenerator.generateForUuid(auth.getAccount().getUuid()); + public ExternalServiceCredentials getAuth(@Auth AuthenticatedDevice auth) { + return storageServiceCredentialsGenerator.generateForUuid(auth.getAccountIdentifier()); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureValueRecovery2Controller.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureValueRecovery2Controller.java index f0bc31379..e74052085 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureValueRecovery2Controller.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureValueRecovery2Controller.java @@ -34,7 +34,6 @@ import org.whispersystems.textsecuregcm.limits.RateLimitedByIp; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; -import org.whispersystems.websocket.auth.ReadOnly; @Path("/v2/backup") @Tag(name = "Secure Value Recovery") @@ -78,8 +77,8 @@ public class SecureValueRecovery2Controller { ) @ApiResponse(responseCode = "200", description = "`JSON` with generated credentials.", useReturnTypeSchema = true) @ApiResponse(responseCode = "401", description = "Account authentication check failed.") - public ExternalServiceCredentials getAuth(@ReadOnly @Auth final AuthenticatedDevice auth) { - return backupServiceCredentialGenerator.generateFor(auth.getAccount().getUuid().toString()); + public ExternalServiceCredentials getAuth(@Auth final AuthenticatedDevice auth) { + return backupServiceCredentialGenerator.generateFor(auth.getAccountIdentifier().toString()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/StickerController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/StickerController.java index 1b94418c8..0a5001a3d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/StickerController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/StickerController.java @@ -28,7 +28,6 @@ import org.whispersystems.textsecuregcm.s3.PolicySigner; import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Pair; -import org.whispersystems.websocket.auth.ReadOnly; @Path("/v1/sticker") @Tag(name = "Stickers") @@ -47,10 +46,10 @@ public class StickerController { @GET @Produces(MediaType.APPLICATION_JSON) @Path("/pack/form/{count}") - public StickerPackFormUploadAttributes getStickersForm(@ReadOnly @Auth AuthenticatedDevice auth, + public StickerPackFormUploadAttributes getStickersForm(@Auth AuthenticatedDevice auth, @PathParam("count") @Min(1) @Max(201) int stickerCount) throws RateLimitExceededException { - rateLimiters.getStickerPackLimiter().validate(auth.getAccount().getUuid()); + rateLimiters.getStickerPackLimiter().validate(auth.getAccountIdentifier()); ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC); String packId = generatePackId(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SubscriptionController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SubscriptionController.java index ab644223e..0b00aea24 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SubscriptionController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SubscriptionController.java @@ -88,7 +88,6 @@ import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; -import org.whispersystems.websocket.auth.ReadOnly; @Path("/v1/subscription") @io.swagger.v3.oas.annotations.tags.Tag(name = "Subscriptions") @@ -220,7 +219,7 @@ public class SubscriptionController { @Path("/{subscriberId}") @Produces(MediaType.APPLICATION_JSON) public CompletableFuture deleteSubscriber( - @ReadOnly @Auth Optional authenticatedAccount, + @Auth Optional authenticatedAccount, @PathParam("subscriberId") String subscriberId) throws SubscriptionException { SubscriberCredentials subscriberCredentials = SubscriberCredentials.process(authenticatedAccount, subscriberId, clock); @@ -232,7 +231,7 @@ public class SubscriptionController { @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) public CompletableFuture updateSubscriber( - @ReadOnly @Auth Optional authenticatedAccount, + @Auth Optional authenticatedAccount, @PathParam("subscriberId") String subscriberId) throws SubscriptionException { SubscriberCredentials subscriberCredentials = SubscriberCredentials.process(authenticatedAccount, subscriberId, clock); @@ -248,7 +247,7 @@ public class SubscriptionController { @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) public CompletableFuture createPaymentMethod( - @ReadOnly @Auth Optional authenticatedAccount, + @Auth Optional authenticatedAccount, @PathParam("subscriberId") String subscriberId, @QueryParam("type") @DefaultValue("CARD") PaymentMethod paymentMethodType, @HeaderParam(HttpHeaders.USER_AGENT) @Nullable final String userAgentString) throws SubscriptionException { @@ -284,7 +283,7 @@ public class SubscriptionController { @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) public CompletableFuture createPayPalPaymentMethod( - @ReadOnly @Auth Optional authenticatedAccount, + @Auth Optional authenticatedAccount, @PathParam("subscriberId") String subscriberId, @NotNull @Valid CreatePayPalBillingAgreementRequest request, @Context ContainerRequestContext containerRequestContext, @@ -323,7 +322,7 @@ public class SubscriptionController { @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) public CompletableFuture setDefaultPaymentMethodWithProcessor( - @ReadOnly @Auth Optional authenticatedAccount, + @Auth Optional authenticatedAccount, @PathParam("subscriberId") String subscriberId, @PathParam("processor") PaymentProvider processor, @PathParam("paymentMethodToken") @NotEmpty String paymentMethodToken) throws SubscriptionException { @@ -360,7 +359,7 @@ public class SubscriptionController { @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) public CompletableFuture setSubscriptionLevel( - @ReadOnly @Auth Optional authenticatedAccount, + @Auth Optional authenticatedAccount, @PathParam("subscriberId") String subscriberId, @PathParam("level") long level, @PathParam("currency") String currency, @@ -432,7 +431,7 @@ public class SubscriptionController { @ApiResponse(responseCode = "409", description = "subscriberId is already linked to a processor that does not support appstore payments. Delete this subscriberId and use a new one.") @ApiResponse(responseCode = "429", description = "Rate limit exceeded.") public CompletableFuture setAppStoreSubscription( - @ReadOnly @Auth Optional authenticatedAccount, + @Auth Optional authenticatedAccount, @PathParam("subscriberId") String subscriberId, @PathParam("originalTransactionId") String originalTransactionId) throws SubscriptionException { final SubscriberCredentials subscriberCredentials = @@ -473,7 +472,7 @@ public class SubscriptionController { @ApiResponse(responseCode = "404", description = "No such subscriberId exists or subscriberId is malformed or the purchaseToken does not exist") @ApiResponse(responseCode = "409", description = "subscriberId is already linked to a processor that does not support Play Billing. Delete this subscriberId and use a new one.") public CompletableFuture setPlayStoreSubscription( - @ReadOnly @Auth Optional authenticatedAccount, + @Auth Optional authenticatedAccount, @PathParam("subscriberId") String subscriberId, @PathParam("purchaseToken") String purchaseToken) throws SubscriptionException { final SubscriberCredentials subscriberCredentials = @@ -627,7 +626,7 @@ public class SubscriptionController { @ApiResponse(responseCode = "403", description = "subscriberId authentication failure OR account authentication is present") @ApiResponse(responseCode = "404", description = "No such subscriberId exists or subscriberId is malformed") public CompletableFuture getSubscriptionInformation( - @ReadOnly @Auth Optional authenticatedAccount, + @Auth Optional authenticatedAccount, @PathParam("subscriberId") String subscriberId) throws SubscriptionException { SubscriberCredentials subscriberCredentials = SubscriberCredentials.process(authenticatedAccount, subscriberId, clock); @@ -662,7 +661,7 @@ public class SubscriptionController { @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) public CompletableFuture createSubscriptionReceiptCredentials( - @ReadOnly @Auth Optional authenticatedAccount, + @Auth Optional authenticatedAccount, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @PathParam("subscriberId") String subscriberId, @NotNull @Valid GetReceiptCredentialsRequest request) throws SubscriptionException { @@ -691,7 +690,7 @@ public class SubscriptionController { @Path("/{subscriberId}/default_payment_method_for_ideal/{setupIntentId}") @Produces(MediaType.APPLICATION_JSON) public CompletableFuture setDefaultPaymentMethodForIdeal( - @ReadOnly @Auth Optional authenticatedAccount, + @Auth Optional authenticatedAccount, @PathParam("subscriberId") String subscriberId, @PathParam("setupIntentId") @NotEmpty String setupIntentId) throws SubscriptionException { SubscriberCredentials subscriberCredentials = diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/VerificationSessionRateLimitExceededException.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/VerificationSessionRateLimitExceededException.java index c96c6eae1..c615743d5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/VerificationSessionRateLimitExceededException.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/VerificationSessionRateLimitExceededException.java @@ -5,9 +5,9 @@ package org.whispersystems.textsecuregcm.controllers; -import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession; -import javax.annotation.Nullable; import java.time.Duration; +import javax.annotation.Nullable; +import org.whispersystems.textsecuregcm.entities.RegistrationServiceSession; public class VerificationSessionRateLimitExceededException extends RateLimitExceededException { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java index c3a0acd42..3c5953905 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java @@ -10,15 +10,14 @@ import com.webauthn4j.converter.jackson.deserializer.json.ByteArrayBase64Deseria import io.micrometer.core.instrument.Metrics; import io.swagger.v3.oas.annotations.media.Schema; import jakarta.validation.constraints.AssertTrue; -import javax.annotation.Nullable; import jakarta.validation.constraints.NotNull; import jakarta.validation.constraints.Size; +import java.util.Arrays; +import java.util.Objects; +import javax.annotation.Nullable; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; -import org.whispersystems.textsecuregcm.storage.Account; -import java.util.Arrays; -import java.util.Objects; public record IncomingMessage(int type, byte destinationDeviceId, @@ -35,7 +34,7 @@ public record IncomingMessage(int type, MetricsUtil.name(IncomingMessage.class, "rejectInvalidEnvelopeType"); public MessageProtos.Envelope toEnvelope(final ServiceIdentifier destinationIdentifier, - @Nullable Account sourceAccount, + @Nullable AciServiceIdentifier sourceServiceIdentifier, @Nullable Byte sourceDeviceId, final long timestamp, final boolean story, @@ -54,9 +53,9 @@ public record IncomingMessage(int type, .setEphemeral(ephemeral) .setUrgent(urgent); - if (sourceAccount != null && sourceDeviceId != null) { + if (sourceServiceIdentifier != null && sourceDeviceId != null) { envelopeBuilder - .setSourceServiceId(new AciServiceIdentifier(sourceAccount.getUuid()).toServiceIdentifierString()) + .setSourceServiceId(sourceServiceIdentifier.toServiceIdentifierString()) .setSourceDevice(sourceDeviceId.intValue()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/filters/RestDeprecationFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RestDeprecationFilter.java index bb7a59e34..aa973a681 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/filters/RestDeprecationFilter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RestDeprecationFilter.java @@ -96,8 +96,8 @@ public class RestDeprecationFilter implements ContainerRequestFilter { return false; } - if (securityContext.getUserPrincipal() instanceof AuthenticatedDevice ad) { - return experimentEnrollmentManager.isEnrolled(ad.getAccount().getUuid(), AUTHENTICATED_EXPERIMENT_NAME); + if (securityContext.getUserPrincipal() instanceof AuthenticatedDevice authenticatedDevice) { + return experimentEnrollmentManager.isEnrolled(authenticatedDevice.getAccountIdentifier(), AUTHENTICATED_EXPERIMENT_NAME); } else { log.error("Security context was not null but user principal was of type {}", securityContext.getUserPrincipal().getClass().getName()); return false; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/spam/SpamChecker.java b/service/src/main/java/org/whispersystems/textsecuregcm/spam/SpamChecker.java index 42058cd79..29eec2213 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/spam/SpamChecker.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/spam/SpamChecker.java @@ -9,8 +9,6 @@ import java.util.Optional; import jakarta.ws.rs.core.Response; import org.signal.chat.messages.SendMessageResponse; import org.signal.chat.messages.SendMultiRecipientMessageResponse; -import org.whispersystems.textsecuregcm.auth.AccountAndAuthenticatedDeviceHolder; -import org.whispersystems.textsecuregcm.auth.grpc.AuthenticatedDevice; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; import org.whispersystems.textsecuregcm.storage.Account; @@ -31,7 +29,7 @@ public interface SpamChecker { SpamCheckResult checkForIndividualRecipientSpamHttp( final MessageType messageType, final ContainerRequestContext requestContext, - final Optional maybeSource, + final Optional maybeSource, final Optional maybeDestination, final ServiceIdentifier destinationIdentifier); @@ -79,7 +77,7 @@ public interface SpamChecker { @Override public SpamCheckResult checkForIndividualRecipientSpamHttp(final MessageType messageType, final ContainerRequestContext requestContext, - final Optional maybeSource, + final Optional maybeSource, final Optional maybeDestination, final ServiceIdentifier destinationIdentifier) { @@ -95,7 +93,7 @@ public interface SpamChecker { @Override public SpamCheckResult> checkForIndividualRecipientSpamGrpc(final MessageType messageType, - final Optional maybeSource, + final Optional maybeSource, final Optional maybeDestination, final ServiceIdentifier destinationIdentifier) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountPrincipalSupplier.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountPrincipalSupplier.java deleted file mode 100644 index 893a9770a..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountPrincipalSupplier.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright 2024 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.textsecuregcm.storage; - -import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; -import org.whispersystems.websocket.auth.PrincipalSupplier; - -public class AccountPrincipalSupplier implements PrincipalSupplier { - - private final AccountsManager accountsManager; - - public AccountPrincipalSupplier(final AccountsManager accountsManager) { - this.accountsManager = accountsManager; - } - - @Override - public AuthenticatedDevice refresh(final AuthenticatedDevice oldAccount) { - final Account account = accountsManager.getByAccountIdentifier(oldAccount.getAccount().getUuid()) - .orElseThrow(() -> new RefreshingAccountNotFoundException("Could not find account")); - final Device device = account.getDevice(oldAccount.getAuthenticatedDevice().getId()) - .orElseThrow(() -> new RefreshingAccountNotFoundException("Could not find device")); - return new AuthenticatedDevice(account, device); - } - - @Override - public AuthenticatedDevice deepCopy(final AuthenticatedDevice authenticatedDevice) { - final Account cloned = AccountUtil.cloneAccountAsNotStale(authenticatedDevice.getAccount()); - return new AuthenticatedDevice( - cloned, - cloned.getDevice(authenticatedDevice.getAuthenticatedDevice().getId()) - .orElseThrow(() -> new IllegalStateException( - "Could not find device from a clone of an account where the device was present"))); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index 0f27deddf..85f8bc06d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -600,13 +600,10 @@ public class AccountsManager extends RedisPubSubAdapter implemen } /** - * Unlink a device from the given account. The device will be immediately disconnected if it is - * connected to any chat frontend, but it is the caller's responsibility to make sure that the - * account's *other* devices are disconnected, either by use of - * {@link org.whispersystems.textsecuregcm.auth.LinkedDeviceRefreshRequirementProvider} or - * directly by calling {@link DeviceDisconnectionManager#requestDisconnection}. + * Unlink a device from the given account. The device will be immediately disconnected if it is connected to any chat + * frontend. * - * @returns the updated Account + * @return the updated Account */ public CompletableFuture removeDevice(final Account account, final byte deviceId) { if (deviceId == Device.PRIMARY_ID) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index f448cbfdd..b609107c7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.websocket; import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; import io.micrometer.core.instrument.Tags; +import java.util.Optional; import java.util.concurrent.ScheduledExecutorService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -20,7 +21,10 @@ import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ReceiptSender; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; +import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.websocket.session.WebSocketSessionContext; import org.whispersystems.websocket.setup.WebSocketConnectListener; @@ -36,6 +40,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { private static final Logger log = LoggerFactory.getLogger(AuthenticatedConnectListener.class); + private final AccountsManager accountsManager; private final ReceiptSender receiptSender; private final MessagesManager messagesManager; private final MessageMetrics messageMetrics; @@ -51,7 +56,9 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { private final OpenWebSocketCounter openAuthenticatedWebSocketCounter; private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter; - public AuthenticatedConnectListener(ReceiptSender receiptSender, + public AuthenticatedConnectListener( + AccountsManager accountsManager, + ReceiptSender receiptSender, MessagesManager messagesManager, MessageMetrics messageMetrics, PushNotificationManager pushNotificationManager, @@ -62,6 +69,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { ClientReleaseManager clientReleaseManager, MessageDeliveryLoopMonitor messageDeliveryLoopMonitor, final ExperimentEnrollmentManager experimentEnrollmentManager) { + + this.accountsManager = accountsManager; this.receiptSender = receiptSender; this.messagesManager = messagesManager; this.messageMetrics = messageMetrics; @@ -82,7 +91,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { } @Override - public void onWebSocketConnect(WebSocketSessionContext context) { + public void onWebSocketConnect(final WebSocketSessionContext context) { final boolean authenticated = (context.getAuthenticated() != null); final OpenWebSocketCounter openWebSocketCounter = @@ -92,12 +101,24 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { if (authenticated) { final AuthenticatedDevice auth = context.getAuthenticated(AuthenticatedDevice.class); + + final Optional maybeAuthenticatedAccount = accountsManager.getByAccountIdentifier(auth.getAccountIdentifier()); + final Optional maybeAuthenticatedDevice = maybeAuthenticatedAccount.flatMap(account -> account.getDevice(auth.getDeviceId()));; + + if (maybeAuthenticatedAccount.isEmpty() || maybeAuthenticatedDevice.isEmpty()) { + log.warn("{}:{} not found when opening authenticated WebSocket", auth.getAccountIdentifier(), auth.getDeviceId()); + + context.getClient().close(1011, "Unexpected error initializing connection"); + return; + } + final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, messageMetrics, pushNotificationManager, pushNotificationScheduler, - auth, + maybeAuthenticatedAccount.get(), + maybeAuthenticatedDevice.get(), context.getClient(), scheduledExecutorService, messageDeliveryScheduler, @@ -110,8 +131,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { // receive push notifications for inbound messages. We should do this first because, at this point, the // connection has already closed and attempts to actually deliver a message via the connection will not succeed. // It's preferable to start sending push notifications as soon as possible. - webSocketConnectionEventManager.handleClientDisconnected(auth.getAccount().getUuid(), - auth.getAuthenticatedDevice().getId()); + webSocketConnectionEventManager.handleClientDisconnected(auth.getAccountIdentifier(), auth.getDeviceId()); // Finally, stop trying to deliver messages and send a push notification if the connection is aware of any // undelivered messages. @@ -127,7 +147,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { // Finally, we register this client's presence, which suppresses push notifications. We do this last because // receiving extra push notifications is generally preferable to missing out on a push notification. - webSocketConnectionEventManager.handleClientConnected(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), connection); + webSocketConnectionEventManager.handleClientConnected(auth.getAccountIdentifier(), auth.getDeviceId(), connection); } catch (final Exception e) { log.warn("Failed to initialize websocket", e); context.getClient().close(1011, "Unexpected error initializing connection"); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java index e5b445fe9..c88fa5e29 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java @@ -9,41 +9,39 @@ import static org.whispersystems.textsecuregcm.util.HeaderUtils.basicCredentials import com.google.common.net.HttpHeaders; import javax.annotation.Nullable; +import io.dropwizard.auth.basic.BasicCredentials; import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; -import org.whispersystems.websocket.ReusableAuth; import org.whispersystems.websocket.auth.InvalidCredentialsException; -import org.whispersystems.websocket.auth.PrincipalSupplier; import org.whispersystems.websocket.auth.WebSocketAuthenticator; +import java.util.Optional; public class WebSocketAccountAuthenticator implements WebSocketAuthenticator { - private static final ReusableAuth CREDENTIALS_NOT_PRESENTED = ReusableAuth.anonymous(); - private final AccountAuthenticator accountAuthenticator; - private final PrincipalSupplier principalSupplier; - public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator, - final PrincipalSupplier principalSupplier) { + public WebSocketAccountAuthenticator(final AccountAuthenticator accountAuthenticator) { this.accountAuthenticator = accountAuthenticator; - this.principalSupplier = principalSupplier; } @Override - public ReusableAuth authenticate(final UpgradeRequest request) + public Optional authenticate(final UpgradeRequest request) throws InvalidCredentialsException { @Nullable final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION); if (authHeader == null) { - return CREDENTIALS_NOT_PRESENTED; + return Optional.empty(); } - return basicCredentialsFromAuthHeader(authHeader) - .flatMap(accountAuthenticator::authenticate) - .map(authenticatedAccount -> ReusableAuth.authenticated(authenticatedAccount, this.principalSupplier)) + final BasicCredentials credentials = basicCredentialsFromAuthHeader(authHeader) .orElseThrow(InvalidCredentialsException::new); + + final AuthenticatedDevice authenticatedDevice = accountAuthenticator.authenticate(credentials) + .orElseThrow(InvalidCredentialsException::new); + + return Optional.of(authenticatedDevice); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index a3d9083f6..6d32574f1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -37,7 +37,6 @@ import org.eclipse.jetty.util.StaticException; import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; @@ -52,6 +51,7 @@ import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.WebSocketConnectionEventListener; +import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; @@ -123,7 +123,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener { private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor; private final ExperimentEnrollmentManager experimentEnrollmentManager; - private final AuthenticatedDevice auth; + private final Account authenticatedAccount; + private final Device authenticatedDevice; private final WebSocketClient client; private final int sendFuturesTimeoutMillis; @@ -156,7 +157,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener { MessageMetrics messageMetrics, PushNotificationManager pushNotificationManager, PushNotificationScheduler pushNotificationScheduler, - AuthenticatedDevice auth, + Account authenticatedAccount, + Device authenticatedDevice, WebSocketClient client, ScheduledExecutorService scheduledExecutorService, Scheduler messageDeliveryScheduler, @@ -169,7 +171,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener { messageMetrics, pushNotificationManager, pushNotificationScheduler, - auth, + authenticatedAccount, + authenticatedDevice, client, DEFAULT_SEND_FUTURES_TIMEOUT_MILLIS, scheduledExecutorService, @@ -184,7 +187,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener { MessageMetrics messageMetrics, PushNotificationManager pushNotificationManager, PushNotificationScheduler pushNotificationScheduler, - AuthenticatedDevice auth, + Account authenticatedAccount, + Device authenticatedDevice, WebSocketClient client, int sendFuturesTimeoutMillis, ScheduledExecutorService scheduledExecutorService, @@ -198,7 +202,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener { this.messageMetrics = messageMetrics; this.pushNotificationManager = pushNotificationManager; this.pushNotificationScheduler = pushNotificationScheduler; - this.auth = auth; + this.authenticatedAccount = authenticatedAccount; + this.authenticatedDevice = authenticatedDevice; this.client = client; this.sendFuturesTimeoutMillis = sendFuturesTimeoutMillis; this.scheduledExecutorService = scheduledExecutorService; @@ -209,7 +214,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener { } public void start() { - pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), auth.getAuthenticatedDevice(), client.getUserAgent()); + pushNotificationManager.handleMessagesRetrieved(authenticatedAccount, authenticatedDevice, client.getUserAgent()); queueDrainStartTime.set(System.currentTimeMillis()); processStoredMessages(); } @@ -229,8 +234,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener { client.close(1000, "OK"); if (storedMessageState.get() != StoredMessageState.EMPTY) { - pushNotificationScheduler.scheduleDelayedNotification(auth.getAccount(), - auth.getAuthenticatedDevice(), + pushNotificationScheduler.scheduleDelayedNotification(authenticatedAccount, + authenticatedDevice, CLOSE_WITH_PENDING_MESSAGES_NOTIFICATION_DELAY); } } @@ -242,7 +247,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener { sendMessageCounter.increment(); sentMessageCounter.increment(); bytesSentCounter.increment(body.map(bytes -> bytes.length).orElse(0)); - messageMetrics.measureAccountEnvelopeUuidMismatches(auth.getAccount(), message); + messageMetrics.measureAccountEnvelopeUuidMismatches(authenticatedAccount, message); // X-Signal-Key: false must be sent until Android stops assuming it missing means true return client.sendRequest("PUT", "/api/v1/message", @@ -253,7 +258,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener { } else { messageMetrics.measureOutgoingMessageLatency(message.getServerTimestamp(), "websocket", - auth.getAuthenticatedDevice().isPrimary(), + authenticatedDevice.isPrimary(), message.getUrgent(), message.getEphemeral(), client.getUserAgent(), @@ -263,12 +268,12 @@ public class WebSocketConnection implements WebSocketConnectionEventListener { final CompletableFuture result; if (isSuccessResponse(response)) { - result = messagesManager.delete(auth.getAccount().getUuid(), auth.getAuthenticatedDevice(), + result = messagesManager.delete(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice, storedMessageInfo.guid(), storedMessageInfo.serverTimestamp()) .thenApply(ignored -> null); if (message.getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT) { - recordMessageDeliveryDuration(message.getServerTimestamp(), auth.getAuthenticatedDevice()); + recordMessageDeliveryDuration(message.getServerTimestamp(), authenticatedDevice); sendDeliveryReceiptFor(message); } } else { @@ -307,7 +312,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener { try { receiptSender.sendReceipt(ServiceIdentifier.valueOf(message.getDestinationServiceId()), - auth.getAuthenticatedDevice().getId(), AciServiceIdentifier.valueOf(message.getSourceServiceId()), + authenticatedDevice.getId(), AciServiceIdentifier.valueOf(message.getSourceServiceId()), message.getClientTimestamp()); } catch (IllegalArgumentException e) { logger.error("Could not parse UUID: {}", message.getSourceServiceId()); @@ -338,7 +343,6 @@ public class WebSocketConnection implements WebSocketConnectionEventListener { // Cleared the queue! Send a queue empty message if we need to consecutiveRetries.set(0); if (sentInitialQueueEmptyMessage.compareAndSet(false, true)) { - final Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(client.getUserAgent())); final long drainDuration = System.currentTimeMillis() - queueDrainStartTime.get(); @@ -399,7 +403,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener { final CompletableFuture queueCleared = new CompletableFuture<>(); final Publisher messages = - messagesManager.getMessagesForDeviceReactive(auth.getAccount().getUuid(), auth.getAuthenticatedDevice(), cachedMessagesOnly); + messagesManager.getMessagesForDeviceReactive(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice, cachedMessagesOnly); final AtomicBoolean hasSentFirstMessage = new AtomicBoolean(); final AtomicBoolean hasErrored = new AtomicBoolean(); @@ -410,8 +414,8 @@ public class WebSocketConnection implements WebSocketConnectionEventListener { .limitRate(MESSAGE_PUBLISHER_LIMIT_RATE) .doOnNext(envelope -> { if (hasSentFirstMessage.compareAndSet(false, true)) { - messageDeliveryLoopMonitor.recordDeliveryAttempt(auth.getAccount().getIdentifier(IdentityType.ACI), - auth.getAuthenticatedDevice().getId(), + messageDeliveryLoopMonitor.recordDeliveryAttempt(authenticatedAccount.getIdentifier(IdentityType.ACI), + authenticatedDevice.getId(), UUID.fromString(envelope.getServerGuid()), client.getUserAgent(), "websocket"); @@ -471,7 +475,7 @@ public class WebSocketConnection implements WebSocketConnectionEventListener { final UUID messageGuid = UUID.fromString(envelope.getServerGuid()); if (envelope.getStory() && !client.shouldDeliverStories()) { - messagesManager.delete(auth.getAccount().getUuid(), auth.getAuthenticatedDevice(), messageGuid, envelope.getServerTimestamp()); + messagesManager.delete(authenticatedAccount.getIdentifier(IdentityType.ACI), authenticatedDevice, messageGuid, envelope.getServerTimestamp()); return CompletableFuture.completedFuture(null); } else { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketResourceProviderIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketResourceProviderIntegrationTest.java index fd68380c5..d74e599b0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketResourceProviderIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketResourceProviderIntegrationTest.java @@ -21,6 +21,7 @@ import jakarta.ws.rs.core.MediaType; import java.io.IOException; import java.net.URI; import java.util.EnumSet; +import java.util.Optional; import org.apache.commons.lang3.RandomStringUtils; import org.eclipse.jetty.websocket.client.WebSocketClient; import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; @@ -34,9 +35,7 @@ import org.junit.jupiter.params.provider.ValueSource; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener; -import org.whispersystems.websocket.ReusableAuth; import org.whispersystems.websocket.WebSocketResourceProviderFactory; -import org.whispersystems.websocket.auth.PrincipalSupplier; import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.messages.WebSocketResponseMessage; import org.whispersystems.websocket.setup.WebSocketEnvironment; @@ -78,8 +77,7 @@ public class WebsocketResourceProviderIntegrationTest { .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); webSocketEnvironment.jersey().register(testController); webSocketEnvironment.jersey().register(new RemoteAddressFilter()); - webSocketEnvironment.setAuthenticator(upgradeRequest -> - ReusableAuth.authenticated(mock(AuthenticatedDevice.class), PrincipalSupplier.forImmutablePrincipal())); + webSocketEnvironment.setAuthenticator(upgradeRequest -> Optional.of(mock(AuthenticatedDevice.class))); webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE); webSocketEnvironment.setConnectListener(webSocketSessionContext -> { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketReuseAuthIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketReuseAuthIntegrationTest.java deleted file mode 100644 index 47c7c4bef..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/WebsocketReuseAuthIntegrationTest.java +++ /dev/null @@ -1,279 +0,0 @@ -package org.whispersystems.textsecuregcm; - -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.reset; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; -import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME; - -import io.dropwizard.auth.Auth; -import io.dropwizard.core.Application; -import io.dropwizard.core.Configuration; -import io.dropwizard.core.setup.Environment; -import io.dropwizard.testing.junit5.DropwizardAppExtension; -import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; -import jakarta.servlet.DispatcherType; -import jakarta.servlet.ServletRegistration; -import jakarta.ws.rs.GET; -import jakarta.ws.rs.Path; -import jakarta.ws.rs.PathParam; -import java.io.IOException; -import java.net.URI; -import java.util.EnumSet; -import java.util.List; -import java.util.Optional; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; -import java.util.stream.IntStream; -import org.eclipse.jetty.websocket.client.WebSocketClient; -import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; -import org.glassfish.jersey.server.ManagedAsync; -import org.glassfish.jersey.server.ServerProperties; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; -import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; -import org.whispersystems.textsecuregcm.storage.RefreshingAccountNotFoundException; -import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener; -import org.whispersystems.websocket.ReusableAuth; -import org.whispersystems.websocket.WebSocketResourceProviderFactory; -import org.whispersystems.websocket.auth.PrincipalSupplier; -import org.whispersystems.websocket.auth.ReadOnly; -import org.whispersystems.websocket.configuration.WebSocketConfiguration; -import org.whispersystems.websocket.messages.WebSocketResponseMessage; -import org.whispersystems.websocket.setup.WebSocketEnvironment; - -@ExtendWith(DropwizardExtensionsSupport.class) -public class WebsocketReuseAuthIntegrationTest { - - private static final AuthenticatedDevice ACCOUNT = mock(AuthenticatedDevice.class); - @SuppressWarnings("unchecked") - private static final PrincipalSupplier PRINCIPAL_SUPPLIER = mock(PrincipalSupplier.class); - private static final DropwizardAppExtension DROPWIZARD_APP_EXTENSION = - new DropwizardAppExtension<>(TestApplication.class); - - - private WebSocketClient client; - - @BeforeEach - void setUp() throws Exception { - reset(PRINCIPAL_SUPPLIER); - reset(ACCOUNT); - when(ACCOUNT.getName()).thenReturn("original"); - client = new WebSocketClient(); - client.start(); - } - - @AfterEach - void tearDown() throws Exception { - client.stop(); - } - - - public static class TestApplication extends Application { - - @Override - public void run(final Configuration configuration, final Environment environment) throws Exception { - final TestController testController = new TestController(); - - final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration(); - - final WebSocketEnvironment webSocketEnvironment = - new WebSocketEnvironment<>(environment, webSocketConfiguration); - - environment.jersey().register(testController); - environment.servlets() - .addFilter("RemoteAddressFilter", new RemoteAddressFilter()) - .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); - webSocketEnvironment.jersey().register(testController); - webSocketEnvironment.jersey().register(new RemoteAddressFilter()); - webSocketEnvironment.setAuthenticator(upgradeRequest -> ReusableAuth.authenticated(ACCOUNT, PRINCIPAL_SUPPLIER)); - - webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE); - webSocketEnvironment.setConnectListener(webSocketSessionContext -> { - }); - - final WebSocketResourceProviderFactory webSocketServlet = - new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class, - webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME); - - JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); - - final ServletRegistration.Dynamic websocketServlet = - environment.servlets().addServlet("WebSocket", webSocketServlet); - - websocketServlet.addMapping("/websocket"); - websocketServlet.setAsyncSupported(true); - } - } - - private WebSocketResponseMessage make1WebsocketRequest(final String requestPath) throws IOException { - - final TestWebsocketListener testWebsocketListener = new TestWebsocketListener(); - - client.connect(testWebsocketListener, - URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort()))); - return testWebsocketListener.doGet(requestPath).join(); - } - - @ParameterizedTest - @ValueSource(strings = {"/test/read-auth", "/test/optional-read-auth"}) - public void readAuth(final String path) throws IOException { - final WebSocketResponseMessage response = make1WebsocketRequest(path); - assertThat(response.getStatus()).isEqualTo(200); - verifyNoMoreInteractions(PRINCIPAL_SUPPLIER); - } - - @ParameterizedTest - @ValueSource(strings = {"/test/write-auth", "/test/optional-write-auth"}) - public void writeAuth(final String path) throws IOException { - final AuthenticatedDevice copiedAccount = mock(AuthenticatedDevice.class); - when(copiedAccount.getName()).thenReturn("copy"); - when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(copiedAccount); - - final WebSocketResponseMessage response = make1WebsocketRequest(path); - assertThat(response.getStatus()).isEqualTo(200); - assertThat(response.getBody().map(String::new)).get().isEqualTo("copy"); - verify(PRINCIPAL_SUPPLIER, times(1)).deepCopy(any()); - verifyNoMoreInteractions(PRINCIPAL_SUPPLIER); - } - - @Test - public void readAfterWrite() throws IOException { - when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(ACCOUNT); - final AuthenticatedDevice account2 = mock(AuthenticatedDevice.class); - when(account2.getName()).thenReturn("refresh"); - when(PRINCIPAL_SUPPLIER.refresh(any())).thenReturn(account2); - - final TestWebsocketListener testWebsocketListener = new TestWebsocketListener(); - client.connect(testWebsocketListener, - URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort()))); - - final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/read-auth").join(); - assertThat(readResponse.getBody().map(String::new)).get().isEqualTo("original"); - - final WebSocketResponseMessage writeResponse = testWebsocketListener.doGet("/test/write-auth").join(); - assertThat(writeResponse.getBody().map(String::new)).get().isEqualTo("original"); - - final WebSocketResponseMessage readResponse2 = testWebsocketListener.doGet("/test/read-auth").join(); - assertThat(readResponse2.getBody().map(String::new)).get().isEqualTo("refresh"); - } - - @Test - public void readAfterWriteRefreshFails() throws IOException { - when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(ACCOUNT); - when(PRINCIPAL_SUPPLIER.refresh(any())).thenThrow(RefreshingAccountNotFoundException.class); - - final TestWebsocketListener testWebsocketListener = new TestWebsocketListener(); - client.connect(testWebsocketListener, - URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort()))); - - final WebSocketResponseMessage writeResponse = testWebsocketListener.doGet("/test/write-auth").join(); - assertThat(writeResponse.getBody().map(String::new)).get().isEqualTo("original"); - - final WebSocketResponseMessage readResponse2 = testWebsocketListener.doGet("/test/read-auth").join(); - assertThat(readResponse2.getStatus()).isEqualTo(500); - } - - @Test - public void readConcurrentWithWrite() throws IOException, ExecutionException, InterruptedException, TimeoutException { - final AuthenticatedDevice deepCopy = mock(AuthenticatedDevice.class); - when(deepCopy.getName()).thenReturn("deepCopy"); - when(PRINCIPAL_SUPPLIER.deepCopy(any())).thenReturn(deepCopy); - - final AuthenticatedDevice refresh = mock(AuthenticatedDevice.class); - when(refresh.getName()).thenReturn("refresh"); - when(PRINCIPAL_SUPPLIER.refresh(any())).thenReturn(refresh); - - final TestWebsocketListener testWebsocketListener = new TestWebsocketListener(); - client.connect(testWebsocketListener, - URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort()))); - - // start a write request that takes a while to finish - final CompletableFuture writeResponse = - testWebsocketListener.doGet("/test/start-delayed-write/foo"); - - // send a bunch of reads, they should reflect the original auth - final List> futures = IntStream.range(0, 10) - .boxed().map(i -> testWebsocketListener.doGet("/test/read-auth")) - .toList(); - CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new)).join(); - for (CompletableFuture future : futures) { - assertThat(future.join().getBody().map(String::new)).get().isEqualTo("original"); - } - - assertThat(writeResponse.isDone()).isFalse(); - - // finish the delayed write request - testWebsocketListener.doGet("/test/finish-delayed-write/foo").get(1, TimeUnit.SECONDS); - assertThat(writeResponse.join().getBody().map(String::new)).get().isEqualTo("deepCopy"); - - // subsequent reads should have the refreshed auth - final WebSocketResponseMessage readResponse = testWebsocketListener.doGet("/test/read-auth").join(); - assertThat(readResponse.getBody().map(String::new)).get().isEqualTo("refresh"); - } - - - @Path("/test") - public static class TestController { - - private final ConcurrentHashMap delayedWriteLatches = new ConcurrentHashMap<>(); - - @GET - @Path("/read-auth") - @ManagedAsync - public String readAuth(@ReadOnly @Auth final AuthenticatedDevice account) { - return account.getName(); - } - - @GET - @Path("/optional-read-auth") - @ManagedAsync - public String optionalReadAuth(@ReadOnly @Auth final Optional account) { - return account.map(AuthenticatedDevice::getName).orElse("empty"); - } - - @GET - @Path("/write-auth") - @ManagedAsync - public String writeAuth(@Auth final AuthenticatedDevice account) { - return account.getName(); - } - - @GET - @Path("/optional-write-auth") - @ManagedAsync - public String optionalWriteAuth(@Auth final Optional account) { - return account.map(AuthenticatedDevice::getName).orElse("empty"); - } - - @GET - @Path("/start-delayed-write/{id}") - @ManagedAsync - public String startDelayedWrite(@Auth final AuthenticatedDevice account, @PathParam("id") String id) - throws InterruptedException { - delayedWriteLatches.computeIfAbsent(id, i -> new CountDownLatch(1)).await(); - return account.getName(); - } - - @GET - @Path("/finish-delayed-write/{id}") - @ManagedAsync - public String finishDelayedWrite(@PathParam("id") String id) { - delayedWriteLatches.computeIfAbsent(id, i -> new CountDownLatch(1)).countDown(); - return "ok"; - } - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/CertificateGeneratorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/CertificateGeneratorTest.java index 9cbd2e350..a9cb942df 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/CertificateGeneratorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/CertificateGeneratorTest.java @@ -18,7 +18,6 @@ import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.Device; class CertificateGeneratorTest { @@ -37,7 +36,7 @@ class CertificateGeneratorTest { @Test void testCreateFor() throws IOException, InvalidKeyException, org.signal.libsignal.protocol.InvalidKeyException { final Account account = mock(Account.class); - final Device device = mock(Device.class); + final byte deviceId = 4; final CertificateGenerator certificateGenerator = new CertificateGenerator( Base64.getDecoder().decode(SIGNING_CERTIFICATE), Curve.decodePrivatePoint(Base64.getDecoder().decode(SIGNING_KEY)), 1); @@ -45,9 +44,8 @@ class CertificateGeneratorTest { when(account.getIdentityKey(IdentityType.ACI)).thenReturn(IDENTITY_KEY); when(account.getUuid()).thenReturn(UUID.randomUUID()); when(account.getNumber()).thenReturn("+18005551234"); - when(device.getId()).thenReturn((byte) 4); - assertTrue(certificateGenerator.createFor(account, device, true).length > 0); - assertTrue(certificateGenerator.createFor(account, device, false).length > 0); + assertTrue(certificateGenerator.createFor(account, deviceId, true).length > 0); + assertTrue(certificateGenerator.createFor(account, deviceId, false).length > 0); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest.java index a52240016..075e01e23 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest.java @@ -31,8 +31,6 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.util.TestClock; -import org.whispersystems.websocket.ReusableAuth; -import org.whispersystems.websocket.auth.PrincipalSupplier; class IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest { @@ -59,9 +57,9 @@ class IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest { final boolean expectPqKeyCheck, @Nullable final String expectedAlertHeader) { - final ReusableAuth reusableAuth = authenticatedDevice != null - ? ReusableAuth.authenticated(authenticatedDevice, PrincipalSupplier.forImmutablePrincipal()) - : ReusableAuth.anonymous(); + final Optional reusableAuth = authenticatedDevice != null + ? Optional.of(authenticatedDevice) + : Optional.empty(); final JettyServerUpgradeResponse response = mock(JettyServerUpgradeResponse.class); @@ -88,20 +86,24 @@ class IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilterTest { private static List handleAuthentication() { final Device activePrimaryDevice = mock(Device.class); + when(activePrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID); when(activePrimaryDevice.isPrimary()).thenReturn(true); when(activePrimaryDevice.getLastSeen()).thenReturn(CLOCK.millis()); final Device minIdlePrimaryDevice = mock(Device.class); + when(minIdlePrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID); when(minIdlePrimaryDevice.isPrimary()).thenReturn(true); when(minIdlePrimaryDevice.getLastSeen()) .thenReturn(CLOCK.instant().minus(MIN_IDLE_DURATION).minusSeconds(1).toEpochMilli()); final Device longIdlePrimaryDevice = mock(Device.class); + when(longIdlePrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID); when(longIdlePrimaryDevice.isPrimary()).thenReturn(true); when(longIdlePrimaryDevice.getLastSeen()) .thenReturn(CLOCK.instant().minus(IdlePrimaryDeviceAuthenticatedWebSocketUpgradeFilter.PQ_KEY_CHECK_THRESHOLD).minusSeconds(1).toEpochMilli()); final Device linkedDevice = mock(Device.class); + when(linkedDevice.getId()).thenReturn((byte) (Device.PRIMARY_ID + 1)); when(linkedDevice.isPrimary()).thenReturn(false); final Account accountWithActivePrimaryDevice = mock(Account.class); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/LinkedDeviceRefreshRequirementProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/LinkedDeviceRefreshRequirementProviderTest.java deleted file mode 100644 index 3c44f14c1..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/LinkedDeviceRefreshRequirementProviderTest.java +++ /dev/null @@ -1,328 +0,0 @@ -/* - * Copyright 2013 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.auth; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; - -import io.dropwizard.auth.Auth; -import io.dropwizard.auth.AuthDynamicFeature; -import io.dropwizard.auth.AuthValueFactoryProvider; -import io.dropwizard.auth.basic.BasicCredentialAuthFilter; -import io.dropwizard.jersey.DropwizardResourceConfig; -import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider; -import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; -import io.dropwizard.testing.junit5.ResourceExtension; -import jakarta.ws.rs.DELETE; -import jakarta.ws.rs.GET; -import jakarta.ws.rs.PUT; -import jakarta.ws.rs.Path; -import jakarta.ws.rs.PathParam; -import jakarta.ws.rs.client.Entity; -import jakarta.ws.rs.core.MediaType; -import jakarta.ws.rs.core.Response; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.security.Principal; -import java.time.Duration; -import java.util.Arrays; -import java.util.Base64; -import java.util.LinkedList; -import java.util.List; -import java.util.Optional; -import java.util.UUID; -import java.util.function.Supplier; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -import org.eclipse.jetty.websocket.api.RemoteEndpoint; -import org.eclipse.jetty.websocket.api.Session; -import org.eclipse.jetty.websocket.api.UpgradeRequest; -import org.eclipse.jetty.websocket.api.WriteCallback; -import org.glassfish.jersey.server.ApplicationHandler; -import org.glassfish.jersey.server.ResourceConfig; -import org.glassfish.jersey.server.monitoring.ApplicationEventListener; -import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Nested; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.mockito.ArgumentCaptor; -import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.AccountsManager; -import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; -import org.whispersystems.textsecuregcm.util.SystemMapper; -import org.whispersystems.websocket.ReusableAuth; -import org.whispersystems.websocket.WebSocketResourceProvider; -import org.whispersystems.websocket.auth.PrincipalSupplier; -import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; -import org.whispersystems.websocket.logging.WebsocketRequestLog; -import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory; -import org.whispersystems.websocket.messages.protobuf.SubProtocol; -import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider; - -@ExtendWith(DropwizardExtensionsSupport.class) -class LinkedDeviceRefreshRequirementProviderTest { - - private final ApplicationEventListener applicationEventListener = mock(ApplicationEventListener.class); - - private final Account account = new Account(); - private final Device authenticatedDevice = DevicesHelper.createDevice(Device.PRIMARY_ID); - - private final Supplier> principalSupplier = () -> Optional.of( - new TestPrincipal("test", account, authenticatedDevice)); - - private final ResourceExtension resources = ResourceExtension.builder() - .addProvider(new AuthDynamicFeature(new BasicCredentialAuthFilter.Builder() - .setAuthenticator(c -> principalSupplier.get()).buildAuthFilter())) - .addProvider(new AuthValueFactoryProvider.Binder<>(TestPrincipal.class)) - .addProvider(applicationEventListener) - .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addResource(new TestResource()) - .build(); - - private AccountsManager accountsManager; - private DisconnectionRequestManager disconnectionRequestManager; - - private LinkedDeviceRefreshRequirementProvider provider; - - @BeforeEach - void setup() { - accountsManager = mock(AccountsManager.class); - disconnectionRequestManager = mock(DisconnectionRequestManager.class); - - provider = new LinkedDeviceRefreshRequirementProvider(accountsManager); - - final WebsocketRefreshRequestEventListener listener = - new WebsocketRefreshRequestEventListener(disconnectionRequestManager, provider); - - when(applicationEventListener.onRequest(any())).thenReturn(listener); - - final UUID uuid = UUID.randomUUID(); - account.setUuid(uuid); - account.addDevice(authenticatedDevice); - IntStream.range(2, 4) - .forEach(deviceId -> account.addDevice(DevicesHelper.createDevice((byte) deviceId))); - - when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account)); - } - - @Test - void testDeviceAdded() { - final int initialDeviceCount = account.getDevices().size(); - - final List addedDeviceNames = List.of( - Base64.getEncoder().encodeToString("newDevice1".getBytes(StandardCharsets.UTF_8)), - Base64.getEncoder().encodeToString("newDevice2".getBytes(StandardCharsets.UTF_8))); - - final Response response = resources.getJerseyTest() - .target("/v1/test/account/devices") - .request() - .header("Authorization", - "Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8))) - .put(Entity.entity(addedDeviceNames, MediaType.APPLICATION_JSON_PATCH_JSON)); - - assertEquals(200, response.getStatus()); - - assertEquals(initialDeviceCount + addedDeviceNames.size(), account.getDevices().size()); - - verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of((byte) 1)); - verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of((byte) 2)); - verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of((byte) 3)); - } - - @ParameterizedTest - @ValueSource(ints = {1, 2}) - void testDeviceRemoved(final int removedDeviceCount) { - final List initialDeviceIds = account.getDevices().stream().map(Device::getId).toList(); - - final List deletedDeviceIds = account.getDevices().stream() - .map(Device::getId) - .filter(deviceId -> deviceId != Device.PRIMARY_ID) - .limit(removedDeviceCount) - .toList(); - - assert deletedDeviceIds.size() == removedDeviceCount; - - final String deletedDeviceIdsParam = deletedDeviceIds.stream().map(String::valueOf) - .collect(Collectors.joining(",")); - - final Response response = resources.getJerseyTest() - .target("/v1/test/account/devices/" + deletedDeviceIdsParam) - .request() - .header("Authorization", - "Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8))) - .delete(); - - assertEquals(200, response.getStatus()); - - initialDeviceIds.forEach(deviceId -> - verify(disconnectionRequestManager).requestDisconnection(account.getUuid(), List.of(deviceId))); - - verifyNoMoreInteractions(disconnectionRequestManager); - } - - @Test - void testOnEvent() { - Response response = resources.getJerseyTest() - .target("/v1/test/hello") - .request() - // no authorization required - .get(); - - assertEquals(200, response.getStatus()); - - response = resources.getJerseyTest() - .target("/v1/test/authorized") - .request() - .header("Authorization", - "Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8))) - .get(); - - assertEquals(200, response.getStatus()); - - verify(accountsManager, never()).getByAccountIdentifier(any(UUID.class)); - } - - @Nested - class WebSocket { - - private WebSocketResourceProvider provider; - private RemoteEndpoint remoteEndpoint; - - @BeforeEach - void setup() { - ResourceConfig resourceConfig = new DropwizardResourceConfig(); - resourceConfig.register(applicationEventListener); - resourceConfig.register(new TestResource()); - resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder()); - resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class)); - resourceConfig.register(new JacksonMessageBodyProvider(SystemMapper.jsonMapper())); - - ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); - WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); - - provider = new WebSocketResourceProvider<>("127.0.0.1", RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, - applicationHandler, requestLog, TestPrincipal.reusableAuth("test", account, authenticatedDevice), - new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); - - remoteEndpoint = mock(RemoteEndpoint.class); - Session session = mock(Session.class); - UpgradeRequest request = mock(UpgradeRequest.class); - - when(session.getRemote()).thenReturn(remoteEndpoint); - when(session.getUpgradeRequest()).thenReturn(request); - - provider.onWebSocketConnect(session); - } - - @Test - void testOnEvent() throws Exception { - - byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello", - new LinkedList<>(), Optional.empty()).toByteArray(); - - provider.onWebSocketBinary(message, 0, message.length); - - final SubProtocol.WebSocketResponseMessage response = verifyAndGetResponse(remoteEndpoint); - - assertEquals(200, response.getStatus()); - } - - private SubProtocol.WebSocketResponseMessage verifyAndGetResponse(final RemoteEndpoint remoteEndpoint) - throws IOException { - ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); - verify(remoteEndpoint).sendBytes(responseBytesCaptor.capture(), any(WriteCallback.class)); - - return SubProtocol.WebSocketMessage.parseFrom(responseBytesCaptor.getValue().array()).getResponse(); - } - } - - public static class TestPrincipal implements Principal, AccountAndAuthenticatedDeviceHolder { - - private final String name; - private final Account account; - private final Device device; - - private TestPrincipal(final String name, final Account account, final Device device) { - this.name = name; - this.account = account; - this.device = device; - } - - @Override - public String getName() { - return name; - } - - @Override - public Account getAccount() { - return account; - } - - @Override - public Device getAuthenticatedDevice() { - return device; - } - - public static ReusableAuth reusableAuth(final String name, final Account account, final Device device) { - return ReusableAuth.authenticated(new TestPrincipal(name, account, device), PrincipalSupplier.forImmutablePrincipal()); - } - - } - - @Path("/v1/test") - public static class TestResource { - - @GET - @Path("/hello") - public String testGetHello() { - return "Hello!"; - } - - @GET - @Path("/authorized") - public String testAuth(@Auth TestPrincipal principal) { - return "You’re in!"; - } - - @PUT - @Path("/account/devices") - @ChangesLinkedDevices - public String addDevices(@Auth TestPrincipal auth, List deviceNames) { - - deviceNames.forEach(name -> { - final Device device = DevicesHelper.createDevice(auth.getAccount().getNextDeviceId()); - auth.getAccount().addDevice(device); - - device.setName(name); - }); - - return "Added devices " + deviceNames; - } - - @DELETE - @Path("/account/devices/{deviceIds}") - @ChangesLinkedDevices - public String removeDevices(@Auth TestPrincipal auth, @PathParam("deviceIds") String deviceIds) { - - Arrays.stream(deviceIds.split(",")) - .map(Byte::valueOf) - .forEach(auth.getAccount()::removeDevice); - - return "Removed device(s) " + deviceIds; - } - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java deleted file mode 100644 index 55a47d34e..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/PhoneNumberChangeRefreshRequirementProviderTest.java +++ /dev/null @@ -1,294 +0,0 @@ -/* - * Copyright 2013-2021 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.auth; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.reset; -import static org.mockito.Mockito.timeout; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; -import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME; - -import com.google.common.net.HttpHeaders; -import io.dropwizard.auth.Auth; -import io.dropwizard.auth.AuthDynamicFeature; -import io.dropwizard.auth.basic.BasicCredentialAuthFilter; -import io.dropwizard.core.Application; -import io.dropwizard.core.Configuration; -import io.dropwizard.core.setup.Environment; -import io.dropwizard.testing.junit5.DropwizardAppExtension; -import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; -import jakarta.servlet.DispatcherType; -import jakarta.servlet.ServletRegistration; -import jakarta.ws.rs.GET; -import jakarta.ws.rs.Path; -import jakarta.ws.rs.client.Invocation; -import java.io.IOException; -import java.net.URI; -import java.util.Collections; -import java.util.EnumSet; -import java.util.List; -import java.util.Optional; -import java.util.UUID; -import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; -import org.eclipse.jetty.websocket.client.WebSocketClient; -import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; -import org.glassfish.jersey.server.ManagedAsync; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.EnumSource; -import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.AccountsManager; -import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.tests.util.DevicesHelper; -import org.whispersystems.textsecuregcm.tests.util.TestWebsocketListener; -import org.whispersystems.textsecuregcm.util.HeaderUtils; -import org.whispersystems.textsecuregcm.websocket.WebSocketAccountAuthenticator; -import org.whispersystems.websocket.WebSocketResourceProviderFactory; -import org.whispersystems.websocket.auth.PrincipalSupplier; -import org.whispersystems.websocket.auth.ReadOnly; -import org.whispersystems.websocket.configuration.WebSocketConfiguration; -import org.whispersystems.websocket.setup.WebSocketEnvironment; - -@ExtendWith(DropwizardExtensionsSupport.class) -class PhoneNumberChangeRefreshRequirementProviderTest { - - private static final String NUMBER = "+18005551234"; - private static final String CHANGED_NUMBER = "+18005554321"; - private static final String TEST_CRED_HEADER = HeaderUtils.basicAuthHeader("test", "password"); - - - private static final DropwizardAppExtension DROPWIZARD_APP_EXTENSION = new DropwizardAppExtension<>( - TestApplication.class); - - private static final AccountAuthenticator AUTHENTICATOR = mock(AccountAuthenticator.class); - private static final AccountsManager ACCOUNTS_MANAGER = mock(AccountsManager.class); - private static final DisconnectionRequestManager DISCONNECTION_REQUEST_MANAGER = - mock(DisconnectionRequestManager.class); - - private WebSocketClient client; - private final Account account1 = new Account(); - private final Account account2 = new Account(); - private final Device authenticatedDevice = DevicesHelper.createDevice(Device.PRIMARY_ID); - - - @BeforeEach - void setUp() throws Exception { - reset(AUTHENTICATOR, ACCOUNTS_MANAGER, DISCONNECTION_REQUEST_MANAGER); - client = new WebSocketClient(); - client.start(); - - final UUID uuid = UUID.randomUUID(); - account1.setUuid(uuid); - account1.addDevice(authenticatedDevice); - account1.setNumber(NUMBER, UUID.randomUUID()); - - account2.setUuid(uuid); - account2.addDevice(authenticatedDevice); - account2.setNumber(CHANGED_NUMBER, UUID.randomUUID()); - - } - - @AfterEach - void tearDown() throws Exception { - client.stop(); - } - - - public static class TestApplication extends Application { - - @Override - public void run(final Configuration configuration, final Environment environment) throws Exception { - final TestController testController = new TestController(); - - final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration(); - - final WebSocketEnvironment webSocketEnvironment = - new WebSocketEnvironment<>(environment, webSocketConfiguration); - - environment.jersey().register(testController); - webSocketEnvironment.jersey().register(testController); - environment.servlets() - .addFilter("RemoteAddressFilter", new RemoteAddressFilter()) - .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); - webSocketEnvironment.jersey().register(new RemoteAddressFilter()); - webSocketEnvironment.jersey() - .register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, DISCONNECTION_REQUEST_MANAGER)); - environment.jersey() - .register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, DISCONNECTION_REQUEST_MANAGER)); - webSocketEnvironment.setConnectListener(webSocketSessionContext -> { - }); - - - environment.jersey().register(new AuthDynamicFeature(new BasicCredentialAuthFilter.Builder() - .setAuthenticator(AUTHENTICATOR) - .buildAuthFilter())); - webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(AUTHENTICATOR, mock(PrincipalSupplier.class))); - - final WebSocketResourceProviderFactory webSocketServlet = - new WebSocketResourceProviderFactory<>(webSocketEnvironment, AuthenticatedDevice.class, - webSocketConfiguration, REMOTE_ADDRESS_ATTRIBUTE_NAME); - - JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); - - final ServletRegistration.Dynamic websocketServlet = - environment.servlets().addServlet("WebSocket", webSocketServlet); - - websocketServlet.addMapping("/websocket"); - websocketServlet.setAsyncSupported(true); - } - } - - enum Protocol { HTTP, WEBSOCKET } - - private void makeAnonymousRequest(final Protocol protocol, final String requestPath) throws IOException { - makeRequest(protocol, requestPath, true); - } - - /* - * Make an authenticated request that will return account1 as the principal - */ - private void makeAuthenticatedRequest( - final Protocol protocol, - final String requestPath) throws IOException { - when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedDevice(account1, authenticatedDevice))); - makeRequest(protocol,requestPath, false); - } - - private void makeRequest(final Protocol protocol, final String requestPath, final boolean anonymous) throws IOException { - switch (protocol) { - case WEBSOCKET -> { - final TestWebsocketListener testWebsocketListener = new TestWebsocketListener(); - final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest(); - if (!anonymous) { - upgradeRequest.setHeader(HttpHeaders.AUTHORIZATION, TEST_CRED_HEADER); - } - client.connect( - testWebsocketListener, - URI.create(String.format("ws://127.0.0.1:%d/websocket", DROPWIZARD_APP_EXTENSION.getLocalPort())), - upgradeRequest); - testWebsocketListener.sendRequest(requestPath, "GET", Collections.emptyList(), Optional.empty()).join(); - } - case HTTP -> { - final Invocation.Builder request = DROPWIZARD_APP_EXTENSION.client() - .target("http://127.0.0.1:%s%s".formatted(DROPWIZARD_APP_EXTENSION.getLocalPort(), requestPath)) - .request(); - if (!anonymous) { - request.header(HttpHeaders.AUTHORIZATION, TEST_CRED_HEADER); - } - request.get(); - } - } - } - - @ParameterizedTest - @EnumSource(Protocol.class) - void handleRequestNoChange(final Protocol protocol) throws IOException { - when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account1)); - makeAuthenticatedRequest(protocol, "/test/annotated"); - - // Event listeners can fire after responses are sent - verify(ACCOUNTS_MANAGER, timeout(5000).times(1)).getByAccountIdentifier(eq(account1.getUuid())); - verifyNoMoreInteractions(ACCOUNTS_MANAGER); - } - - @ParameterizedTest - @EnumSource(Protocol.class) - void handleRequestChange(final Protocol protocol) throws IOException { - when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account2)); - when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedDevice(account1, authenticatedDevice))); - - makeAuthenticatedRequest(protocol, "/test/annotated"); - - // Make sure we disconnect the account if the account has changed numbers. Event listeners can fire after responses - // are sent, so use a timeout. - verify(DISCONNECTION_REQUEST_MANAGER, timeout(5000)) - .requestDisconnection(account1.getUuid(), List.of(authenticatedDevice.getId())); - verifyNoMoreInteractions(DISCONNECTION_REQUEST_MANAGER); - } - - @Test - void handleRequestChangeAsyncEndpoint() throws IOException { - when(ACCOUNTS_MANAGER.getByAccountIdentifier(any())).thenReturn(Optional.of(account2)); - when(AUTHENTICATOR.authenticate(any())).thenReturn(Optional.of(new AuthenticatedDevice(account1, authenticatedDevice))); - - // Event listeners with asynchronous HTTP endpoints don't currently correctly maintain state between request and - // response - makeAuthenticatedRequest(Protocol.WEBSOCKET, "/test/async-annotated"); - - // Make sure we disconnect the account if the account has changed numbers. Event listeners can fire after responses - // are sent, so use a timeout. - verify(DISCONNECTION_REQUEST_MANAGER, timeout(5000)) - .requestDisconnection(account1.getUuid(), List.of(authenticatedDevice.getId())); - verifyNoMoreInteractions(DISCONNECTION_REQUEST_MANAGER); - } - - @ParameterizedTest - @EnumSource(Protocol.class) - void handleRequestNotAnnotated(final Protocol protocol) throws IOException, InterruptedException { - makeAuthenticatedRequest(protocol,"/test/not-annotated"); - - // Give a tick for event listeners to run. Racy, but should occasionally catch an errant running listener if one is - // introduced. - Thread.sleep(100); - - // Shouldn't even read the account if the method has not been annotated - verifyNoMoreInteractions(ACCOUNTS_MANAGER); - } - - @ParameterizedTest - @EnumSource(Protocol.class) - void handleRequestNotAuthenticated(final Protocol protocol) throws IOException, InterruptedException { - makeAnonymousRequest(protocol, "/test/not-authenticated"); - - // Give a tick for event listeners to run. Racy, but should occasionally catch an errant running listener if one is - // introduced. - Thread.sleep(100); - - // Shouldn't even read the account if the method has not been annotated - verifyNoMoreInteractions(ACCOUNTS_MANAGER); - } - - - @Path("/test") - public static class TestController { - - @GET - @Path("/annotated") - @ChangesPhoneNumber - public String annotated(@ReadOnly @Auth final AuthenticatedDevice account) { - return "ok"; - } - - @GET - @Path("/async-annotated") - @ChangesPhoneNumber - @ManagedAsync - public String asyncAnnotated(@ReadOnly @Auth final AuthenticatedDevice account) { - return "ok"; - } - - @GET - @Path("/not-authenticated") - @ChangesPhoneNumber - public String notAuthenticated() { - return "ok"; - } - - @GET - @Path("/not-annotated") - public String notAnnotated(@ReadOnly @Auth final AuthenticatedDevice account) { - return "ok"; - } - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java index fb34b0d57..5c15f054f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java @@ -211,6 +211,11 @@ class AccountControllerTest { when(accountsManager.getByE164(eq(SENDER_HAS_STORAGE))).thenReturn(Optional.of(senderHasStorage)); when(accountsManager.getByE164(eq(SENDER_TRANSFER))).thenReturn(Optional.of(senderTransfer)); + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_TWO)); + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_3)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_3)); + when(accountsManager.getByAccountIdentifier(AuthHelper.UNDISCOVERABLE_UUID)).thenReturn(Optional.of(AuthHelper.UNDISCOVERABLE_ACCOUNT)); + doAnswer(invocation -> { final byte[] proof = invocation.getArgument(0); final byte[] hash = invocation.getArgument(1); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java index 8d9293bb5..d29b56377 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerV2Test.java @@ -145,6 +145,8 @@ class AccountControllerV2Test { void setUp() throws Exception { when(rateLimiters.getRegistrationLimiter()).thenReturn(registrationLimiter); + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); + when(changeNumberManager.changeNumber(any(), any(), any(), any(), any(), any(), any(), any())).thenAnswer( (Answer) invocation -> { final Account account = invocation.getArgument(0); @@ -607,6 +609,8 @@ class AccountControllerV2Test { @BeforeEach void setUp() throws Exception { + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); + when(changeNumberManager.updatePniKeys(any(), any(), any(), any(), any(), any(), any())).thenAnswer( (Answer) invocation -> { final Account account = invocation.getArgument(0); @@ -768,7 +772,9 @@ class AccountControllerV2Test { @BeforeEach void setup() { AccountsHelper.setupMockUpdate(accountsManager); + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); } + @Test void testSetPhoneNumberDiscoverability() { Response response = resources.getJerseyTest() @@ -805,6 +811,7 @@ class AccountControllerV2Test { @MethodSource void testGetAccountDataReport(final Account account, final String expectedTextAfterHeader) throws Exception { when(AuthHelper.ACCOUNTS_MANAGER.getByAccountIdentifier(account.getUuid())).thenReturn(Optional.of(account)); + when(accountsManager.getByAccountIdentifier(account.getUuid())).thenReturn(Optional.of(account)); final Response response = resources.getJerseyTest() .target("/v2/accounts/data_report") diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ArchiveControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ArchiveControllerTest.java index 7200a6317..5feb07c65 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ArchiveControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ArchiveControllerTest.java @@ -73,6 +73,7 @@ import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper; import org.whispersystems.textsecuregcm.mappers.GrpcStatusRuntimeExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.metrics.BackupMetrics; +import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.EnumMapUtil; import org.whispersystems.textsecuregcm.util.SystemMapper; @@ -82,6 +83,7 @@ import reactor.core.publisher.Flux; @ExtendWith(DropwizardExtensionsSupport.class) public class ArchiveControllerTest { + private static final AccountsManager accountsManager = mock(AccountsManager.class); private static final BackupAuthManager backupAuthManager = mock(BackupAuthManager.class); private static final BackupManager backupManager = mock(BackupManager.class); private final BackupAuthTestUtil backupAuthTestUtil = new BackupAuthTestUtil(Clock.systemUTC()); @@ -95,7 +97,7 @@ public class ArchiveControllerTest { .addProvider(new RateLimitExceededExceptionMapper()) .setMapper(SystemMapper.jsonMapper()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addResource(new ArchiveController(backupAuthManager, backupManager, new BackupMetrics())) + .addResource(new ArchiveController(accountsManager, backupAuthManager, backupManager, new BackupMetrics())) .build(); private final UUID aci = UUID.randomUUID(); @@ -106,6 +108,9 @@ public class ArchiveControllerTest { public void setUp() { reset(backupAuthManager); reset(backupManager); + + when(accountsManager.getByAccountIdentifierAsync(AuthHelper.VALID_UUID)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT))); } @ParameterizedTest diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/CertificateControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/CertificateControllerTest.java index a75f3f242..49ce4a5db 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/CertificateControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/CertificateControllerTest.java @@ -9,6 +9,8 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import io.dropwizard.auth.AuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; @@ -21,9 +23,11 @@ import java.time.Instant; import java.time.ZoneId; import java.time.temporal.ChronoUnit; import java.util.Base64; +import java.util.Optional; import java.util.stream.Stream; import org.apache.commons.lang3.StringUtils; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; @@ -44,6 +48,7 @@ import org.whispersystems.textsecuregcm.entities.DeliveryCertificate; import org.whispersystems.textsecuregcm.entities.GroupCredentials; import org.whispersystems.textsecuregcm.entities.MessageProtos.SenderCertificate; import org.whispersystems.textsecuregcm.entities.MessageProtos.ServerCertificate; +import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.SystemMapper; @@ -66,6 +71,8 @@ class CertificateControllerTest { private static final ServerZkAuthOperations serverZkAuthOperations; private static final Clock clock = Clock.fixed(Instant.now(), ZoneId.systemDefault()); + private static final AccountsManager accountsManager = mock(AccountsManager.class); + static { try { certificateGenerator = new CertificateGenerator(Base64.getDecoder().decode(signingCertificate), @@ -82,9 +89,14 @@ class CertificateControllerTest { .addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class)) .setMapper(SystemMapper.jsonMapper()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addResource(new CertificateController(certificateGenerator, serverZkAuthOperations, genericServerSecretParams, clock)) + .addResource(new CertificateController(accountsManager, certificateGenerator, serverZkAuthOperations, genericServerSecretParams, clock)) .build(); + @BeforeEach + void setUp() { + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); + } + @Test void testValidCertificate() throws Exception { DeliveryCertificate certificateObject = resources.getJerseyTest() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java index f99010465..448745394 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java @@ -38,6 +38,7 @@ import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker; import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker.ChallengeConstraints; +import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider; @@ -45,11 +46,12 @@ import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider; @ExtendWith(DropwizardExtensionsSupport.class) class ChallengeControllerTest { + private static final AccountsManager accountsManager = mock(AccountsManager.class); private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class); private static final ChallengeConstraintChecker challengeConstraintChecker = mock(ChallengeConstraintChecker.class); private static final ChallengeController challengeController = - new ChallengeController(rateLimitChallengeManager, challengeConstraintChecker); + new ChallengeController(accountsManager, rateLimitChallengeManager, challengeConstraintChecker); private static final ResourceExtension EXTENSION = ResourceExtension.builder() .addProvider(AuthHelper.getAuthFilter()) @@ -63,6 +65,9 @@ class ChallengeControllerTest { @BeforeEach void setup() { + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_TWO)); + when(challengeConstraintChecker.challengeConstraints(any(), any())) .thenReturn(new ChallengeConstraints(true, Optional.empty())); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceCheckControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceCheckControllerTest.java index b4150d953..18f300f69 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceCheckControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceCheckControllerTest.java @@ -27,6 +27,7 @@ import java.time.Duration; import java.time.Instant; import java.util.Base64; import java.util.Map; +import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; import org.glassfish.jersey.server.ServerProperties; @@ -45,6 +46,7 @@ import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper; import org.whispersystems.textsecuregcm.mappers.GrpcStatusRuntimeExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.devicecheck.AppleDeviceCheckManager; import org.whispersystems.textsecuregcm.storage.devicecheck.ChallengeNotFoundException; import org.whispersystems.textsecuregcm.storage.devicecheck.DeviceCheckKeyIdNotFoundException; @@ -62,6 +64,7 @@ class DeviceCheckControllerTest { private final static Duration REDEMPTION_DURATION = Duration.ofDays(5); private final static long REDEMPTION_LEVEL = 201L; + private static final AccountsManager accountsManager = mock(AccountsManager.class); private final static BackupAuthManager backupAuthManager = mock(BackupAuthManager.class); private final static AppleDeviceCheckManager appleDeviceCheckManager = mock(AppleDeviceCheckManager.class); private final static RateLimiters rateLimiters = mock(RateLimiters.class); @@ -76,7 +79,7 @@ class DeviceCheckControllerTest { .addProvider(new RateLimitExceededExceptionMapper()) .setMapper(SystemMapper.jsonMapper()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addResource(new DeviceCheckController(clock, backupAuthManager, appleDeviceCheckManager, rateLimiters, + .addResource(new DeviceCheckController(clock, accountsManager, backupAuthManager, appleDeviceCheckManager, rateLimiters, REDEMPTION_LEVEL, REDEMPTION_DURATION)) .build(); @@ -86,6 +89,8 @@ class DeviceCheckControllerTest { reset(appleDeviceCheckManager); reset(rateLimiters); when(rateLimiters.forDescriptor(any())).thenReturn(mock(RateLimiter.class)); + + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); } @ParameterizedTest diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index 7abb8b42f..8f60de7ed 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -42,14 +42,12 @@ import java.util.concurrent.CompletableFuture; import java.util.stream.IntStream; import java.util.stream.Stream; import org.apache.commons.lang3.RandomStringUtils; -import org.apache.commons.lang3.StringUtils; import org.glassfish.jersey.server.ServerProperties; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.CsvSource; @@ -62,8 +60,6 @@ import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; -import org.whispersystems.textsecuregcm.auth.DisconnectionRequestManager; -import org.whispersystems.textsecuregcm.auth.WebsocketRefreshApplicationEventListener; import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.ApnRegistrationId; import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest; @@ -116,7 +112,6 @@ class DeviceControllerTest { private static final Account account = mock(Account.class); private static final Account maxedAccount = mock(Account.class); private static final Device primaryDevice = mock(Device.class); - private static final DisconnectionRequestManager disconnectionRequestManager = mock(DisconnectionRequestManager.class); private static final Map deviceConfiguration = new HashMap<>(); private static final TestClock testClock = TestClock.now(); @@ -129,16 +124,12 @@ class DeviceControllerTest { persistentTimer, deviceConfiguration); - @RegisterExtension - public static final AuthHelper.AuthFilterExtension AUTH_FILTER_EXTENSION = new AuthHelper.AuthFilterExtension(); - private static final ResourceExtension resources = ResourceExtension.builder() .addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE) .addProvider(AuthHelper.getAuthFilter()) .addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class)) .addProvider(new RateLimitExceededExceptionMapper()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addProvider(new WebsocketRefreshApplicationEventListener(accountsManager, disconnectionRequestManager)) .addProvider(new DeviceLimitExceededExceptionMapper()) .addResource(deviceController) .build(); @@ -157,8 +148,15 @@ class DeviceControllerTest { when(account.getNumber()).thenReturn(AuthHelper.VALID_NUMBER); when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID); when(account.getPhoneNumberIdentifier()).thenReturn(AuthHelper.VALID_PNI); + when(account.getPrimaryDevice()).thenReturn(primaryDevice); + when(account.getDevice(anyByte())).thenReturn(Optional.empty()); + when(account.getDevice(Device.PRIMARY_ID)).thenReturn(Optional.of(primaryDevice)); + when(account.getDevices()).thenReturn(List.of(primaryDevice)); when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account)); + when(accountsManager.getByAccountIdentifierAsync(AuthHelper.VALID_UUID)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(account))); + when(accountsManager.getByE164(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(account)); when(accountsManager.getByE164(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(maxedAccount)); @@ -229,7 +227,7 @@ class DeviceControllerTest { final Device existingDevice = mock(Device.class); when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID); - when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); + when(account.getDevices()).thenReturn(List.of(existingDevice)); final ECSignedPreKey aciSignedPreKey; final ECSignedPreKey pniSignedPreKey; @@ -310,7 +308,7 @@ class DeviceControllerTest { final Device primaryDevice = mock(Device.class); when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID); - when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(primaryDevice)); + when(account.getDevices()).thenReturn(List.of(primaryDevice)); final ECSignedPreKey aciSignedPreKey; final ECSignedPreKey pniSignedPreKey; @@ -362,7 +360,7 @@ class DeviceControllerTest { final Device primaryDevice = mock(Device.class); when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID); - when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(primaryDevice)); + when(account.getDevices()).thenReturn(List.of(primaryDevice)); final ECSignedPreKey aciSignedPreKey; final ECSignedPreKey pniSignedPreKey; @@ -398,7 +396,7 @@ class DeviceControllerTest { void linkDeviceAtomicReusedToken() { final Device existingDevice = mock(Device.class); when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID); - when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); + when(account.getDevices()).thenReturn(List.of(existingDevice)); final ECSignedPreKey aciSignedPreKey; final ECSignedPreKey pniSignedPreKey; @@ -447,7 +445,7 @@ class DeviceControllerTest { final Device existingDevice = mock(Device.class); when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID); - when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); + when(account.getDevices()).thenReturn(List.of(existingDevice)); final ECSignedPreKey aciSignedPreKey; final ECSignedPreKey pniSignedPreKey; @@ -487,12 +485,12 @@ class DeviceControllerTest { void linkDeviceAtomicConflictingChannel(final boolean fetchesMessages, final Optional apnRegistrationId, final Optional gcmRegistrationId) { - when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account)); when(accountsManager.generateLinkDeviceToken(any())).thenReturn("test"); final Device existingDevice = mock(Device.class); when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID); - when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); + when(account.getDevices()).thenReturn(List.of(existingDevice)); final LinkDeviceToken deviceCode = resources.getJerseyTest() .target("/v1/devices/provisioning/code") @@ -548,12 +546,12 @@ class DeviceControllerTest { final KEMSignedPreKey aciPqLastResortPreKey, final KEMSignedPreKey pniPqLastResortPreKey) { - when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account)); when(accountsManager.generateLinkDeviceToken(any())).thenReturn("test"); final Device existingDevice = mock(Device.class); when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID); - when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); + when(account.getDevices()).thenReturn(List.of(existingDevice)); final LinkDeviceToken deviceCode = resources.getJerseyTest() .target("/v1/devices/provisioning/code") @@ -613,11 +611,11 @@ class DeviceControllerTest { aciPqLastResortPreKey = KeysHelper.signedKEMPreKey(3, aciIdentityKeyPair); pniPqLastResortPreKey = KeysHelper.signedKEMPreKey(4, pniIdentityKeyPair); - when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account)); final Device existingDevice = mock(Device.class); when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID); - when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); + when(account.getDevices()).thenReturn(List.of(existingDevice)); when(account.getIdentityKey(IdentityType.ACI)).thenReturn(new IdentityKey(aciIdentityKeyPair.getPublicKey())); when(account.getIdentityKey(IdentityType.PNI)).thenReturn(new IdentityKey(pniIdentityKeyPair.getPublicKey())); @@ -647,11 +645,11 @@ class DeviceControllerTest { final KEMSignedPreKey aciPqLastResortPreKey, final KEMSignedPreKey pniPqLastResortPreKey) { - when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(account)); final Device existingDevice = mock(Device.class); when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID); - when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); + when(account.getDevices()).thenReturn(List.of(existingDevice)); when(account.getIdentityKey(IdentityType.ACI)).thenReturn(aciIdentityKey); when(account.getIdentityKey(IdentityType.PNI)).thenReturn(pniIdentityKey); @@ -698,7 +696,7 @@ class DeviceControllerTest { final Device existingDevice = mock(Device.class); when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID); - when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); + when(account.getDevices()).thenReturn(List.of(existingDevice)); final ECSignedPreKey aciSignedPreKey; final ECSignedPreKey pniSignedPreKey; @@ -735,7 +733,7 @@ class DeviceControllerTest { void linkDeviceRegistrationId(final int registrationId, final int pniRegistrationId, final int expectedStatusCode) { final Device existingDevice = mock(Device.class); when(existingDevice.getId()).thenReturn(Device.PRIMARY_ID); - when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(List.of(existingDevice)); + when(account.getDevices()).thenReturn(List.of(existingDevice)); final ECKeyPair aciIdentityKeyPair = Curve.generateKeyPair(); final ECKeyPair pniIdentityKeyPair = Curve.generateKeyPair(); @@ -800,17 +798,16 @@ class DeviceControllerTest { @Test void maxDevicesTest() { - final AuthHelper.TestAccount testAccount = AUTH_FILTER_EXTENSION.createTestAccount(); - final List devices = IntStream.range(0, DeviceController.MAX_DEVICES + 1) .mapToObj(i -> mock(Device.class)) .toList(); - when(testAccount.account.getDevices()).thenReturn(devices); + + when(account.getDevices()).thenReturn(devices); Response response = resources.getJerseyTest() .target("/v1/devices/provisioning/code") .request() - .header("Authorization", testAccount.getAuthHeader()) + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .get(); assertEquals(411, response.getStatus()); @@ -829,7 +826,7 @@ class DeviceControllerTest { assertThat(response.getStatus()).isEqualTo(204); assertThat(response.hasEntity()).isFalse(); - verify(AuthHelper.VALID_DEVICE).setCapabilities(Set.of(DeviceCapability.DELETE_SYNC)); + verify(primaryDevice).setCapabilities(Set.of(DeviceCapability.DELETE_SYNC)); } } @@ -851,12 +848,12 @@ class DeviceControllerTest { void removeDevice() { // this is a static mock, so it might have previous invocations - clearInvocations(AuthHelper.VALID_ACCOUNT); + clearInvocations(account); final byte deviceId = 2; - when(accountsManager.removeDevice(AuthHelper.VALID_ACCOUNT, deviceId)) - .thenReturn(CompletableFuture.completedFuture(AuthHelper.VALID_ACCOUNT)); + when(accountsManager.removeDevice(account, deviceId)) + .thenReturn(CompletableFuture.completedFuture(account)); try (final Response response = resources .getJerseyTest() @@ -869,14 +866,14 @@ class DeviceControllerTest { assertThat(response.getStatus()).isEqualTo(204); assertThat(response.hasEntity()).isFalse(); - verify(accountsManager).removeDevice(AuthHelper.VALID_ACCOUNT, deviceId); + verify(accountsManager).removeDevice(account, deviceId); } } @Test void unlinkPrimaryDevice() { // this is a static mock, so it might have previous invocations - clearInvocations(AuthHelper.VALID_ACCOUNT); + clearInvocations(account); try (final Response response = resources .getJerseyTest() @@ -897,7 +894,10 @@ class DeviceControllerTest { final byte deviceId = 2; when(accountsManager.removeDevice(AuthHelper.VALID_ACCOUNT_3, deviceId)) - .thenReturn(CompletableFuture.completedFuture(AuthHelper.VALID_ACCOUNT)); + .thenReturn(CompletableFuture.completedFuture(account)); + + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_3)) + .thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_3)); try (final Response response = resources .getJerseyTest() @@ -946,7 +946,7 @@ class DeviceControllerTest { assertEquals(204, response.getStatus()); } - verify(clientPublicKeysManager).setPublicKey(AuthHelper.VALID_ACCOUNT, AuthHelper.VALID_DEVICE.getId(), request.publicKey()); + verify(clientPublicKeysManager).setPublicKey(account, AuthHelper.VALID_DEVICE.getId(), request.publicKey()); } @Test @@ -959,7 +959,7 @@ class DeviceControllerTest { final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]); when(accountsManager - .waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any())) + .waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(primaryDevice), eq(tokenIdentifier), any())) .thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo))); when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null)); @@ -985,7 +985,7 @@ class DeviceControllerTest { final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]); when(accountsManager - .waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any())) + .waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(primaryDevice), eq(tokenIdentifier), any())) .thenReturn(CompletableFuture.completedFuture(Optional.empty())); when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null)); @@ -1005,7 +1005,7 @@ class DeviceControllerTest { final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]); when(accountsManager - .waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any())) + .waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(primaryDevice), eq(tokenIdentifier), any())) .thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException())); when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null)); @@ -1079,7 +1079,7 @@ class DeviceControllerTest { new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8))); when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null)); - when(accountsManager.recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferArchive)) + when(accountsManager.recordTransferArchiveUpload(account, deviceId, deviceCreated, transferArchive)) .thenReturn(CompletableFuture.completedFuture(null)); try (final Response response = resources.getJerseyTest() @@ -1092,7 +1092,7 @@ class DeviceControllerTest { assertEquals(204, response.getStatus()); verify(accountsManager) - .recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferArchive); + .recordTransferArchiveUpload(account, deviceId, deviceCreated, transferArchive); } } @@ -1103,7 +1103,7 @@ class DeviceControllerTest { final RemoteAttachmentError transferFailure = new RemoteAttachmentError(RemoteAttachmentError.ErrorType.CONTINUE_WITHOUT_UPLOAD); when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null)); - when(accountsManager.recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferFailure)) + when(accountsManager.recordTransferArchiveUpload(account, deviceId, deviceCreated, transferFailure)) .thenReturn(CompletableFuture.completedFuture(null)); try (final Response response = resources.getJerseyTest() @@ -1116,7 +1116,7 @@ class DeviceControllerTest { assertEquals(204, response.getStatus()); verify(accountsManager) - .recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferFailure); + .recordTransferArchiveUpload(account, deviceId, deviceCreated, transferFailure); } } @@ -1186,7 +1186,7 @@ class DeviceControllerTest { new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8))); when(rateLimiter.validateAsync(anyString())).thenReturn(CompletableFuture.completedFuture(null)); - when(accountsManager.waitForTransferArchive(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any())) + when(accountsManager.waitForTransferArchive(eq(account), eq(primaryDevice), any())) .thenReturn(CompletableFuture.completedFuture(Optional.of(transferArchive))); try (final Response response = resources.getJerseyTest() @@ -1206,7 +1206,7 @@ class DeviceControllerTest { new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8))); when(rateLimiter.validateAsync(anyString())).thenReturn(CompletableFuture.completedFuture(null)); - when(accountsManager.waitForTransferArchive(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any())) + when(accountsManager.waitForTransferArchive(eq(account), eq(primaryDevice), any())) .thenReturn(CompletableFuture.completedFuture(Optional.of(transferArchive))); try (final Response response = resources.getJerseyTest() @@ -1223,7 +1223,7 @@ class DeviceControllerTest { @Test void waitForTransferArchiveNoArchiveUploaded() { when(rateLimiter.validateAsync(anyString())).thenReturn(CompletableFuture.completedFuture(null)); - when(accountsManager.waitForTransferArchive(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any())) + when(accountsManager.waitForTransferArchive(eq(account), eq(primaryDevice), any())) .thenReturn(CompletableFuture.completedFuture(Optional.empty())); try (final Response response = resources.getJerseyTest() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DirectoryControllerV2Test.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DirectoryControllerV2Test.java index 55edd6752..991d84ec7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DirectoryControllerV2Test.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DirectoryControllerV2Test.java @@ -19,6 +19,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialsGenerator; import org.whispersystems.textsecuregcm.configuration.DirectoryV2ClientConfiguration; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; @@ -35,7 +36,7 @@ class DirectoryControllerV2Test { final Account account = mock(Account.class); final UUID uuid = UUID.fromString("11111111-1111-1111-1111-111111111111"); - when(account.getUuid()).thenReturn(uuid); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(uuid); final ExternalServiceCredentials credentials = controller.getAuthToken( new AuthenticatedDevice(account, mock(Device.class))); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DonationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DonationControllerTest.java index 5b79ba553..434a6be4e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DonationControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DonationControllerTest.java @@ -140,6 +140,8 @@ class DonationControllerTest { when(receiptCredentialPresentation.getReceiptExpirationTime()).thenReturn(receiptExpiration); when(redeemedReceiptsManager.put(same(receiptSerial), eq(receiptExpiration), eq(receiptLevel), eq(AuthHelper.VALID_UUID))).thenReturn( CompletableFuture.completedFuture(Boolean.FALSE)); + when(accountsManager.getByAccountIdentifierAsync(eq(AuthHelper.VALID_UUID))).thenReturn( + CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT))); RedeemReceiptRequest request = new RedeemReceiptRequest(presentation, true, true); Response response = resources.getJerseyTest() diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java index a5478e4ee..3a6cf4750 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java @@ -242,10 +242,21 @@ class KeysControllerTest { when(existsAccount.getUnidentifiedAccessKey()).thenReturn(Optional.of("1337".getBytes())); when(accounts.getByServiceIdentifier(any())).thenReturn(Optional.empty()); + when(accounts.getByServiceIdentifierAsync(any())).thenReturn(CompletableFuture.completedFuture(Optional.empty())); when(accounts.getByServiceIdentifier(new AciServiceIdentifier(EXISTS_UUID))).thenReturn(Optional.of(existsAccount)); when(accounts.getByServiceIdentifier(new PniServiceIdentifier(EXISTS_PNI))).thenReturn(Optional.of(existsAccount)); + when(accounts.getByServiceIdentifierAsync(new AciServiceIdentifier(EXISTS_UUID))) + .thenReturn(CompletableFuture.completedFuture(Optional.of(existsAccount))); + + when(accounts.getByServiceIdentifierAsync(new PniServiceIdentifier(EXISTS_PNI))) + .thenReturn(CompletableFuture.completedFuture(Optional.of(existsAccount))); + + when(accounts.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); + when(accounts.getByAccountIdentifierAsync(AuthHelper.VALID_UUID)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT))); + when(rateLimiters.getPreKeysLimiter()).thenReturn(rateLimiter); when(KEYS.storeEcOneTimePreKeys(any(), anyByte(), any())) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 89c317b2b..381cfd2df 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -236,6 +236,14 @@ class MessageControllerTest { when(accountsManager.getByServiceIdentifierAsync(MULTI_DEVICE_PNI_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(multiDeviceAccount))); when(accountsManager.getByServiceIdentifierAsync(new AciServiceIdentifier(INTERNATIONAL_UUID))).thenReturn(CompletableFuture.completedFuture(Optional.of(internationalAccount))); + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT)); + when(accountsManager.getByAccountIdentifierAsync(AuthHelper.VALID_UUID)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT))); + + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_3)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_3)); + when(accountsManager.getByAccountIdentifierAsync(AuthHelper.VALID_UUID_3)) + .thenReturn(CompletableFuture.completedFuture(Optional.of(AuthHelper.VALID_ACCOUNT_3))); + when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter); when(rateLimiters.getStoriesLimiter()).thenReturn(rateLimiter); when(rateLimiters.getInboundMessageBytes()).thenReturn(rateLimiter); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java index fb1c40f4b..3c2ea856b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ProfileControllerTest.java @@ -221,6 +221,8 @@ class ProfileControllerTest { when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID)).thenReturn(Optional.of(capabilitiesAccount)); when(accountsManager.getByServiceIdentifier(new AciServiceIdentifier(AuthHelper.VALID_UUID))).thenReturn(Optional.of(capabilitiesAccount)); + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_TWO)); + final byte[] name = TestRandomUtil.nextBytes(81); final byte[] emoji = TestRandomUtil.nextBytes(60); final byte[] about = TestRandomUtil.nextBytes(156); @@ -1155,6 +1157,7 @@ class ProfileControllerTest { reset(accountsManager); final int accountsManagerUpdateRetryCount = 2; AccountsHelper.setupMockUpdateWithRetries(accountsManager, accountsManagerUpdateRetryCount); + when(accountsManager.getByAccountIdentifier(AuthHelper.VALID_UUID_TWO)).thenReturn(Optional.of(AuthHelper.VALID_ACCOUNT_TWO)); // set up two invocations -- one for each AccountsManager#update try when(AuthHelper.VALID_ACCOUNT_TWO.getBadges()) .thenReturn(List.of( diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RemoteConfigControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RemoteConfigControllerTest.java index 6baab2041..9e4dd89c7 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RemoteConfigControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RemoteConfigControllerTest.java @@ -170,7 +170,12 @@ class RemoteConfigControllerTest { void testHashKeyLinkedConfigs() { boolean allUnlinkedConfigsMatched = true; for (AuthHelper.TestAccount testAccount : AuthHelper.TEST_ACCOUNTS) { - UserRemoteConfigList configuration = resources.getJerseyTest().target("/v1/config/").request().header("Authorization", testAccount.getAuthHeader()).get(UserRemoteConfigList.class); + UserRemoteConfigList configuration = resources.getJerseyTest() + .target("/v1/config/") + .request() + .header("Authorization", testAccount.getAuthHeader()) + .get(UserRemoteConfigList.class); + assertThat(configuration.getConfig()).hasSize(11); final UserRemoteConfig linkedConfig0 = configuration.getConfig().get(7); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java index a3ed44661..86a62432b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/entities/OutgoingMessageEntityTest.java @@ -7,7 +7,6 @@ package org.whispersystems.textsecuregcm.entities; import static org.junit.jupiter.api.Assertions.assertEquals; -import java.util.Random; import java.util.UUID; import javax.annotation.Nullable; import org.junit.jupiter.api.Test; @@ -16,7 +15,6 @@ import org.junitpioneer.jupiter.cartesian.CartesianTest; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.TestRandomUtil; @@ -65,15 +63,13 @@ class OutgoingMessageEntityTest { @Test void entityPreservesEnvelope() { final byte[] reportSpamToken = TestRandomUtil.nextBytes(8); + final AciServiceIdentifier sourceServiceIdentifier = new AciServiceIdentifier(UUID.randomUUID()); - final Account account = new Account(); - account.setUuid(UUID.randomUUID()); - - IncomingMessage message = new IncomingMessage(1, (byte) 44, 55, TestRandomUtil.nextBytes(4)); + final IncomingMessage message = new IncomingMessage(1, (byte) 44, 55, TestRandomUtil.nextBytes(4)); MessageProtos.Envelope baseEnvelope = message.toEnvelope( new AciServiceIdentifier(UUID.randomUUID()), - account, + sourceServiceIdentifier, (byte) 123, System.currentTimeMillis(), false, diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java index 3595e697f..4844c1583 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java @@ -153,7 +153,7 @@ class MetricsRequestEventListenerTest { final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); final WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.reusableAuth("foo"), + RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.authenticatedTestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); final Session session = mock(Session.class); @@ -220,7 +220,7 @@ class MetricsRequestEventListenerTest { final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); final WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.reusableAuth("foo"), + RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, TestPrincipal.authenticatedTestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); final Session session = mock(Session.class); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java index e20bf0651..b6770da00 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java @@ -220,6 +220,8 @@ public class AccountsHelper { case "getBackupVoucher" -> when(updatedAccount.getBackupVoucher()).thenAnswer(stubbing); case "getLastSeen" -> when(updatedAccount.getLastSeen()).thenAnswer(stubbing); case "hasLockedCredentials" -> when(updatedAccount.hasLockedCredentials()).thenAnswer(stubbing); + case "getCurrentProfileVersion" -> when(updatedAccount.getCurrentProfileVersion()).thenAnswer(stubbing); + case "getUnidentifiedAccessKey" -> when(updatedAccount.getUnidentifiedAccessKey()).thenAnswer(stubbing); default -> throw new IllegalArgumentException("unsupported method: Account#" + stubbing.getInvocation().getMethod().getName()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java index 2619df688..952375654 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java @@ -277,6 +277,7 @@ public class AuthHelper { when(account.getPrimaryDevice()).thenReturn(device); when(account.getNumber()).thenReturn(number); when(account.getUuid()).thenReturn(uuid); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(uuid); when(accountsManager.getByE164(number)).thenReturn(Optional.of(account)); when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account)); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestPrincipal.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestPrincipal.java index 201e10a50..186ec5264 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestPrincipal.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/TestPrincipal.java @@ -5,8 +5,7 @@ package org.whispersystems.textsecuregcm.tests.util; import java.security.Principal; -import org.whispersystems.websocket.ReusableAuth; -import org.whispersystems.websocket.auth.PrincipalSupplier; +import java.util.Optional; public class TestPrincipal implements Principal { @@ -21,7 +20,7 @@ public class TestPrincipal implements Principal { return name; } - public static ReusableAuth reusableAuth(final String name) { - return ReusableAuth.authenticated(new TestPrincipal(name), PrincipalSupplier.forImmutablePrincipal()); + public static Optional authenticatedTestPrincipal(final String name) { + return Optional.of(new TestPrincipal(name)); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java index 79168d475..c91d91908 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java @@ -176,7 +176,7 @@ class LoggingUnhandledExceptionMapperTest { WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, - TestPrincipal.reusableAuth("foo"), + TestPrincipal.authenticatedTestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java index 9bed3c009..c1056bb26 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java @@ -28,7 +28,6 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.websocket.auth.InvalidCredentialsException; -import org.whispersystems.websocket.auth.PrincipalSupplier; class WebSocketAccountAuthenticatorTest { @@ -70,14 +69,12 @@ class WebSocketAccountAuthenticatorTest { when(upgradeRequest.getHeader(eq(HttpHeaders.AUTHORIZATION))).thenReturn(authorizationHeaderValue); } - final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator( - accountAuthenticator, - mock(PrincipalSupplier.class)); + final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator); if (expectInvalid) { assertThrows(InvalidCredentialsException.class, () -> webSocketAuthenticator.authenticate(upgradeRequest)); } else { - assertEquals(expectAccount, webSocketAuthenticator.authenticate(upgradeRequest).ref().isPresent()); + assertEquals(expectAccount, webSocketAuthenticator.authenticate(upgradeRequest).isPresent()); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java index 35a704307..6e28fd180 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java @@ -43,11 +43,11 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; -import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.push.PushNotificationManager; @@ -111,7 +111,7 @@ class WebSocketConnectionIntegrationTest { clientReleaseManager = mock(ClientReleaseManager.class); when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(UUID.randomUUID()); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(UUID.randomUUID()); when(device.getId()).thenReturn(Device.PRIMARY_ID); } @@ -137,7 +137,8 @@ class WebSocketConnectionIntegrationTest { new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), - new AuthenticatedDevice(account, device), + account, + device, webSocketClient, scheduledExecutorService, messageDeliveryScheduler, @@ -159,14 +160,14 @@ class WebSocketConnectionIntegrationTest { expectedMessages.add(envelope); } - messagesDynamoDb.store(persistedMessages, account.getUuid(), device); + messagesDynamoDb.store(persistedMessages, account.getIdentifier(IdentityType.ACI), device); } for (int i = 0; i < cachedMessageCount; i++) { final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); - messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join(); + messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join(); expectedMessages.add(envelope); } @@ -226,7 +227,8 @@ class WebSocketConnectionIntegrationTest { new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), - new AuthenticatedDevice(account, device), + account, + device, webSocketClient, scheduledExecutorService, messageDeliveryScheduler, @@ -250,13 +252,13 @@ class WebSocketConnectionIntegrationTest { expectedMessages.add(envelope); } - messagesDynamoDb.store(persistedMessages, account.getUuid(), device); + messagesDynamoDb.store(persistedMessages, account.getIdentifier(IdentityType.ACI), device); } for (int i = 0; i < cachedMessageCount; i++) { final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); - messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join(); + messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join(); expectedMessages.add(envelope); } @@ -296,7 +298,8 @@ class WebSocketConnectionIntegrationTest { new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), - new AuthenticatedDevice(account, device), + account, + device, webSocketClient, 100, // use a very short timeout, so that this test completes quickly scheduledExecutorService, @@ -321,13 +324,13 @@ class WebSocketConnectionIntegrationTest { expectedMessages.add(envelope); } - messagesDynamoDb.store(persistedMessages, account.getUuid(), device); + messagesDynamoDb.store(persistedMessages, account.getIdentifier(IdentityType.ACI), device); } for (int i = 0; i < cachedMessageCount; i++) { final UUID messageGuid = UUID.randomUUID(); final MessageProtos.Envelope envelope = generateRandomMessage(messageGuid); - messagesCache.insert(messageGuid, account.getUuid(), device.getId(), envelope).join(); + messagesCache.insert(messageGuid, account.getIdentifier(IdentityType.ACI), device.getId(), envelope).join(); expectedMessages.add(envelope); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index bf54719c7..c05c50f9b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -56,6 +56,7 @@ import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; +import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.push.PushNotificationManager; @@ -67,9 +68,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; -import org.whispersystems.websocket.ReusableAuth; import org.whispersystems.websocket.WebSocketClient; -import org.whispersystems.websocket.auth.PrincipalSupplier; import org.whispersystems.websocket.messages.WebSocketResponseMessage; import org.whispersystems.websocket.session.WebSocketSessionContext; import reactor.core.publisher.Flux; @@ -90,7 +89,6 @@ class WebSocketConnectionTest { private AccountsManager accountsManager; private Account account; private Device device; - private AuthenticatedDevice auth; private UpgradeRequest upgradeRequest; private MessagesManager messagesManager; private ReceiptSender receiptSender; @@ -104,7 +102,6 @@ class WebSocketConnectionTest { accountsManager = mock(AccountsManager.class); account = mock(Account.class); device = mock(Device.class); - auth = new AuthenticatedDevice(account, device); upgradeRequest = mock(UpgradeRequest.class); messagesManager = mock(MessagesManager.class); receiptSender = mock(ReceiptSender.class); @@ -122,8 +119,8 @@ class WebSocketConnectionTest { @Test void testCredentials() throws Exception { WebSocketAccountAuthenticator webSocketAuthenticator = - new WebSocketAccountAuthenticator(accountAuthenticator, mock(PrincipalSupplier.class)); - AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, messagesManager, + new WebSocketAccountAuthenticator(accountAuthenticator); + AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(accountsManager, receiptSender, messagesManager, new MessageMetrics(), mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), mock(WebSocketConnectionEventManager.class), retrySchedulingExecutor, messageDeliveryScheduler, clientReleaseManager, mock(MessageDeliveryLoopMonitor.class), @@ -133,9 +130,9 @@ class WebSocketConnectionTest { when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) .thenReturn(Optional.of(new AuthenticatedDevice(account, device))); - ReusableAuth account = webSocketAuthenticator.authenticate(upgradeRequest); - when(sessionContext.getAuthenticated()).thenReturn(account.ref().orElse(null)); - when(sessionContext.getAuthenticated(AuthenticatedDevice.class)).thenReturn(account.ref().orElse(null)); + Optional account = webSocketAuthenticator.authenticate(upgradeRequest); + when(sessionContext.getAuthenticated()).thenReturn(account.orElse(null)); + when(sessionContext.getAuthenticated(AuthenticatedDevice.class)).thenReturn(account.orElse(null)); final WebSocketClient webSocketClient = mock(WebSocketClient.class); when(webSocketClient.getUserAgent()).thenReturn("Signal-Android/6.22.8"); @@ -150,7 +147,7 @@ class WebSocketConnectionTest { // unauthenticated when(upgradeRequest.getParameterMap()).thenReturn(Map.of()); account = webSocketAuthenticator.authenticate(upgradeRequest); - assertFalse(account.ref().isPresent()); + assertFalse(account.isPresent()); connectListener.onWebSocketConnect(sessionContext); verify(sessionContext, times(2)).addWebsocketClosedListener( @@ -174,7 +171,7 @@ class WebSocketConnectionTest { when(device.getId()).thenReturn(deviceId); when(account.getNumber()).thenReturn("+14152222222"); - when(account.getUuid()).thenReturn(accountUuid); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); final Device sender1device = mock(Device.class); @@ -191,7 +188,7 @@ class WebSocketConnectionTest { String userAgent = HttpHeaders.USER_AGENT; - when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false)) + when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false)) .thenReturn(Flux.fromIterable(outgoingMessages)); final List> futures = new LinkedList<>(); @@ -237,7 +234,7 @@ class WebSocketConnectionTest { final UUID accountUuid = UUID.randomUUID(); when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(accountUuid); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); @@ -320,7 +317,7 @@ class WebSocketConnectionTest { when(device.getId()).thenReturn(deviceId); when(account.getNumber()).thenReturn("+14152222222"); - when(account.getUuid()).thenReturn(accountUuid); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); final Device sender1device = mock(Device.class); @@ -337,7 +334,7 @@ class WebSocketConnectionTest { String userAgent = HttpHeaders.USER_AGENT; - when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false)) + when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false)) .thenReturn(Flux.fromIterable(pendingMessages)); final List> futures = new LinkedList<>(); @@ -364,7 +361,7 @@ class WebSocketConnectionTest { futures.get(1).complete(response); futures.get(0).completeExceptionally(new IOException()); - verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getUuid())), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)), + verify(receiptSender, times(1)).sendReceipt(eq(new AciServiceIdentifier(account.getIdentifier(IdentityType.ACI))), eq(deviceId), eq(new AciServiceIdentifier(senderTwoUuid)), eq(secondMessage.getClientTimestamp())); connection.stop(); @@ -377,7 +374,7 @@ class WebSocketConnectionTest { final WebSocketConnection connection = webSocketConnection(client); when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(UUID.randomUUID()); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(UUID.randomUUID()); when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); @@ -385,7 +382,7 @@ class WebSocketConnectionTest { final AtomicBoolean returnMessageList = new AtomicBoolean(false); when( - messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false)) + messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false)) .thenAnswer(invocation -> { synchronized (threadWaiting) { threadWaiting.set(true); @@ -442,7 +439,7 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+18005551234"); final UUID accountUuid = UUID.randomUUID(); - when(account.getUuid()).thenReturn(accountUuid); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); @@ -490,7 +487,7 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+18005551234"); final UUID accountUuid = UUID.randomUUID(); - when(account.getUuid()).thenReturn(accountUuid); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); @@ -573,7 +570,7 @@ class WebSocketConnectionTest { when(account.getNumber()).thenReturn("+18005551234"); final UUID accountUuid = UUID.randomUUID(); - when(account.getUuid()).thenReturn(accountUuid); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); @@ -581,7 +578,7 @@ class WebSocketConnectionTest { final List messages = List.of( createMessage(senderUuid, UUID.randomUUID(), 1111L, "message the first")); - when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false)) + when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false)) .thenReturn(Flux.fromIterable(messages)) .thenReturn(Flux.empty()); @@ -629,7 +626,7 @@ class WebSocketConnectionTest { private WebSocketConnection webSocketConnection(final WebSocketClient client) { return new WebSocketConnection(receiptSender, messagesManager, new MessageMetrics(), - mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), auth, client, + mock(PushNotificationManager.class), mock(PushNotificationScheduler.class), account, device, client, retrySchedulingExecutor, Schedulers.immediate(), clientReleaseManager, mock(MessageDeliveryLoopMonitor.class), mock(ExperimentEnrollmentManager.class)); } @@ -642,7 +639,7 @@ class WebSocketConnectionTest { final UUID accountUuid = UUID.randomUUID(); when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(accountUuid); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); @@ -669,7 +666,7 @@ class WebSocketConnectionTest { final UUID accountUuid = UUID.randomUUID(); when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(accountUuid); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); @@ -725,7 +722,7 @@ class WebSocketConnectionTest { final UUID accountUuid = UUID.randomUUID(); when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(accountUuid); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); @@ -741,11 +738,11 @@ class WebSocketConnectionTest { // anything. connection.processStoredMessages(); - verify(messagesManager).getMessagesForDeviceReactive(account.getUuid(), device, false); + verify(messagesManager).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false); connection.handleNewMessageAvailable(); - verify(messagesManager).getMessagesForDeviceReactive(account.getUuid(), device, true); + verify(messagesManager).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, true); } @Test @@ -756,7 +753,7 @@ class WebSocketConnectionTest { final UUID accountUuid = UUID.randomUUID(); when(account.getNumber()).thenReturn("+18005551234"); - when(account.getUuid()).thenReturn(accountUuid); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); when(device.getId()).thenReturn(Device.PRIMARY_ID); when(client.isOpen()).thenReturn(true); @@ -773,7 +770,7 @@ class WebSocketConnectionTest { connection.processStoredMessages(); connection.handleMessagesPersisted(); - verify(messagesManager, times(2)).getMessagesForDeviceReactive(account.getUuid(), device, false); + verify(messagesManager, times(2)).getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false); } @Test @@ -783,9 +780,9 @@ class WebSocketConnectionTest { when(device.getId()).thenReturn((byte) 2); when(account.getNumber()).thenReturn("+14152222222"); - when(account.getUuid()).thenReturn(accountUuid); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); - when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false)) + when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false)) .thenReturn(Flux.error(new RedisException("OH NO"))); when(retrySchedulingExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer( @@ -812,9 +809,9 @@ class WebSocketConnectionTest { when(device.getId()).thenReturn((byte) 2); when(account.getNumber()).thenReturn("+14152222222"); - when(account.getUuid()).thenReturn(accountUuid); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); - when(messagesManager.getMessagesForDeviceReactive(account.getUuid(), device, false)) + when(messagesManager.getMessagesForDeviceReactive(account.getIdentifier(IdentityType.ACI), device, false)) .thenReturn(Flux.error(new RedisException("OH NO"))); final WebSocketClient client = mock(WebSocketClient.class); @@ -835,7 +832,7 @@ class WebSocketConnectionTest { when(device.getId()).thenReturn(deviceId); when(account.getNumber()).thenReturn("+14152222222"); - when(account.getUuid()).thenReturn(accountUuid); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); final int totalMessages = 1000; @@ -884,7 +881,7 @@ class WebSocketConnectionTest { when(device.getId()).thenReturn(deviceId); when(account.getNumber()).thenReturn("+14152222222"); - when(account.getUuid()).thenReturn(accountUuid); + when(account.getIdentifier(IdentityType.ACI)).thenReturn(accountUuid); final AtomicBoolean canceled = new AtomicBoolean(); diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/ReusableAuth.java b/websocket-resources/src/main/java/org/whispersystems/websocket/ReusableAuth.java deleted file mode 100644 index fec4f67e0..000000000 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/ReusableAuth.java +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Copyright 2024 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.websocket; - -import java.security.Principal; -import java.util.Optional; -import java.util.concurrent.atomic.AtomicBoolean; -import org.whispersystems.websocket.auth.PrincipalSupplier; - -/** - * This class holds a principal that can be reused across requests on a websocket. Since two requests may operate - * concurrently on the same principal, and some principals contain non thread-safe mutable state, appropriate use of - * this class ensures that no data races occur. It also ensures that after a principal is modified, a subsequent request - * gets the up-to-date principal - * - * @param The underlying principal type - * @see PrincipalSupplier - */ -public abstract sealed class ReusableAuth { - - /** - * Get a reference to the underlying principal that callers pledge not to modify. - *

- * The reference returned will potentially be provided to many threads concurrently accessing the principal. Callers - * should use this method only if they can ensure that they will not modify the in-memory principal object AND they do - * not intend to modify the underlying canonical representation of the principal. - *

- * For example, if a caller retrieves a reference to a principal, does not modify the in memory state, but updates a - * field on a database that should be reflected in subsequent retrievals of the principal, they will have met the - * first criteria, but not the second. In that case they should instead use {@link #mutableRef()}. - *

- * If other callers have modified the underlying principal by using {@link #mutableRef()}, this method may need to - * refresh the principal via {@link PrincipalSupplier#refresh} which could be a blocking operation. - * - * @return If authenticated, a reference to the underlying principal that should not be modified - */ - public abstract Optional ref(); - - - public interface MutableRef { - - T ref(); - - void close(); - } - - /** - * Get a reference to the underlying principal that may be modified. - *

- * The underlying principal can be safely modified. Multiple threads may operate on the same {@link ReusableAuth} so - * long as they each have their own {@link MutableRef}. After any modifications, the caller must call - * {@link MutableRef#close} to notify the principal has become dirty. Close should be called after modifications but - * before sending a response on the websocket. This ensures that a request that comes in after a modification response - * is received is guaranteed to see the modification. - * - * @return If authenticated, a reference to the underlying principal that may be modified - */ - public abstract Optional> mutableRef(); - - /** - * @return A {@link ReusableAuth} indicating no credential were provided - */ - public static ReusableAuth anonymous() { - //noinspection unchecked - return (ReusableAuth) Anonymous.ANON_RESULT; - } - - /** - * Create a successfully authenticated {@link ReusableAuth} - * - * @param principal The authenticated principal - * @param principalSupplier Instructions for how to refresh or copy this principal - * @param The principal type - * @return A {@link ReusableAuth} for a successfully authenticated principal - */ - public static ReusableAuth authenticated(T principal, - PrincipalSupplier principalSupplier) { - return new Authenticated<>(principal, principalSupplier); - } - - private static final class Anonymous extends ReusableAuth { - - @SuppressWarnings({"rawtypes"}) - private static final ReusableAuth ANON_RESULT = new Anonymous(); - - @Override - public Optional ref() { - return Optional.empty(); - } - - @Override - public Optional> mutableRef() { - return Optional.empty(); - } - } - - private static final class Authenticated extends ReusableAuth { - - private T basePrincipal; - private final AtomicBoolean needRefresh = new AtomicBoolean(false); - private final PrincipalSupplier principalSupplier; - - Authenticated(final T basePrincipal, PrincipalSupplier principalSupplier) { - this.basePrincipal = basePrincipal; - this.principalSupplier = principalSupplier; - - } - - @Override - public Optional ref() { - maybeRefresh(); - return Optional.of(basePrincipal); - } - - @Override - public Optional> mutableRef() { - maybeRefresh(); - return Optional.of(new AuthenticatedMutableRef(principalSupplier.deepCopy(basePrincipal))); - } - - private void maybeRefresh() { - if (needRefresh.compareAndSet(true, false)) { - basePrincipal = principalSupplier.refresh(basePrincipal); - } - } - - private class AuthenticatedMutableRef implements MutableRef { - - final T ref; - - private AuthenticatedMutableRef(T ref) { - this.ref = ref; - } - - public T ref() { - return ref; - } - - public void close() { - needRefresh.set(true); - } - } - } - - private ReusableAuth() { - } -} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java index 0ec463ddb..0cd80f02f 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java @@ -58,7 +58,7 @@ public class WebSocketResourceProvider implements WebSocket private final Map> requestMap = new ConcurrentHashMap<>(); - private final ReusableAuth reusableAuth; + private final Optional reusableAuth; private final WebSocketMessageFactory messageFactory; private final Optional connectListener; private final ApplicationHandler jerseyHandler; @@ -77,7 +77,7 @@ public class WebSocketResourceProvider implements WebSocket String remoteAddressPropertyName, ApplicationHandler jerseyHandler, WebsocketRequestLog requestLog, - ReusableAuth authenticated, + Optional authenticated, WebSocketMessageFactory messageFactory, Optional connectListener, Duration idleTimeout) { @@ -97,7 +97,7 @@ public class WebSocketResourceProvider implements WebSocket this.remoteEndpoint = session.getRemote(); this.context = new WebSocketSessionContext( new WebSocketClient(session, remoteEndpoint, messageFactory, requestMap)); - this.context.setAuthenticated(reusableAuth.ref().orElse(null)); + this.context.setAuthenticated(reusableAuth.orElse(null)); this.session.setIdleTimeout(idleTimeout); connectListener.ifPresent(listener -> listener.onWebSocketConnect(this.context)); @@ -164,16 +164,10 @@ public class WebSocketResourceProvider implements WebSocket /** * The property name where {@link org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider} can find an - * {@link ReusableAuth} object that lives for the lifetime of the websocket + * authenticated principal that lives for the lifetime of the websocket */ public static final String REUSABLE_AUTH_PROPERTY = WebSocketResourceProvider.class.getName() + ".reusableAuth"; - /** - * The property name where {@link org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider} can install a - * {@link org.whispersystems.websocket.ReusableAuth.MutableRef} for us to close when the request is finished - */ - public static final String RESOLVED_PRINCIPAL_PROPERTY = WebSocketResourceProvider.class.getName() + ".resolvedPrincipal"; - /** * The property name where request byte count is stored for metrics collection */ @@ -205,16 +199,6 @@ public class WebSocketResourceProvider implements WebSocket containerRequest, responseBody); responseFuture - .whenComplete((ignoredResponse, ignoredError) -> { - // If the request ended up being one that mutates our principal, we have to close it to indicate we're done - // with the mutation operation - final Object resolvedPrincipal = containerRequest.getProperty(RESOLVED_PRINCIPAL_PROPERTY); - if (resolvedPrincipal instanceof ReusableAuth.MutableRef ref) { - ref.close(); - } else if (resolvedPrincipal != null) { - logger.warn("unexpected resolved principal type {} : {}", resolvedPrincipal.getClass(), resolvedPrincipal); - } - }) .thenAccept(response -> { try { final int responseBytes = responseBody.size(); diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java index 6a2246ea9..4532794d4 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java @@ -64,9 +64,9 @@ public class WebSocketResourceProviderFactory extends Jetty try { Optional> authenticator = Optional.ofNullable(environment.getAuthenticator()); - final ReusableAuth authenticated = authenticator.isPresent() + final Optional authenticated = authenticator.isPresent() ? authenticator.get().authenticate(request) - : ReusableAuth.anonymous(); + : Optional.empty(); Optional.ofNullable(environment.getAuthenticatedWebSocketUpgradeFilter()) .ifPresent(filter -> filter.handleAuthentication(authenticated, request, response)); diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/AuthenticatedWebSocketUpgradeFilter.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/AuthenticatedWebSocketUpgradeFilter.java index daf25b080..bf0593dc0 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/AuthenticatedWebSocketUpgradeFilter.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/AuthenticatedWebSocketUpgradeFilter.java @@ -6,13 +6,13 @@ package org.whispersystems.websocket.auth; import java.security.Principal; +import java.util.Optional; import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest; import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse; -import org.whispersystems.websocket.ReusableAuth; public interface AuthenticatedWebSocketUpgradeFilter { - void handleAuthentication(ReusableAuth authenticated, + void handleAuthentication(@SuppressWarnings("OptionalUsedAsFieldOrParameterType") Optional authenticated, JettyServerUpgradeRequest request, JettyServerUpgradeResponse response); } diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/Mutable.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/Mutable.java deleted file mode 100644 index b5fadf7fd..000000000 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/Mutable.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright 2024 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.websocket.auth; - -import io.dropwizard.auth.Auth; - -import java.lang.annotation.ElementType; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; -import java.lang.annotation.Target; - -/** - * An @{@link Auth} object annotated with {@link Mutable} indicates that the consumer of the object - * will modify the object or its underlying canonical source. - * - * Note: An {@link Auth} object that does not specify @{@link ReadOnly} will be assumed to be @Mutable - * - * @see org.whispersystems.websocket.ReusableAuth - */ -@Retention(RetentionPolicy.RUNTIME) -@Target({ElementType.FIELD, ElementType.PARAMETER}) -public @interface Mutable { -} - diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/PrincipalSupplier.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/PrincipalSupplier.java deleted file mode 100644 index 414d11d66..000000000 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/PrincipalSupplier.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright 2024 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.websocket.auth; - -/** - * Teach {@link org.whispersystems.websocket.ReusableAuth} how to make a deep copy of a principal (that is safe to - * concurrently modify while the original principal is being read), and how to refresh a principal after it has been - * potentially modified. - * - * @param The underlying principal type - */ -public interface PrincipalSupplier { - - /** - * Re-fresh the principal after it has been modified. - *

- * If the principal is populated from a backing store, refresh should re-read it. - * - * @param t the potentially stale principal to refresh - * @return The up-to-date principal - */ - T refresh(T t); - - /** - * Create a deep, in-memory copy of the principal. This should be identical to the original principal, but should - * share no mutable state with the original. It should be safe for two threads to independently write and read from - * two independent deep copies. - * - * @param t the principal to copy - * @return An in-memory copy of the principal - */ - T deepCopy(T t); - - class ImmutablePrincipalSupplier implements PrincipalSupplier { - @SuppressWarnings({"rawtypes"}) - private static final PrincipalSupplier INSTANCE = new ImmutablePrincipalSupplier(); - - @Override - public T refresh(final T t) { - return t; - } - - @Override - public T deepCopy(final T t) { - return t; - } - } - - /** - * @return A principal supplier that can be used if the principal type does not support modification. - */ - static PrincipalSupplier forImmutablePrincipal() { - //noinspection unchecked - return (PrincipalSupplier) ImmutablePrincipalSupplier.INSTANCE; - } -} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/ReadOnly.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/ReadOnly.java deleted file mode 100644 index 8deab76ce..000000000 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/ReadOnly.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright 2024 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ -package org.whispersystems.websocket.auth; - -import java.lang.annotation.ElementType; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; -import java.lang.annotation.Target; - -/** - * An @{@link io.dropwizard.auth.Auth} object annotated with {@link ReadOnly} indicates that the consumer of the object - * will never modify the object, nor its underlying canonical source. - *

- * For example, a consumer of a @ReadOnly AuthenticatedAccount promises to never modify the in-memory - * AuthenticatedAccount and to never modify the underlying Account database for the account. - * - * @see org.whispersystems.websocket.ReusableAuth - */ -@Retention(RetentionPolicy.RUNTIME) -@Target({ElementType.FIELD, ElementType.PARAMETER}) -public @interface ReadOnly { -} - diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java index d002bff23..7836b5b8d 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java @@ -5,9 +5,20 @@ package org.whispersystems.websocket.auth; import java.security.Principal; +import java.util.Optional; import org.eclipse.jetty.websocket.api.UpgradeRequest; -import org.whispersystems.websocket.ReusableAuth; public interface WebSocketAuthenticator { - ReusableAuth authenticate(UpgradeRequest request) throws InvalidCredentialsException; + + /** + * Authenticates an account from credential headers provided in a WebSocket upgrade request. + * + * @param request the request from which to extract credentials + * + * @return the authenticated principal if credentials were provided and authenticated or empty if the caller is + * anonymous + * + * @throws InvalidCredentialsException if credentials were provided, but could not be authenticated + */ + Optional authenticate(UpgradeRequest request) throws InvalidCredentialsException; } diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebsocketAuthValueFactoryProvider.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebsocketAuthValueFactoryProvider.java index d7bfaf657..2fe00890a 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebsocketAuthValueFactoryProvider.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebsocketAuthValueFactoryProvider.java @@ -21,7 +21,6 @@ import org.glassfish.jersey.server.model.Parameter; import org.glassfish.jersey.server.spi.internal.ValueParamProvider; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.whispersystems.websocket.ReusableAuth; import org.whispersystems.websocket.WebSocketResourceProvider; @Singleton @@ -43,36 +42,30 @@ public class WebsocketAuthValueFactoryProvider extends Abst return null; } - final boolean readOnly = parameter.isAnnotationPresent(ReadOnly.class); + final boolean readOnly = true; if (parameter.getRawType() == Optional.class && ParameterizedType.class.isAssignableFrom(parameter.getType().getClass()) && principalClass == ((ParameterizedType) parameter.getType()).getActualTypeArguments()[0]) { - return containerRequest -> createPrincipal(containerRequest, readOnly); + return this::createPrincipal; } else if (principalClass.equals(parameter.getRawType())) { return containerRequest -> - createPrincipal(containerRequest, readOnly) + createPrincipal(containerRequest) .orElseThrow(() -> new WebApplicationException("Authenticated resource", 401)); } else { throw new IllegalStateException("Can't inject unassignable principal: " + principalClass + " for parameter: " + parameter); } } - private Optional createPrincipal(final ContainerRequest request, final boolean readOnly) { + private Optional createPrincipal(final ContainerRequest request) { final Object obj = request.getProperty(WebSocketResourceProvider.REUSABLE_AUTH_PROPERTY); - if (!(obj instanceof ReusableAuth)) { + if (!(obj instanceof Optional)) { logger.warn("Unexpected reusable auth property type {} : {}", obj.getClass(), obj); return Optional.empty(); } - @SuppressWarnings("unchecked") final ReusableAuth reusableAuth = (ReusableAuth) obj; - if (readOnly) { - return reusableAuth.ref(); - } else { - return reusableAuth.mutableRef().map(writeRef -> { - request.setProperty(WebSocketResourceProvider.RESOLVED_PRINCIPAL_PROPERTY, writeRef); - return writeRef.ref(); - }); - } + + //noinspection unchecked + return (Optional) obj; } @Singleton diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java index 5a790c47c..31e44e09f 100644 --- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java +++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java @@ -17,6 +17,7 @@ import io.dropwizard.jersey.DropwizardResourceConfig; import jakarta.servlet.http.HttpServletRequest; import java.io.IOException; import java.security.Principal; +import java.util.Optional; import javax.security.auth.Subject; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest; @@ -26,7 +27,6 @@ import org.glassfish.jersey.server.ResourceConfig; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.whispersystems.websocket.auth.InvalidCredentialsException; -import org.whispersystems.websocket.auth.PrincipalSupplier; import org.whispersystems.websocket.auth.AuthenticatedWebSocketUpgradeFilter; import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.configuration.WebSocketConfiguration; @@ -75,7 +75,7 @@ public class WebSocketResourceProviderFactoryTest { when(environment.getAuthenticator()).thenReturn(authenticator); when(authenticator.authenticate(eq(request))) - .thenReturn(ReusableAuth.authenticated(account, PrincipalSupplier.forImmutablePrincipal())); + .thenReturn(Optional.of(account)); when(environment.jersey()).thenReturn(jerseyEnvironment); final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class); when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1"); @@ -129,8 +129,7 @@ public class WebSocketResourceProviderFactoryTest { @Test void testAuthenticatedWebSocketUpgradeFilter() throws InvalidCredentialsException { final Account account = new Account(); - final ReusableAuth reusableAuth = - ReusableAuth.authenticated(account, PrincipalSupplier.forImmutablePrincipal()); + final Optional reusableAuth = Optional.of(account); when(environment.getAuthenticator()).thenReturn(authenticator); when(authenticator.authenticate(eq(request))).thenReturn(reusableAuth); diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java index e4a8f5b9a..953190bac 100644 --- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java +++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java @@ -59,7 +59,6 @@ import org.glassfish.jersey.server.ResourceConfig; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; -import org.whispersystems.websocket.auth.PrincipalSupplier; import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; import org.whispersystems.websocket.logging.WebsocketRequestLog; import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory; @@ -81,7 +80,7 @@ class WebSocketResourceProviderTest { WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, - immutableTestPrincipal("fooz"), + Optional.of(new TestPrincipal("fooz")), new ProtobufWebSocketMessageFactory(), Optional.of(connectListener), Duration.ofMillis(30000)); @@ -109,7 +108,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = mock(ApplicationHandler.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("foo")), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -186,7 +185,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = mock(ApplicationHandler.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("foo")), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -242,7 +241,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("foo")), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -282,7 +281,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("foo"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("foo")), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -322,7 +321,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("authorizedUserName"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("authorizedUserName")), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -362,7 +361,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, ReusableAuth.anonymous(), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.empty(), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -401,7 +400,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("something"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("something")), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -441,7 +440,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, ReusableAuth.anonymous(), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.empty(), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -481,7 +480,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("gooduser")), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -522,7 +521,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("gooduser")), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -564,7 +563,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("gooduser")), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -604,7 +603,7 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, immutableTestPrincipal("gooduser"), + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, Optional.of(new TestPrincipal("gooduser")), new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); @@ -729,10 +728,6 @@ class WebSocketResourceProviderTest { } } - public static ReusableAuth immutableTestPrincipal(final String name) { - return ReusableAuth.authenticated(new TestPrincipal(name), PrincipalSupplier.forImmutablePrincipal()); - } - public static class TestException extends Exception { public TestException(String message) {