diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 2bcfdeba8..d2a5be3f5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -64,8 +64,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.dispatch.DispatchManager; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.CertificateGenerator; -import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; @@ -149,7 +150,6 @@ import org.whispersystems.textsecuregcm.sms.TwilioSmsSender; import org.whispersystems.textsecuregcm.sms.TwilioVerifyExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; import org.whispersystems.textsecuregcm.storage.AbusiveHostRules; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountCleaner; import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawler; import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerCache; @@ -544,31 +544,40 @@ public class WhisperServerService extends Application accountAuthFilter = new BasicCredentialAuthFilter.Builder().setAuthenticator(accountAuthenticator).buildAuthFilter (); - AuthFilter disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder().setAuthenticator(disabledPermittedAccountAuthenticator).buildAuthFilter(); + AuthFilter accountAuthFilter = new BasicCredentialAuthFilter.Builder().setAuthenticator( + accountAuthenticator).buildAuthFilter(); + AuthFilter disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder().setAuthenticator( + disabledPermittedAccountAuthenticator).buildAuthFilter(); - environment.servlets().addFilter("RemoteDeprecationFilter", new RemoteDeprecationFilter(dynamicConfigurationManager)) + environment.servlets() + .addFilter("RemoteDeprecationFilter", new RemoteDeprecationFilter(dynamicConfigurationManager)) .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); environment.jersey().register(new ContentLengthFilter(TrafficSource.HTTP)); environment.jersey().register(MultiRecipientMessageProvider.class); environment.jersey().register(new MetricsApplicationEventListener(TrafficSource.HTTP)); - environment.jersey().register(new PolymorphicAuthDynamicFeature<>(ImmutableMap.of(Account.class, accountAuthFilter, - DisabledPermittedAccount.class, disabledPermittedAccountAuthFilter))); - environment.jersey().register(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))); + environment.jersey() + .register(new PolymorphicAuthDynamicFeature<>(ImmutableMap.of(AuthenticatedAccount.class, accountAuthFilter, + DisabledPermittedAuthenticatedAccount.class, disabledPermittedAccountAuthFilter))); + environment.jersey().register(new PolymorphicAuthValueFactoryProvider.Binder<>( + ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))); environment.jersey().register(new TimestampResponseFilter()); - environment.jersey().register(new VoiceVerificationController(config.getVoiceVerificationConfiguration().getUrl(), config.getVoiceVerificationConfiguration().getLocales())); + environment.jersey().register(new VoiceVerificationController(config.getVoiceVerificationConfiguration().getUrl(), + config.getVoiceVerificationConfiguration().getLocales())); /// - WebSocketEnvironment webSocketEnvironment = new WebSocketEnvironment<>(environment, config.getWebSocketConfiguration(), 90000); + WebSocketEnvironment webSocketEnvironment = new WebSocketEnvironment<>(environment, + config.getWebSocketConfiguration(), 90000); webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator)); webSocketEnvironment.setConnectListener( new AuthenticatedConnectListener(receiptSender, messagesManager, messageSender, apnFallbackManager, @@ -602,15 +611,18 @@ public class WhisperServerService extends Application provisioningEnvironment = new WebSocketEnvironment<>(environment, webSocketEnvironment.getRequestLog(), 60000); + WebSocketEnvironment provisioningEnvironment = new WebSocketEnvironment<>(environment, + webSocketEnvironment.getRequestLog(), 60000); provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(pubSubManager)); provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET)); provisioningEnvironment.jersey().register(new KeepAliveController(clientPresenceManager)); @@ -618,16 +630,19 @@ public class WhisperServerService extends Application webSocketServlet = new WebSocketResourceProviderFactory<>(webSocketEnvironment, Account.class); - WebSocketResourceProviderFactory provisioningServlet = new WebSocketResourceProviderFactory<>(provisioningEnvironment, Account.class); + WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory<>( + webSocketEnvironment, AuthenticatedAccount.class); + WebSocketResourceProviderFactory provisioningServlet = new WebSocketResourceProviderFactory<>( + provisioningEnvironment, AuthenticatedAccount.class); - ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet ); + ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet); ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet); websocket.addMapping("/v1/websocket/"); @@ -649,14 +664,18 @@ public class WhisperServerService extends Application webSocketEnvironment, WebSocketEnvironment provisioningEnvironment) { + private void registerExceptionMappers(Environment environment, + WebSocketEnvironment webSocketEnvironment, + WebSocketEnvironment provisioningEnvironment) { environment.jersey().register(new LoggingUnhandledExceptionMapper()); environment.jersey().register(new IOExceptionMapper()); environment.jersey().register(new RateLimitExceededExceptionMapper()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AccountAuthenticator.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AccountAuthenticator.java index 563d5be8f..382c9eff7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AccountAuthenticator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AccountAuthenticator.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.auth; @@ -7,17 +7,17 @@ package org.whispersystems.textsecuregcm.auth; import io.dropwizard.auth.Authenticator; import io.dropwizard.auth.basic.BasicCredentials; import java.util.Optional; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; -public class AccountAuthenticator extends BaseAccountAuthenticator implements Authenticator { +public class AccountAuthenticator extends BaseAccountAuthenticator implements + Authenticator { public AccountAuthenticator(AccountsManager accountsManager) { super(accountsManager); } @Override - public Optional authenticate(BasicCredentials basicCredentials) { + public Optional authenticate(BasicCredentials basicCredentials) { return super.authenticate(basicCredentials, true); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedAccount.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedAccount.java new file mode 100644 index 000000000..e1a34f45d --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedAccount.java @@ -0,0 +1,42 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.auth; + +import java.security.Principal; +import java.util.function.Supplier; +import javax.security.auth.Subject; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.util.Pair; + +public class AuthenticatedAccount implements Principal { + + private final Supplier> accountAndDevice; + + public AuthenticatedAccount(final Supplier> accountAndDevice) { + this.accountAndDevice = accountAndDevice; + } + + public Account getAccount() { + return accountAndDevice.get().first(); + } + + public Device getAuthenticatedDevice() { + return accountAndDevice.get().second(); + } + + // Principal implementation + + @Override + public String getName() { + return null; + } + + @Override + public boolean implies(final Subject subject) { + return false; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/BaseAccountAuthenticator.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/BaseAccountAuthenticator.java index dec507aef..1cf04d7e0 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/BaseAccountAuthenticator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/BaseAccountAuthenticator.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -19,6 +19,7 @@ import org.apache.commons.lang3.StringUtils; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.storage.RefreshingAccountAndDeviceSupplier; import org.whispersystems.textsecuregcm.util.Util; public class BaseAccountAuthenticator { @@ -45,14 +46,15 @@ public class BaseAccountAuthenticator { this.clock = clock; } - public Optional authenticate(BasicCredentials basicCredentials, boolean enabledRequired) { + public Optional authenticate(BasicCredentials basicCredentials, boolean enabledRequired) { boolean succeeded = false; String failureReason = null; String credentialType = null; try { - AuthorizationHeader authorizationHeader = AuthorizationHeader.fromUserAndPassword(basicCredentials.getUsername(), basicCredentials.getPassword()); - Optional account = accountsManager.get(authorizationHeader.getIdentifier()); + AuthorizationHeader authorizationHeader = AuthorizationHeader.fromUserAndPassword(basicCredentials.getUsername(), + basicCredentials.getPassword()); + Optional account = accountsManager.get(authorizationHeader.getIdentifier()); credentialType = authorizationHeader.getIdentifier().hasNumber() ? "e164" : "uuid"; @@ -83,9 +85,8 @@ public class BaseAccountAuthenticator { if (device.get().getAuthenticationCredentials().verify(basicCredentials.getPassword())) { succeeded = true; final Account authenticatedAccount = updateLastSeen(account.get(), device.get()); - // the device in scope might be stale after the update, so get the latest from the authenticated account - authenticatedAccount.setAuthenticatedDevice(authenticatedAccount.getDevice(device.get().getId()).orElseThrow()); - return Optional.of(authenticatedAccount); + return Optional.of(new AuthenticatedAccount( + new RefreshingAccountAndDeviceSupplier(authenticatedAccount, device.get().getId(), accountsManager))); } return Optional.empty(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisabledPermittedAccount.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisabledPermittedAccount.java deleted file mode 100644 index e2ea8368e..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisabledPermittedAccount.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright 2013-2020 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.auth; - -import org.whispersystems.textsecuregcm.storage.Account; - -import javax.security.auth.Subject; -import java.security.Principal; - -public class DisabledPermittedAccount implements Principal { - - private final Account account; - - public DisabledPermittedAccount(Account account) { - this.account = account; - } - - public Account getAccount() { - return account; - } - - // Principal implementation - - @Override - public String getName() { - return null; - } - - @Override - public boolean implies(Subject subject) { - return false; - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisabledPermittedAccountAuthenticator.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisabledPermittedAccountAuthenticator.java index ffdbcdc31..cb7b8b78b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisabledPermittedAccountAuthenticator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisabledPermittedAccountAuthenticator.java @@ -1,27 +1,25 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.auth; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.AccountsManager; - -import java.util.Optional; - import io.dropwizard.auth.Authenticator; import io.dropwizard.auth.basic.BasicCredentials; +import java.util.Optional; +import org.whispersystems.textsecuregcm.storage.AccountsManager; -public class DisabledPermittedAccountAuthenticator extends BaseAccountAuthenticator implements Authenticator { +public class DisabledPermittedAccountAuthenticator extends BaseAccountAuthenticator implements + Authenticator { public DisabledPermittedAccountAuthenticator(AccountsManager accountsManager) { super(accountsManager); } - + @Override - public Optional authenticate(BasicCredentials credentials) { - Optional account = super.authenticate(credentials, false); - return account.map(DisabledPermittedAccount::new); + public Optional authenticate(BasicCredentials credentials) { + Optional account = super.authenticate(credentials, false); + return account.map(DisabledPermittedAuthenticatedAccount::new); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisabledPermittedAuthenticatedAccount.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisabledPermittedAuthenticatedAccount.java new file mode 100644 index 000000000..4001a9573 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisabledPermittedAuthenticatedAccount.java @@ -0,0 +1,40 @@ +/* + * Copyright 2013-2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.auth; + +import java.security.Principal; +import javax.security.auth.Subject; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Device; + +public class DisabledPermittedAuthenticatedAccount implements Principal { + + private final AuthenticatedAccount authenticatedAccount; + + public DisabledPermittedAuthenticatedAccount(final AuthenticatedAccount authenticatedAccount) { + this.authenticatedAccount = authenticatedAccount; + } + + public Account getAccount() { + return authenticatedAccount.getAccount(); + } + + public Device getAuthenticatedDevice() { + return authenticatedAccount.getAuthenticatedDevice(); + } + + // Principal implementation + + @Override + public String getName() { + return null; + } + + @Override + public boolean implies(Subject subject) { + return false; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java index d04b29dc9..ba1578e8f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AccountController.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.controllers; @@ -39,9 +39,10 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials; import org.whispersystems.textsecuregcm.auth.AuthorizationHeader; -import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; import org.whispersystems.textsecuregcm.auth.InvalidAuthorizationHeaderException; @@ -404,8 +405,8 @@ public class AccountController { @GET @Path("/turn/") @Produces(MediaType.APPLICATION_JSON) - public TurnToken getTurnToken(@Auth Account account) throws RateLimitExceededException { - rateLimiters.getTurnLimiter().validate(account.getUuid()); + public TurnToken getTurnToken(@Auth AuthenticatedAccount auth) throws RateLimitExceededException { + rateLimiters.getTurnLimiter().validate(auth.getAccount().getUuid()); return turnTokenGenerator.generate(); } @@ -413,13 +414,13 @@ public class AccountController { @PUT @Path("/gcm/") @Consumes(MediaType.APPLICATION_JSON) - public void setGcmRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount, @Valid GcmRegistrationId registrationId) { - Account account = disabledPermittedAccount.getAccount(); - Device device = account.getAuthenticatedDevice().get(); + public void setGcmRegistrationId(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth, + @Valid GcmRegistrationId registrationId) { + Account account = disabledPermittedAuth.getAccount(); + Device device = disabledPermittedAuth.getAuthenticatedDevice(); if (device.getGcmId() != null && - device.getGcmId().equals(registrationId.getGcmRegistrationId())) - { + device.getGcmId().equals(registrationId.getGcmRegistrationId())) { return; } @@ -434,9 +435,9 @@ public class AccountController { @Timed @DELETE @Path("/gcm/") - public void deleteGcmRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount) { - Account account = disabledPermittedAccount.getAccount(); - Device device = account.getAuthenticatedDevice().get(); + public void deleteGcmRegistrationId(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth) { + Account account = disabledPermittedAuth.getAccount(); + Device device = disabledPermittedAuth.getAuthenticatedDevice(); accounts.updateDevice(account, device.getId(), d -> { d.setGcmId(null); @@ -449,9 +450,10 @@ public class AccountController { @PUT @Path("/apn/") @Consumes(MediaType.APPLICATION_JSON) - public void setApnRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount, @Valid ApnRegistrationId registrationId) { - Account account = disabledPermittedAccount.getAccount(); - Device device = account.getAuthenticatedDevice().get(); + public void setApnRegistrationId(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth, + @Valid ApnRegistrationId registrationId) { + Account account = disabledPermittedAuth.getAccount(); + Device device = disabledPermittedAuth.getAuthenticatedDevice(); accounts.updateDevice(account, device.getId(), d -> { d.setApnId(registrationId.getApnRegistrationId()); @@ -464,9 +466,9 @@ public class AccountController { @Timed @DELETE @Path("/apn/") - public void deleteApnRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount) { - Account account = disabledPermittedAccount.getAccount(); - Device device = account.getAuthenticatedDevice().get(); + public void deleteApnRegistrationId(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth) { + Account account = disabledPermittedAuth.getAccount(); + Device device = disabledPermittedAuth.getAuthenticatedDevice(); accounts.updateDevice(account, device.getId(), d -> { d.setApnId(null); @@ -483,57 +485,54 @@ public class AccountController { @PUT @Produces(MediaType.APPLICATION_JSON) @Path("/registration_lock") - public void setRegistrationLock(@Auth Account account, @Valid RegistrationLock accountLock) { + public void setRegistrationLock(@Auth AuthenticatedAccount auth, @Valid RegistrationLock accountLock) { AuthenticationCredentials credentials = new AuthenticationCredentials(accountLock.getRegistrationLock()); - accounts.update(account, a -> { - a.setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt()); - }); + accounts.update(auth.getAccount(), + a -> a.setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt())); } @Timed @DELETE @Path("/registration_lock") - public void removeRegistrationLock(@Auth Account account) { - accounts.update(account, a -> a.setRegistrationLock(null, null)); + public void removeRegistrationLock(@Auth AuthenticatedAccount auth) { + accounts.update(auth.getAccount(), a -> a.setRegistrationLock(null, null)); } @Timed @PUT @Path("/name/") - public void setName(@Auth DisabledPermittedAccount disabledPermittedAccount, @Valid DeviceName deviceName) { - Account account = disabledPermittedAccount.getAccount(); - Device device = account.getAuthenticatedDevice().get(); + public void setName(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth, @Valid DeviceName deviceName) { + Account account = disabledPermittedAuth.getAccount(); + Device device = disabledPermittedAuth.getAuthenticatedDevice(); accounts.updateDevice(account, device.getId(), d -> d.setName(deviceName.getDeviceName())); } @Timed @DELETE @Path("/signaling_key") - public void removeSignalingKey(@Auth DisabledPermittedAccount disabledPermittedAccount) { + public void removeSignalingKey(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth) { } @Timed @PUT @Path("/attributes/") @Consumes(MediaType.APPLICATION_JSON) - public void setAccountAttributes(@Auth DisabledPermittedAccount disabledPermittedAccount, - @HeaderParam("X-Signal-Agent") String userAgent, - @Valid AccountAttributes attributes) - { - Account account = disabledPermittedAccount.getAccount(); - long deviceId = account.getAuthenticatedDevice().get().getId(); - - accounts.update(account, a-> { + public void setAccountAttributes(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth, + @HeaderParam("X-Signal-Agent") String userAgent, + @Valid AccountAttributes attributes) { + Account account = disabledPermittedAuth.getAccount(); + long deviceId = disabledPermittedAuth.getAuthenticatedDevice().getId(); + accounts.update(account, a -> { a.getDevice(deviceId).ifPresent(d -> { - d.setFetchesMessages(attributes.getFetchesMessages()); - d.setName(attributes.getName()); - d.setLastSeen(Util.todayInMillis()); - d.setCapabilities(attributes.getCapabilities()); - d.setRegistrationId(attributes.getRegistrationId()); - d.setUserAgent(userAgent); - }); + d.setFetchesMessages(attributes.getFetchesMessages()); + d.setName(attributes.getName()); + d.setLastSeen(Util.todayInMillis()); + d.setCapabilities(attributes.getCapabilities()); + d.setRegistrationId(attributes.getRegistrationId()); + d.setUserAgent(userAgent); + }); a.setRegistrationLockFromAttributes(attributes); @@ -546,29 +545,30 @@ public class AccountController { @GET @Path("/me") @Produces(MediaType.APPLICATION_JSON) - public AccountCreationResult getMe(@Auth Account account) { - return whoAmI(account); + public AccountCreationResult getMe(@Auth AuthenticatedAccount auth) { + return whoAmI(auth); } @GET @Path("/whoami") @Produces(MediaType.APPLICATION_JSON) - public AccountCreationResult whoAmI(@Auth Account account) { - return new AccountCreationResult(account.getUuid(), account.isStorageSupported()); + public AccountCreationResult whoAmI(@Auth AuthenticatedAccount auth) { + return new AccountCreationResult(auth.getAccount().getUuid(), auth.getAccount().isStorageSupported()); } @DELETE @Path("/username") @Produces(MediaType.APPLICATION_JSON) - public void deleteUsername(@Auth Account account) { - usernames.delete(account.getUuid()); + public void deleteUsername(@Auth AuthenticatedAccount auth) { + usernames.delete(auth.getAccount().getUuid()); } @PUT @Path("/username/{username}") @Produces(MediaType.APPLICATION_JSON) - public Response setUsername(@Auth Account account, @PathParam("username") String username) throws RateLimitExceededException { - rateLimiters.getUsernameSetLimiter().validate(account.getUuid()); + public Response setUsername(@Auth AuthenticatedAccount auth, @PathParam("username") String username) + throws RateLimitExceededException { + rateLimiters.getUsernameSetLimiter().validate(auth.getAccount().getUuid()); if (username == null || username.isEmpty()) { return Response.status(Response.Status.BAD_REQUEST).build(); @@ -580,7 +580,7 @@ public class AccountController { return Response.status(Response.Status.BAD_REQUEST).build(); } - if (!usernames.put(account.getUuid(), username)) { + if (!usernames.put(auth.getAccount().getUuid(), username)) { return Response.status(Response.Status.CONFLICT).build(); } @@ -678,8 +678,8 @@ public class AccountController { @Timed @DELETE @Path("/me") - public void deleteAccount(@Auth Account account) throws InterruptedException { - accounts.delete(account, AccountsManager.DeletionReason.USER_REQUEST); + public void deleteAccount(@Auth AuthenticatedAccount auth) throws InterruptedException { + accounts.delete(auth.getAccount(), AccountsManager.DeletionReason.USER_REQUEST); } private boolean shouldAutoBlock(String sourceHost) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AttachmentControllerV1.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AttachmentControllerV1.java index cf25aef65..4c6365d52 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AttachmentControllerV1.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AttachmentControllerV1.java @@ -1,29 +1,27 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.controllers; import com.amazonaws.HttpMethod; import com.codahale.metrics.annotation.Timed; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV1; -import org.whispersystems.textsecuregcm.entities.AttachmentUri; -import org.whispersystems.textsecuregcm.limits.RateLimiters; -import org.whispersystems.textsecuregcm.s3.UrlSigner; -import org.whispersystems.textsecuregcm.storage.Account; - +import io.dropwizard.auth.Auth; +import java.io.IOException; +import java.net.URL; +import java.util.stream.Stream; import javax.ws.rs.GET; import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.core.MediaType; -import java.io.IOException; -import java.net.URL; -import java.util.stream.Stream; - -import io.dropwizard.auth.Auth; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV1; +import org.whispersystems.textsecuregcm.entities.AttachmentUri; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.s3.UrlSigner; @Path("/v1/attachments") @@ -35,25 +33,25 @@ public class AttachmentControllerV1 extends AttachmentControllerBase { private static final String[] UNACCELERATED_REGIONS = {"+20", "+971", "+968", "+974"}; private final RateLimiters rateLimiters; - private final UrlSigner urlSigner; + private final UrlSigner urlSigner; public AttachmentControllerV1(RateLimiters rateLimiters, String accessKey, String accessSecret, String bucket) { this.rateLimiters = rateLimiters; - this.urlSigner = new UrlSigner(accessKey, accessSecret, bucket); + this.urlSigner = new UrlSigner(accessKey, accessSecret, bucket); } @Timed @GET @Produces(MediaType.APPLICATION_JSON) - public AttachmentDescriptorV1 allocateAttachment(@Auth Account account) - throws RateLimitExceededException - { - if (account.isRateLimited()) { - rateLimiters.getAttachmentLimiter().validate(account.getUuid()); + public AttachmentDescriptorV1 allocateAttachment(@Auth AuthenticatedAccount auth) + throws RateLimitExceededException { + if (auth.getAccount().isRateLimited()) { + rateLimiters.getAttachmentLimiter().validate(auth.getAccount().getUuid()); } long attachmentId = generateAttachmentId(); - URL url = urlSigner.getPreSignedUrl(attachmentId, HttpMethod.PUT, Stream.of(UNACCELERATED_REGIONS).anyMatch(region -> account.getNumber().startsWith(region))); + URL url = urlSigner.getPreSignedUrl(attachmentId, HttpMethod.PUT, + Stream.of(UNACCELERATED_REGIONS).anyMatch(region -> auth.getAccount().getNumber().startsWith(region))); return new AttachmentDescriptorV1(attachmentId, url.toExternalForm()); @@ -63,11 +61,11 @@ public class AttachmentControllerV1 extends AttachmentControllerBase { @GET @Produces(MediaType.APPLICATION_JSON) @Path("/{attachmentId}") - public AttachmentUri redirectToAttachment(@Auth Account account, - @PathParam("attachmentId") long attachmentId) - throws IOException - { - return new AttachmentUri(urlSigner.getPreSignedUrl(attachmentId, HttpMethod.GET, Stream.of(UNACCELERATED_REGIONS).anyMatch(region -> account.getNumber().startsWith(region)))); + public AttachmentUri redirectToAttachment(@Auth AuthenticatedAccount auth, + @PathParam("attachmentId") long attachmentId) + throws IOException { + return new AttachmentUri(urlSigner.getPreSignedUrl(attachmentId, HttpMethod.GET, + Stream.of(UNACCELERATED_REGIONS).anyMatch(region -> auth.getAccount().getNumber().startsWith(region)))); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AttachmentControllerV2.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AttachmentControllerV2.java index 54b576ca3..09dd37be9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AttachmentControllerV2.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AttachmentControllerV2.java @@ -1,28 +1,26 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.controllers; import com.codahale.metrics.annotation.Timed; +import io.dropwizard.auth.Auth; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import javax.ws.rs.GET; +import javax.ws.rs.Path; +import javax.ws.rs.Produces; +import javax.ws.rs.core.MediaType; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV2; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.s3.PolicySigner; import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.util.Pair; -import javax.ws.rs.GET; -import javax.ws.rs.Path; -import javax.ws.rs.Produces; -import javax.ws.rs.core.MediaType; -import java.time.ZoneOffset; -import java.time.ZonedDateTime; - -import io.dropwizard.auth.Auth; - @Path("/v2/attachments") public class AttachmentControllerV2 extends AttachmentControllerBase { @@ -40,19 +38,20 @@ public class AttachmentControllerV2 extends AttachmentControllerBase { @GET @Produces(MediaType.APPLICATION_JSON) @Path("/form/upload") - public AttachmentDescriptorV2 getAttachmentUploadForm(@Auth Account account) throws RateLimitExceededException { - rateLimiter.validate(account.getUuid()); + public AttachmentDescriptorV2 getAttachmentUploadForm(@Auth AuthenticatedAccount auth) + throws RateLimitExceededException { + rateLimiter.validate(auth.getAccount().getUuid()); - ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC); - long attachmentId = generateAttachmentId(); - String objectName = String.valueOf(attachmentId); - Pair policy = policyGenerator.createFor(now, String.valueOf(objectName), 100 * 1024 * 1024); - String signature = policySigner.getSignature(now, policy.second()); + ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC); + long attachmentId = generateAttachmentId(); + String objectName = String.valueOf(attachmentId); + Pair policy = policyGenerator.createFor(now, String.valueOf(objectName), 100 * 1024 * 1024); + String signature = policySigner.getSignature(now, policy.second()); return new AttachmentDescriptorV2(attachmentId, objectName, policy.first(), - "private", "AWS4-HMAC-SHA256", - now.format(PostPolicyGenerator.AWS_DATE_TIME), - policy.second(), signature); + "private", "AWS4-HMAC-SHA256", + now.format(PostPolicyGenerator.AWS_DATE_TIME), + policy.second(), signature); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AttachmentControllerV3.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AttachmentControllerV3.java index fc767cca0..08a8813bf 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AttachmentControllerV3.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/AttachmentControllerV3.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -7,19 +7,6 @@ package org.whispersystems.textsecuregcm.controllers; import com.codahale.metrics.annotation.Timed; import io.dropwizard.auth.Auth; -import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV3; -import org.whispersystems.textsecuregcm.gcp.CanonicalRequest; -import org.whispersystems.textsecuregcm.gcp.CanonicalRequestGenerator; -import org.whispersystems.textsecuregcm.gcp.CanonicalRequestSigner; -import org.whispersystems.textsecuregcm.limits.RateLimiter; -import org.whispersystems.textsecuregcm.limits.RateLimiters; -import org.whispersystems.textsecuregcm.storage.Account; - -import javax.annotation.Nonnull; -import javax.ws.rs.GET; -import javax.ws.rs.Path; -import javax.ws.rs.Produces; -import javax.ws.rs.core.MediaType; import java.io.IOException; import java.security.InvalidKeyException; import java.security.SecureRandom; @@ -29,6 +16,18 @@ import java.time.ZonedDateTime; import java.util.Base64; import java.util.HashMap; import java.util.Map; +import javax.annotation.Nonnull; +import javax.ws.rs.GET; +import javax.ws.rs.Path; +import javax.ws.rs.Produces; +import javax.ws.rs.core.MediaType; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV3; +import org.whispersystems.textsecuregcm.gcp.CanonicalRequest; +import org.whispersystems.textsecuregcm.gcp.CanonicalRequestGenerator; +import org.whispersystems.textsecuregcm.gcp.CanonicalRequestSigner; +import org.whispersystems.textsecuregcm.limits.RateLimiter; +import org.whispersystems.textsecuregcm.limits.RateLimiters; @Path("/v3/attachments") public class AttachmentControllerV3 extends AttachmentControllerBase { @@ -45,26 +44,29 @@ public class AttachmentControllerV3 extends AttachmentControllerBase { @Nonnull private final SecureRandom secureRandom; - public AttachmentControllerV3(@Nonnull RateLimiters rateLimiters, @Nonnull String domain, @Nonnull String email, int maxSizeInBytes, @Nonnull String pathPrefix, @Nonnull String rsaSigningKey) + public AttachmentControllerV3(@Nonnull RateLimiters rateLimiters, @Nonnull String domain, @Nonnull String email, + int maxSizeInBytes, @Nonnull String pathPrefix, @Nonnull String rsaSigningKey) throws IOException, InvalidKeyException, InvalidKeySpecException { - this.rateLimiter = rateLimiters.getAttachmentLimiter(); + this.rateLimiter = rateLimiters.getAttachmentLimiter(); this.canonicalRequestGenerator = new CanonicalRequestGenerator(domain, email, maxSizeInBytes, pathPrefix); - this.canonicalRequestSigner = new CanonicalRequestSigner(rsaSigningKey); - this.secureRandom = new SecureRandom(); + this.canonicalRequestSigner = new CanonicalRequestSigner(rsaSigningKey); + this.secureRandom = new SecureRandom(); } @Timed @GET @Produces(MediaType.APPLICATION_JSON) @Path("/form/upload") - public AttachmentDescriptorV3 getAttachmentUploadForm(@Auth Account account) throws RateLimitExceededException { - rateLimiter.validate(account.getUuid()); + public AttachmentDescriptorV3 getAttachmentUploadForm(@Auth AuthenticatedAccount auth) + throws RateLimitExceededException { + rateLimiter.validate(auth.getAccount().getUuid()); - final ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC); - final String key = generateAttachmentKey(); + final ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC); + final String key = generateAttachmentKey(); final CanonicalRequest canonicalRequest = canonicalRequestGenerator.createFor(key, now); - return new AttachmentDescriptorV3(2, key, getHeaderMap(canonicalRequest), getSignedUploadLocation(canonicalRequest)); + return new AttachmentDescriptorV3(2, key, getHeaderMap(canonicalRequest), + getSignedUploadLocation(canonicalRequest)); } private String getSignedUploadLocation(@Nonnull CanonicalRequest canonicalRequest) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CertificateController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CertificateController.java index 31725f997..e4548b8f6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CertificateController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CertificateController.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -10,7 +10,6 @@ import static com.codahale.metrics.MetricRegistry.name; import com.codahale.metrics.annotation.Timed; import io.dropwizard.auth.Auth; import io.micrometer.core.instrument.Metrics; -import java.io.IOException; import java.security.InvalidKeyException; import java.util.LinkedList; import java.util.List; @@ -24,10 +23,10 @@ import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import org.signal.zkgroup.auth.ServerZkAuthOperations; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.CertificateGenerator; import org.whispersystems.textsecuregcm.entities.DeliveryCertificate; import org.whispersystems.textsecuregcm.entities.GroupCredentials; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.util.Util; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @@ -51,43 +50,49 @@ public class CertificateController { @GET @Produces(MediaType.APPLICATION_JSON) @Path("/delivery") - public DeliveryCertificate getDeliveryCertificate(@Auth Account account, - @QueryParam("includeE164") Optional maybeIncludeE164) - throws InvalidKeyException - { - if (account.getAuthenticatedDevice().isEmpty()) { - throw new AssertionError(); - } - if (Util.isEmpty(account.getIdentityKey())) { + public DeliveryCertificate getDeliveryCertificate(@Auth AuthenticatedAccount auth, + @QueryParam("includeE164") Optional maybeIncludeE164) + throws InvalidKeyException { + if (Util.isEmpty(auth.getAccount().getIdentityKey())) { throw new WebApplicationException(Response.Status.BAD_REQUEST); } final boolean includeE164 = maybeIncludeE164.orElse(true); - Metrics.counter(GENERATE_DELIVERY_CERTIFICATE_COUNTER_NAME, INCLUDE_E164_TAG_NAME, String.valueOf(includeE164)).increment(); + Metrics.counter(GENERATE_DELIVERY_CERTIFICATE_COUNTER_NAME, INCLUDE_E164_TAG_NAME, String.valueOf(includeE164)) + .increment(); - return new DeliveryCertificate(certificateGenerator.createFor(account, account.getAuthenticatedDevice().get(), includeE164)); + return new DeliveryCertificate( + certificateGenerator.createFor(auth.getAccount(), auth.getAuthenticatedDevice(), includeE164)); } @Timed @GET @Produces(MediaType.APPLICATION_JSON) @Path("/group/{startRedemptionTime}/{endRedemptionTime}") - public GroupCredentials getAuthenticationCredentials(@Auth Account account, - @PathParam("startRedemptionTime") int startRedemptionTime, - @PathParam("endRedemptionTime") int endRedemptionTime) - { - if (!isZkEnabled) throw new WebApplicationException(Response.Status.NOT_FOUND); - if (startRedemptionTime > endRedemptionTime) throw new WebApplicationException(Response.Status.BAD_REQUEST); - if (endRedemptionTime > Util.currentDaysSinceEpoch() + 7) throw new WebApplicationException(Response.Status.BAD_REQUEST); - if (startRedemptionTime < Util.currentDaysSinceEpoch()) throw new WebApplicationException(Response.Status.BAD_REQUEST); + public GroupCredentials getAuthenticationCredentials(@Auth AuthenticatedAccount auth, + @PathParam("startRedemptionTime") int startRedemptionTime, + @PathParam("endRedemptionTime") int endRedemptionTime) { + if (!isZkEnabled) { + throw new WebApplicationException(Response.Status.NOT_FOUND); + } + if (startRedemptionTime > endRedemptionTime) { + throw new WebApplicationException(Response.Status.BAD_REQUEST); + } + if (endRedemptionTime > Util.currentDaysSinceEpoch() + 7) { + throw new WebApplicationException(Response.Status.BAD_REQUEST); + } + if (startRedemptionTime < Util.currentDaysSinceEpoch()) { + throw new WebApplicationException(Response.Status.BAD_REQUEST); + } List credentials = new LinkedList<>(); - for (int i=startRedemptionTime;i<=endRedemptionTime;i++) { - credentials.add(new GroupCredentials.GroupCredential(serverZkAuthOperations.issueAuthCredential(account.getUuid(), i) - .serialize(), - i)); + for (int i = startRedemptionTime; i <= endRedemptionTime; i++) { + credentials.add(new GroupCredentials.GroupCredential( + serverZkAuthOperations.issueAuthCredential(auth.getAccount().getUuid(), i) + .serialize(), + i)); } return new GroupCredentials(credentials); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ChallengeController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ChallengeController.java index 71e982933..e0dedfdcd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ChallengeController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ChallengeController.java @@ -17,12 +17,12 @@ import javax.ws.rs.Path; import javax.ws.rs.Produces; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.entities.AnswerChallengeRequest; import org.whispersystems.textsecuregcm.entities.AnswerPushChallengeRequest; import org.whispersystems.textsecuregcm.entities.AnswerRecaptchaChallengeRequest; import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.util.ForwardedIpUtil; @Path("/v1/challenge") @@ -38,7 +38,7 @@ public class ChallengeController { @PUT @Produces(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON) - public Response handleChallengeResponse(@Auth final Account account, + public Response handleChallengeResponse(@Auth final AuthenticatedAccount auth, @Valid final AnswerChallengeRequest answerRequest, @HeaderParam("X-Forwarded-For") String forwardedFor) throws RetryLaterException { @@ -46,14 +46,15 @@ public class ChallengeController { if (answerRequest instanceof AnswerPushChallengeRequest) { final AnswerPushChallengeRequest pushChallengeRequest = (AnswerPushChallengeRequest) answerRequest; - rateLimitChallengeManager.answerPushChallenge(account, pushChallengeRequest.getChallenge()); + rateLimitChallengeManager.answerPushChallenge(auth.getAccount(), pushChallengeRequest.getChallenge()); } else if (answerRequest instanceof AnswerRecaptchaChallengeRequest) { try { final AnswerRecaptchaChallengeRequest recaptchaChallengeRequest = (AnswerRecaptchaChallengeRequest) answerRequest; final String mostRecentProxy = ForwardedIpUtil.getMostRecentProxy(forwardedFor).orElseThrow(); - rateLimitChallengeManager.answerRecaptchaChallenge(account, recaptchaChallengeRequest.getCaptcha(), mostRecentProxy); + rateLimitChallengeManager.answerRecaptchaChallenge(auth.getAccount(), recaptchaChallengeRequest.getCaptcha(), + mostRecentProxy); } catch (final NoSuchElementException e) { return Response.status(400).build(); @@ -69,9 +70,9 @@ public class ChallengeController { @Timed @POST @Path("/push") - public Response requestPushChallenge(@Auth final Account account) { + public Response requestPushChallenge(@Auth final AuthenticatedAccount auth) { try { - rateLimitChallengeManager.sendPushChallenge(account); + rateLimitChallengeManager.sendPushChallenge(auth.getAccount()); return Response.status(200).build(); } catch (final NotPushRegisteredException e) { return Response.status(404).build(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index 31f0faf16..bc9c994e7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -26,6 +26,7 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials; import org.whispersystems.textsecuregcm.auth.AuthorizationHeader; import org.whispersystems.textsecuregcm.auth.InvalidAuthorizationHeaderException; @@ -79,12 +80,12 @@ public class DeviceController { @Timed @GET @Produces(MediaType.APPLICATION_JSON) - public DeviceInfoList getDevices(@Auth Account account) { + public DeviceInfoList getDevices(@Auth AuthenticatedAccount auth) { List devices = new LinkedList<>(); - for (Device device : account.getDevices()) { + for (Device device : auth.getAccount().getDevices()) { devices.add(new DeviceInfo(device.getId(), device.getName(), - device.getLastSeen(), device.getCreated())); + device.getLastSeen(), device.getCreated())); } return new DeviceInfoList(devices); @@ -93,8 +94,9 @@ public class DeviceController { @Timed @DELETE @Path("/{device_id}") - public void removeDevice(@Auth Account account, @PathParam("device_id") long deviceId) { - if (account.getAuthenticatedDevice().get().getId() != Device.MASTER_ID) { + public void removeDevice(@Auth AuthenticatedAccount auth, @PathParam("device_id") long deviceId) { + Account account = auth.getAccount(); + if (auth.getAuthenticatedDevice().getId() != Device.MASTER_ID) { throw new WebApplicationException(Response.Status.UNAUTHORIZED); } @@ -109,9 +111,11 @@ public class DeviceController { @GET @Path("/provisioning/code") @Produces(MediaType.APPLICATION_JSON) - public VerificationCode createDeviceToken(@Auth Account account) - throws RateLimitExceededException, DeviceLimitExceededException - { + public VerificationCode createDeviceToken(@Auth AuthenticatedAccount auth) + throws RateLimitExceededException, DeviceLimitExceededException { + + final Account account = auth.getAccount(); + rateLimiters.getAllocateDeviceLimiter().validate(account.getUuid()); int maxDeviceLimit = MAX_DEVICES; @@ -124,7 +128,7 @@ public class DeviceController { throw new DeviceLimitExceededException(account.getDevices().size(), MAX_DEVICES); } - if (account.getAuthenticatedDevice().get().getId() != Device.MASTER_ID) { + if (auth.getAuthenticatedDevice().getId() != Device.MASTER_ID) { throw new WebApplicationException(Response.Status.UNAUTHORIZED); } @@ -213,18 +217,18 @@ public class DeviceController { @Timed @PUT @Path("/unauthenticated_delivery") - public void setUnauthenticatedDelivery(@Auth Account account) { - assert(account.getAuthenticatedDevice().isPresent()); + public void setUnauthenticatedDelivery(@Auth AuthenticatedAccount auth) { + assert (auth.getAuthenticatedDevice() != null); // Deprecated } @Timed @PUT @Path("/capabilities") - public void setCapabiltities(@Auth Account account, @Valid DeviceCapabilities capabilities) { - assert(account.getAuthenticatedDevice().isPresent()); - final long deviceId = account.getAuthenticatedDevice().get().getId(); - accounts.updateDevice(account, deviceId, d -> d.setCapabilities(capabilities)); + public void setCapabiltities(@Auth AuthenticatedAccount auth, @Valid DeviceCapabilities capabilities) { + assert (auth.getAuthenticatedDevice() != null); + final long deviceId = auth.getAuthenticatedDevice().getId(); + accounts.updateDevice(auth.getAccount(), deviceId, d -> d.setCapabilities(capabilities)); } @VisibleForTesting protected VerificationCode generateVerificationCode() { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DirectoryController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DirectoryController.java index 4093f0910..8953c545d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DirectoryController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DirectoryController.java @@ -1,14 +1,11 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.controllers; import com.codahale.metrics.annotation.Timed; import io.dropwizard.auth.Auth; -import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; -import org.whispersystems.textsecuregcm.storage.Account; - import javax.ws.rs.Consumes; import javax.ws.rs.GET; import javax.ws.rs.PUT; @@ -16,6 +13,8 @@ import javax.ws.rs.Path; import javax.ws.rs.Produces; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; @Path("/v1/directory") public class DirectoryController { @@ -30,15 +29,15 @@ public class DirectoryController { @GET @Path("/auth") @Produces(MediaType.APPLICATION_JSON) - public Response getAuthToken(@Auth Account account) { - return Response.ok().entity(directoryServiceTokenGenerator.generateFor(account.getNumber())).build(); + public Response getAuthToken(@Auth AuthenticatedAccount auth) { + return Response.ok().entity(directoryServiceTokenGenerator.generateFor(auth.getAccount().getNumber())).build(); } @PUT @Path("/feedback-v3/{status}") @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) - public Response setFeedback(@Auth Account account) { + public Response setFeedback(@Auth AuthenticatedAccount auth) { return Response.ok().build(); } @@ -47,7 +46,7 @@ public class DirectoryController { @GET @Path("/{token}") @Produces(MediaType.APPLICATION_JSON) - public Response getTokenPresence(@Auth Account account) { + public Response getTokenPresence(@Auth AuthenticatedAccount auth) { return Response.status(429).build(); } @@ -56,7 +55,7 @@ public class DirectoryController { @Path("/tokens") @Produces(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON) - public Response getContactIntersection(@Auth Account account) { + public Response getContactIntersection(@Auth AuthenticatedAccount auth) { return Response.status(429).build(); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DonationController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DonationController.java index 8295e8336..874c99413 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DonationController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DonationController.java @@ -34,12 +34,12 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.configuration.DonationConfiguration; import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationRequest; import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationResponse; import org.whispersystems.textsecuregcm.http.FaultTolerantHttpClient; import org.whispersystems.textsecuregcm.http.FormDataBodyPublisher; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.util.SystemMapper; @Path("/v1/donation") @@ -75,7 +75,7 @@ public class DonationController { @Path("/authorize-apple-pay") @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) - public CompletableFuture getApplePayAuthorization(@Auth Account account, @Valid ApplePayAuthorizationRequest request) { + public CompletableFuture getApplePayAuthorization(@Auth AuthenticatedAccount auth, @Valid ApplePayAuthorizationRequest request) { if (!supportedCurrencies.contains(request.getCurrency())) { return CompletableFuture.completedFuture(Response.status(422).build()); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeepAliveController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeepAliveController.java index 1f0b240b4..a0fa52ff2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeepAliveController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeepAliveController.java @@ -1,28 +1,27 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.controllers; +import static com.codahale.metrics.MetricRegistry.name; + import com.codahale.metrics.annotation.Timed; import io.dropwizard.auth.Auth; import io.micrometer.core.instrument.Metrics; +import javax.ws.rs.GET; +import javax.ws.rs.Path; +import javax.ws.rs.core.Response; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; import org.whispersystems.websocket.session.WebSocketSession; import org.whispersystems.websocket.session.WebSocketSessionContext; -import javax.ws.rs.GET; -import javax.ws.rs.Path; -import javax.ws.rs.core.Response; - -import static com.codahale.metrics.MetricRegistry.name; - @Path("/v1/keepalive") public class KeepAliveController { @@ -40,15 +39,14 @@ public class KeepAliveController { @Timed @GET - public Response getKeepAlive(@Auth Account account, - @WebSocketSession WebSocketSessionContext context) - { - if (account != null) { - if (!clientPresenceManager.isLocallyPresent(account.getUuid(), account.getAuthenticatedDevice().get().getId())) { + public Response getKeepAlive(@Auth AuthenticatedAccount auth, + @WebSocketSession WebSocketSessionContext context) { + if (auth != null) { + if (!clientPresenceManager.isLocallyPresent(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId())) { logger.warn("***** No local subscription found for {}::{}; age = {}ms, User-Agent = {}", - account.getUuid(), account.getAuthenticatedDevice().get().getId(), - System.currentTimeMillis() - context.getClient().getCreatedTimestamp(), - context.getClient().getUserAgent()); + auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(), + System.currentTimeMillis() - context.getClient().getCreatedTimestamp(), + context.getClient().getUserAgent()); context.getClient().close(1000, "OK"); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 4a33ec593..063202523 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -28,7 +28,8 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; import org.whispersystems.textsecuregcm.auth.Anonymous; -import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.entities.PreKeyCount; @@ -76,8 +77,8 @@ public class KeysController { @GET @Produces(MediaType.APPLICATION_JSON) - public PreKeyCount getStatus(@Auth Account account) { - int count = keysDynamoDb.getCount(account, account.getAuthenticatedDevice().get().getId()); + public PreKeyCount getStatus(@Auth AuthenticatedAccount auth) { + int count = keysDynamoDb.getCount(auth.getAccount(), auth.getAuthenticatedDevice().getId()); if (count > 0) { count = count - 1; @@ -89,10 +90,10 @@ public class KeysController { @Timed @PUT @Consumes(MediaType.APPLICATION_JSON) - public void setKeys(@Auth DisabledPermittedAccount disabledPermittedAccount, @Valid PreKeyState preKeys) { - Account account = disabledPermittedAccount.getAccount(); - Device device = account.getAuthenticatedDevice().get(); - boolean updateAccount = false; + public void setKeys(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth, @Valid PreKeyState preKeys) { + Account account = disabledPermittedAuth.getAccount(); + Device device = disabledPermittedAuth.getAuthenticatedDevice(); + boolean updateAccount = false; if (!preKeys.getSignedPreKey().equals(device.getSignedPreKey())) { updateAccount = true; @@ -116,7 +117,7 @@ public class KeysController { @GET @Path("/{identifier}/{device_id}") @Produces(MediaType.APPLICATION_JSON) - public Response getDeviceKeys(@Auth Optional account, + public Response getDeviceKeys(@Auth Optional auth, @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, @PathParam("identifier") AmbiguousIdentifier targetName, @PathParam("device_id") String deviceId, @@ -125,14 +126,16 @@ public class KeysController { targetName.incrementRequestCounter("getDeviceKeys", userAgent); - if (!account.isPresent() && !accessKey.isPresent()) { + if (auth.isEmpty() && accessKey.isEmpty()) { throw new WebApplicationException(Response.Status.UNAUTHORIZED); } + final Optional account = auth.map(AuthenticatedAccount::getAccount); + Optional target = accounts.get(targetName); OptionalAccess.verify(account, accessKey, target, deviceId); - assert(target.isPresent()); + assert (target.isPresent()); { final String sourceCountryCode = account.map(a -> Util.getCountryCode(a.getNumber())).orElse("0"); @@ -146,7 +149,9 @@ public class KeysController { } if (account.isPresent()) { - rateLimiters.getPreKeysLimiter().validate(account.get().getUuid() + "." + account.get().getAuthenticatedDevice().get().getId() + "__" + target.get().getUuid() + "." + deviceId); + rateLimiters.getPreKeysLimiter().validate( + account.get().getUuid() + "." + auth.get().getAuthenticatedDevice().getId() + "__" + target.get().getUuid() + + "." + deviceId); try { preKeyRateLimiter.validate(account.get()); @@ -188,22 +193,25 @@ public class KeysController { @PUT @Path("/signed") @Consumes(MediaType.APPLICATION_JSON) - public void setSignedKey(@Auth Account account, @Valid SignedPreKey signedPreKey) { - Device device = account.getAuthenticatedDevice().get(); + public void setSignedKey(@Auth AuthenticatedAccount auth, @Valid SignedPreKey signedPreKey) { + Device device = auth.getAuthenticatedDevice(); - accounts.updateDevice(account, device.getId(), d -> d.setSignedPreKey(signedPreKey)); + accounts.updateDevice(auth.getAccount(), device.getId(), d -> d.setSignedPreKey(signedPreKey)); } @Timed @GET @Path("/signed") @Produces(MediaType.APPLICATION_JSON) - public Optional getSignedKey(@Auth Account account) { - Device device = account.getAuthenticatedDevice().get(); + public Optional getSignedKey(@Auth AuthenticatedAccount auth) { + Device device = auth.getAuthenticatedDevice(); SignedPreKey signedPreKey = device.getSignedPreKey(); - if (signedPreKey != null) return Optional.of(signedPreKey); - else return Optional.empty(); + if (signedPreKey != null) { + return Optional.of(signedPreKey); + } else { + return Optional.empty(); + } } private Map getLocalKeys(Account destination, String deviceIdSelector) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index f0afd0e00..29c97af51 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -62,6 +62,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; import org.whispersystems.textsecuregcm.auth.Anonymous; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKeys; import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration; @@ -189,12 +190,12 @@ public class MessageController { @PUT @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) - public Response sendMessage(@Auth Optional source, - @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, - @HeaderParam("User-Agent") String userAgent, - @HeaderParam("X-Forwarded-For") String forwardedFor, - @PathParam("destination") AmbiguousIdentifier destinationName, - @Valid IncomingMessageList messages) + public Response sendMessage(@Auth Optional source, + @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, + @HeaderParam("User-Agent") String userAgent, + @HeaderParam("X-Forwarded-For") String forwardedFor, + @PathParam("destination") AmbiguousIdentifier destinationName, + @Valid IncomingMessageList messages) throws RateLimitExceededException, RateLimitChallengeException { destinationName.incrementRequestCounter("sendMessage", userAgent); @@ -203,20 +204,22 @@ public class MessageController { throw new WebApplicationException(Response.Status.UNAUTHORIZED); } - if (source.isPresent() && !source.get().isFor(destinationName)) { - assert source.get().getMasterDevice().isPresent(); + if (source.isPresent() && !source.get().getAccount().isFor(destinationName)) { + assert source.get().getAccount().getMasterDevice().isPresent(); - final Device masterDevice = source.get().getMasterDevice().get(); - final String senderCountryCode = Util.getCountryCode(source.get().getNumber()); + final Device masterDevice = source.get().getAccount().getMasterDevice().get(); + final String senderCountryCode = Util.getCountryCode(source.get().getAccount().getNumber()); - if (StringUtils.isAllBlank(masterDevice.getApnId(), masterDevice.getVoipApnId(), masterDevice.getGcmId()) || masterDevice.getUninstalledFeedbackTimestamp() > 0) { - Metrics.counter(UNSEALED_SENDER_WITHOUT_PUSH_TOKEN_COUNTER_NAME, SENDER_COUNTRY_TAG_NAME, senderCountryCode).increment(); + if (StringUtils.isAllBlank(masterDevice.getApnId(), masterDevice.getVoipApnId(), masterDevice.getGcmId()) + || masterDevice.getUninstalledFeedbackTimestamp() > 0) { + Metrics.counter(UNSEALED_SENDER_WITHOUT_PUSH_TOKEN_COUNTER_NAME, SENDER_COUNTRY_TAG_NAME, senderCountryCode) + .increment(); } } final String senderType; - if (source.isPresent() && !source.get().isFor(destinationName)) { + if (source.isPresent() && !source.get().getAccount().isFor(destinationName)) { identifiedMeter.mark(); senderType = "identified"; } else if (source.isEmpty()) { @@ -246,23 +249,26 @@ public class MessageController { } try { - boolean isSyncMessage = source.isPresent() && source.get().isFor(destinationName); + boolean isSyncMessage = source.isPresent() && source.get().getAccount().isFor(destinationName); Optional destination; - if (!isSyncMessage) destination = accountsManager.get(destinationName); - else destination = source; + if (!isSyncMessage) { + destination = accountsManager.get(destinationName); + } else { + destination = source.map(AuthenticatedAccount::getAccount); + } - OptionalAccess.verify(source, accessKey, destination); - assert(destination.isPresent()); + OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination); + assert (destination.isPresent()); - if (source.isPresent() && !source.get().isFor(destinationName)) { - rateLimiters.getMessagesLimiter().validate(source.get().getUuid(), destination.get().getUuid()); + if (source.isPresent() && !source.get().getAccount().isFor(destinationName)) { + rateLimiters.getMessagesLimiter().validate(source.get().getAccount().getUuid(), destination.get().getUuid()); - final String senderCountryCode = Util.getCountryCode(source.get().getNumber()); + final String senderCountryCode = Util.getCountryCode(source.get().getAccount().getNumber()); try { - unsealedSenderRateLimiter.validate(source.get(), destination.get()); + unsealedSenderRateLimiter.validate(source.get().getAccount(), destination.get()); } catch (final RateLimitExceededException e) { final boolean legacyClient = rateLimitChallengeManager.isClientBelowMinimumVersion(userAgent); @@ -276,11 +282,11 @@ public class MessageController { throw e; } - throw new RateLimitChallengeException(source.get(), e.getRetryDuration()); + throw new RateLimitChallengeException(source.get().getAccount(), e.getRetryDuration()); } final String destinationCountryCode = Util.getCountryCode(destination.get().getNumber()); - final Device masterDevice = source.get().getMasterDevice().get(); + final Device masterDevice = source.get().getAccount().getMasterDevice().get(); if (!senderCountryCode.equals(destinationCountryCode)) { recordInternationalUnsealedSenderMetrics(forwardedFor, senderCountryCode, destination.get().getNumber()); @@ -293,31 +299,34 @@ public class MessageController { .orElse(false); if (isRateLimitedHost) { - return declineDelivery(messages, source.get(), destination.get()); + return declineDelivery(messages, source.get().getAccount(), destination.get()); } } } } } - validateCompleteDeviceList(destination.get(), messages.getMessages(), isSyncMessage); + validateCompleteDeviceList(destination.get(), messages.getMessages(), isSyncMessage, + source.map(AuthenticatedAccount::getAuthenticatedDevice).map(Device::getId)); validateRegistrationIds(destination.get(), messages.getMessages()); final List tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent), - Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.isOnline())), - Tag.of(SENDER_TYPE_TAG_NAME, senderType), - Tag.of(DESTINATION_TYPE_TAG_NAME, destinationName.hasNumber() ? "e164" : "uuid")); + Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.isOnline())), + Tag.of(SENDER_TYPE_TAG_NAME, senderType), + Tag.of(DESTINATION_TYPE_TAG_NAME, destinationName.hasNumber() ? "e164" : "uuid")); for (IncomingMessage incomingMessage : messages.getMessages()) { Optional destinationDevice = destination.get().getDevice(incomingMessage.getDestinationDeviceId()); if (destinationDevice.isPresent()) { Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment(); - sendMessage(source, destination.get(), destinationDevice.get(), messages.getTimestamp(), messages.isOnline(), incomingMessage); + sendMessage(source, destination.get(), destinationDevice.get(), messages.getTimestamp(), messages.isOnline(), + incomingMessage); } } - return Response.ok(new SendMessageResponse(!isSyncMessage && source.isPresent() && source.get().getEnabledDeviceCount() > 1)).build(); + return Response.ok(new SendMessageResponse( + !isSyncMessage && source.isPresent() && source.get().getAccount().getEnabledDeviceCount() > 1)).build(); } catch (NoSuchUserException e) { throw new WebApplicationException(Response.status(404).build()); } catch (MismatchedDevicesException e) { @@ -380,7 +389,7 @@ public class MessageController { final Set> deviceIdAndRegistrationIdSet = accountToDeviceIdAndRegistrationIdMap.get(account); final Set deviceIds = deviceIdAndRegistrationIdSet.stream().map(Pair::first).collect(Collectors.toSet()); try { - validateCompleteDeviceList(account, deviceIds, false); + validateCompleteDeviceList(account, deviceIds, false, Optional.empty()); validateRegistrationIds(account, deviceIdAndRegistrationIdSet.stream()); } catch (MismatchedDevicesException e) { accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(), @@ -476,7 +485,9 @@ public class MessageController { if (random.nextDouble() <= messageRateConfiguration.getReceiptProbability()) { receiptExecutorService.schedule(() -> { try { - receiptSender.sendReceipt(destination, source.getNumber(), timestamp); + receiptSender.sendReceipt( + new AuthenticatedAccount(() -> new Pair<>(destination, destination.getMasterDevice().get())), + source.getNumber(), timestamp); } catch (final NoSuchUserException ignored) { } }, receiptDelay.toMillis(), TimeUnit.MILLISECONDS); @@ -503,16 +514,17 @@ public class MessageController { @Timed @GET @Produces(MediaType.APPLICATION_JSON) - public OutgoingMessageEntityList getPendingMessages(@Auth Account account, @HeaderParam("User-Agent") String userAgent) { - assert account.getAuthenticatedDevice().isPresent(); + public OutgoingMessageEntityList getPendingMessages(@Auth AuthenticatedAccount auth, + @HeaderParam("User-Agent") String userAgent) { + assert auth.getAuthenticatedDevice() != null; - if (!Util.isEmpty(account.getAuthenticatedDevice().get().getApnId())) { - RedisOperation.unchecked(() -> apnFallbackManager.cancel(account, account.getAuthenticatedDevice().get())); + if (!Util.isEmpty(auth.getAuthenticatedDevice().getApnId())) { + RedisOperation.unchecked(() -> apnFallbackManager.cancel(auth.getAccount(), auth.getAuthenticatedDevice())); } final OutgoingMessageEntityList outgoingMessages = messagesManager.getMessagesForDevice( - account.getUuid(), - account.getAuthenticatedDevice().get().getId(), + auth.getAccount().getUuid(), + auth.getAuthenticatedDevice().getId(), userAgent, false); @@ -549,21 +561,20 @@ public class MessageController { @Timed @DELETE @Path("/{source}/{timestamp}") - public void removePendingMessage(@Auth Account account, - @PathParam("source") String source, - @PathParam("timestamp") long timestamp) - { + public void removePendingMessage(@Auth AuthenticatedAccount auth, + @PathParam("source") String source, + @PathParam("timestamp") long timestamp) { try { - WebSocketConnection.recordMessageDeliveryDuration(timestamp, account.getAuthenticatedDevice().get()); + WebSocketConnection.recordMessageDeliveryDuration(timestamp, auth.getAuthenticatedDevice()); Optional message = messagesManager.delete( - account.getUuid(), - account.getAuthenticatedDevice().get().getId(), - source, timestamp); + auth.getAccount().getUuid(), + auth.getAuthenticatedDevice().getId(), + source, timestamp); if (message.isPresent() && message.get().getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE) { - receiptSender.sendReceipt(account, - message.get().getSource(), - message.get().getTimestamp()); + receiptSender.sendReceipt(auth, + message.get().getSource(), + message.get().getTimestamp()); } } catch (NoSuchUserException e) { logger.warn("Sending delivery receipt", e); @@ -573,17 +584,18 @@ public class MessageController { @Timed @DELETE @Path("/uuid/{uuid}") - public void removePendingMessage(@Auth Account account, @PathParam("uuid") UUID uuid) { + public void removePendingMessage(@Auth AuthenticatedAccount auth, @PathParam("uuid") UUID uuid) { try { Optional message = messagesManager.delete( - account.getUuid(), - account.getAuthenticatedDevice().get().getId(), - uuid); + auth.getAccount().getUuid(), + auth.getAuthenticatedDevice().getId(), + uuid); if (message.isPresent()) { - WebSocketConnection.recordMessageDeliveryDuration(message.get().getTimestamp(), account.getAuthenticatedDevice().get()); - if (!Util.isEmpty(message.get().getSource()) && message.get().getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE) { - receiptSender.sendReceipt(account, message.get().getSource(), message.get().getTimestamp()); + WebSocketConnection.recordMessageDeliveryDuration(message.get().getTimestamp(), auth.getAuthenticatedDevice()); + if (!Util.isEmpty(message.get().getSource()) + && message.get().getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE) { + receiptSender.sendReceipt(auth, message.get().getSource(), message.get().getTimestamp()); } } @@ -595,7 +607,8 @@ public class MessageController { @Timed @POST @Path("/report/{sourceNumber}/{messageGuid}") - public Response reportMessage(@Auth Account account, @PathParam("sourceNumber") String sourceNumber, @PathParam("messageGuid") UUID messageGuid) { + public Response reportMessage(@Auth AuthenticatedAccount auth, @PathParam("sourceNumber") String sourceNumber, + @PathParam("messageGuid") UUID messageGuid) { reportMessageManager.report(sourceNumber, messageGuid); @@ -603,27 +616,26 @@ public class MessageController { .build(); } - private void sendMessage(Optional source, - Account destinationAccount, - Device destinationDevice, - long timestamp, - boolean online, - IncomingMessage incomingMessage) - throws NoSuchUserException - { + private void sendMessage(Optional source, + Account destinationAccount, + Device destinationDevice, + long timestamp, + boolean online, + IncomingMessage incomingMessage) + throws NoSuchUserException { try (final Timer.Context ignored = sendMessageInternalTimer.time()) { - Optional messageBody = getMessageBody(incomingMessage); + Optional messageBody = getMessageBody(incomingMessage); Optional messageContent = getMessageContent(incomingMessage); Envelope.Builder messageBuilder = Envelope.newBuilder(); messageBuilder.setType(Envelope.Type.forNumber(incomingMessage.getType())) - .setTimestamp(timestamp == 0 ? System.currentTimeMillis() : timestamp) - .setServerTimestamp(System.currentTimeMillis()); + .setTimestamp(timestamp == 0 ? System.currentTimeMillis() : timestamp) + .setServerTimestamp(System.currentTimeMillis()); if (source.isPresent()) { - messageBuilder.setSource(source.get().getNumber()) - .setSourceUuid(source.get().getUuid().toString()) - .setSourceDevice((int)source.get().getAuthenticatedDevice().get().getId()); + messageBuilder.setSource(source.get().getAccount().getNumber()) + .setSourceUuid(source.get().getAccount().getUuid().toString()) + .setSourceDevice((int) source.get().getAuthenticatedDevice().getId()); } if (messageBody.isPresent()) { @@ -697,24 +709,26 @@ public class MessageController { } @VisibleForTesting - public static void validateCompleteDeviceList(Account account, List messages, boolean isSyncMessage) + public static void validateCompleteDeviceList(Account account, List messages, boolean isSyncMessage, + Optional authenticatedDeviceId) throws MismatchedDevicesException { - Set messageDeviceIds = messages.stream().map(IncomingMessage::getDestinationDeviceId).collect(Collectors.toSet()); - validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage); + Set messageDeviceIds = messages.stream().map(IncomingMessage::getDestinationDeviceId) + .collect(Collectors.toSet()); + validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage, authenticatedDeviceId); } @VisibleForTesting - public static void validateCompleteDeviceList(Account account, Set messageDeviceIds, boolean isSyncMessage) + public static void validateCompleteDeviceList(Account account, Set messageDeviceIds, boolean isSyncMessage, + Optional authenticatedDeviceId) throws MismatchedDevicesException { Set accountDeviceIds = new HashSet<>(); List missingDeviceIds = new LinkedList<>(); - List extraDeviceIds = new LinkedList<>(); + List extraDeviceIds = new LinkedList<>(); - for (Device device : account.getDevices()) { + for (Device device : account.getDevices()) { if (device.isEnabled() && - !(isSyncMessage && device.getId() == account.getAuthenticatedDevice().get().getId())) - { + !(isSyncMessage && device.getId() == authenticatedDeviceId.get())) { accountDeviceIds.add(device.getId()); if (!messageDeviceIds.contains(device.getId())) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/PaymentsController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/PaymentsController.java index ef1b95b68..f84f7a8c4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/PaymentsController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/PaymentsController.java @@ -1,23 +1,21 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.controllers; import com.codahale.metrics.annotation.Timed; -import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; -import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; -import org.whispersystems.textsecuregcm.entities.CurrencyConversionEntityList; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager; - +import io.dropwizard.auth.Auth; import javax.ws.rs.GET; import javax.ws.rs.Path; import javax.ws.rs.Produces; import javax.ws.rs.core.MediaType; - -import io.dropwizard.auth.Auth; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; +import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; +import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager; +import org.whispersystems.textsecuregcm.entities.CurrencyConversionEntityList; @Path("/v1/payments") public class PaymentsController { @@ -34,15 +32,15 @@ public class PaymentsController { @GET @Path("/auth") @Produces(MediaType.APPLICATION_JSON) - public ExternalServiceCredentials getAuth(@Auth Account account) { - return paymentsServiceCredentialGenerator.generateFor(account.getUuid().toString()); + public ExternalServiceCredentials getAuth(@Auth AuthenticatedAccount auth) { + return paymentsServiceCredentialGenerator.generateFor(auth.getAccount().getUuid().toString()); } @Timed @GET @Path("/conversions") @Produces(MediaType.APPLICATION_JSON) - public CurrencyConversionEntityList getConversions(@Auth Account account) { + public CurrencyConversionEntityList getConversions(@Auth AuthenticatedAccount auth) { return currencyManager.getCurrencyConversions().orElseThrow(); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java index 49ee627ce..410a278a7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -41,6 +41,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; import org.whispersystems.textsecuregcm.auth.Anonymous; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessChecksum; import org.whispersystems.textsecuregcm.entities.CreateProfileRequest; @@ -107,60 +108,64 @@ public class ProfileController { this.isZkEnabled = isZkEnabled; } - @Timed - @PUT - @Produces(MediaType.APPLICATION_JSON) - @Consumes(MediaType.APPLICATION_JSON) - public Response setProfile(@Auth Account account, @Valid CreateProfileRequest request) { - if (!isZkEnabled) throw new WebApplicationException(Response.Status.NOT_FOUND); + @Timed + @PUT + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + public Response setProfile(@Auth AuthenticatedAccount auth, @Valid CreateProfileRequest request) { + if (!isZkEnabled) { + throw new WebApplicationException(Response.Status.NOT_FOUND); + } - final Set allowedPaymentsCountryCodes = - dynamicConfigurationManager.getConfiguration().getPaymentsConfiguration().getAllowedCountryCodes(); + final Set allowedPaymentsCountryCodes = + dynamicConfigurationManager.getConfiguration().getPaymentsConfiguration().getAllowedCountryCodes(); - if (StringUtils.isNotBlank(request.getPaymentAddress()) && - !allowedPaymentsCountryCodes.contains(Util.getCountryCode(account.getNumber()))) { + if (StringUtils.isNotBlank(request.getPaymentAddress()) && + !allowedPaymentsCountryCodes.contains(Util.getCountryCode(auth.getAccount().getNumber()))) { - return Response.status(Status.FORBIDDEN).build(); - } + return Response.status(Status.FORBIDDEN).build(); + } - Optional currentProfile = profilesManager.get(account.getUuid(), request.getVersion()); - String avatar = request.isAvatar() ? generateAvatarObjectName() : null; - Optional response = Optional.empty(); + Optional currentProfile = profilesManager.get(auth.getAccount().getUuid(), request.getVersion()); + String avatar = request.isAvatar() ? generateAvatarObjectName() : null; + Optional response = Optional.empty(); - profilesManager.set(account.getUuid(), - new VersionedProfile( - request.getVersion(), - request.getName(), - avatar, - request.getAboutEmoji(), - request.getAbout(), - request.getPaymentAddress(), - request.getCommitment().serialize())); + profilesManager.set(auth.getAccount().getUuid(), + new VersionedProfile( + request.getVersion(), + request.getName(), + avatar, + request.getAboutEmoji(), + request.getAbout(), + request.getPaymentAddress(), + request.getCommitment().serialize())); if (request.isAvatar()) { - Optional currentAvatar = Optional.empty(); + Optional currentAvatar = Optional.empty(); - if (currentProfile.isPresent() && currentProfile.get().getAvatar() != null && currentProfile.get().getAvatar().startsWith("profiles/")) { - currentAvatar = Optional.of(currentProfile.get().getAvatar()); - } + if (currentProfile.isPresent() && currentProfile.get().getAvatar() != null && currentProfile.get().getAvatar() + .startsWith("profiles/")) { + currentAvatar = Optional.of(currentProfile.get().getAvatar()); + } - if (currentAvatar.isEmpty() && account.getAvatar() != null && account.getAvatar().startsWith("profiles/")) { - currentAvatar = Optional.of(account.getAvatar()); - } + if (currentAvatar.isEmpty() && auth.getAccount().getAvatar() != null && auth.getAccount().getAvatar() + .startsWith("profiles/")) { + currentAvatar = Optional.of(auth.getAccount().getAvatar()); + } - currentAvatar.ifPresent(s -> s3client.deleteObject(DeleteObjectRequest.builder() - .bucket(bucket) - .key(s) - .build())); + currentAvatar.ifPresent(s -> s3client.deleteObject(DeleteObjectRequest.builder() + .bucket(bucket) + .key(s) + .build())); - response = Optional.of(generateAvatarUploadForm(avatar)); + response = Optional.of(generateAvatarUploadForm(avatar)); } - accountsManager.update(account, a -> { - a.setProfileName(request.getName()); - a.setAvatar(avatar); - a.setCurrentProfileVersion(request.getVersion()); - }); + accountsManager.update(auth.getAccount(), a -> { + a.setProfileName(request.getName()); + a.setAvatar(avatar); + a.setCurrentProfileVersion(request.getVersion()); + }); if (response.isPresent()) return Response.ok(response).build(); else return Response.ok().build(); @@ -170,29 +175,32 @@ public class ProfileController { @GET @Produces(MediaType.APPLICATION_JSON) @Path("/{uuid}/{version}") - public Optional getProfile(@Auth Optional requestAccount, - @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, - @PathParam("uuid") UUID uuid, - @PathParam("version") String version) - throws RateLimitExceededException - { - if (!isZkEnabled) throw new WebApplicationException(Response.Status.NOT_FOUND); - return getVersionedProfile(requestAccount, accessKey, uuid, version, Optional.empty()); + public Optional getProfile(@Auth Optional auth, + @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, + @PathParam("uuid") UUID uuid, + @PathParam("version") String version) + throws RateLimitExceededException { + if (!isZkEnabled) { + throw new WebApplicationException(Response.Status.NOT_FOUND); + } + return getVersionedProfile(auth.map(AuthenticatedAccount::getAccount), accessKey, uuid, version, Optional.empty()); } @Timed @GET @Produces(MediaType.APPLICATION_JSON) @Path("/{uuid}/{version}/{credentialRequest}") - public Optional getProfile(@Auth Optional requestAccount, - @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, - @PathParam("uuid") UUID uuid, - @PathParam("version") String version, - @PathParam("credentialRequest") String credentialRequest) - throws RateLimitExceededException - { - if (!isZkEnabled) throw new WebApplicationException(Response.Status.NOT_FOUND); - return getVersionedProfile(requestAccount, accessKey, uuid, version, Optional.of(credentialRequest)); + public Optional getProfile(@Auth Optional auth, + @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, + @PathParam("uuid") UUID uuid, + @PathParam("version") String version, + @PathParam("credentialRequest") String credentialRequest) + throws RateLimitExceededException { + if (!isZkEnabled) { + throw new WebApplicationException(Response.Status.NOT_FOUND); + } + return getVersionedProfile(auth.map(AuthenticatedAccount::getAccount), accessKey, uuid, version, + Optional.of(credentialRequest)); } private Optional getVersionedProfile(Optional requestAccount, @@ -255,22 +263,23 @@ public class ProfileController { } - @Timed - @GET - @Produces(MediaType.APPLICATION_JSON) - @Path("/username/{username}") - public Profile getProfileByUsername(@Auth Account account, @PathParam("username") String username) throws RateLimitExceededException { - rateLimiters.getUsernameLookupLimiter().validate(account.getUuid()); + @Timed + @GET + @Produces(MediaType.APPLICATION_JSON) + @Path("/username/{username}") + public Profile getProfileByUsername(@Auth AuthenticatedAccount auth, @PathParam("username") String username) + throws RateLimitExceededException { + rateLimiters.getUsernameLookupLimiter().validate(auth.getAccount().getUuid()); - username = username.toLowerCase(); + username = username.toLowerCase(); - Optional uuid = usernamesManager.get(username); + Optional uuid = usernamesManager.get(username); - if (uuid.isEmpty()) { - throw new WebApplicationException(Response.status(Response.Status.NOT_FOUND).build()); - } + if (uuid.isEmpty()) { + throw new WebApplicationException(Response.status(Response.Status.NOT_FOUND).build()); + } - Optional accountProfile = accountsManager.get(uuid.get()); + Optional accountProfile = accountsManager.get(uuid.get()); if (accountProfile.isEmpty()) { throw new WebApplicationException(Response.status(Response.Status.NOT_FOUND).build()); @@ -312,40 +321,40 @@ public class ProfileController { // Old profile endpoints. Replaced by versioned profile endpoints (above) - @Deprecated - @Timed - @PUT - @Produces(MediaType.APPLICATION_JSON) - @Path("/name/{name}") - public void setProfile(@Auth Account account, @PathParam("name") @ExactlySize(value = {72, 108}, payload = {Unwrapping.Unwrap.class}) Optional name) { - accountsManager.update(account, a -> a.setProfileName(name.orElse(null))); - } + @Deprecated + @Timed + @PUT + @Produces(MediaType.APPLICATION_JSON) + @Path("/name/{name}") + public void setProfile(@Auth AuthenticatedAccount auth, + @PathParam("name") @ExactlySize(value = {72, 108}, payload = {Unwrapping.Unwrap.class}) Optional name) { + accountsManager.update(auth.getAccount(), a -> a.setProfileName(name.orElse(null))); + } @Deprecated @Timed @GET @Produces(MediaType.APPLICATION_JSON) @Path("/{identifier}") - public Profile getProfile(@Auth Optional requestAccount, - @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, - @HeaderParam("User-Agent") String userAgent, - @PathParam("identifier") AmbiguousIdentifier identifier, - @QueryParam("ca") boolean useCaCertificate) - throws RateLimitExceededException - { + public Profile getProfile(@Auth Optional auth, + @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional accessKey, + @HeaderParam("User-Agent") String userAgent, + @PathParam("identifier") AmbiguousIdentifier identifier, + @QueryParam("ca") boolean useCaCertificate) + throws RateLimitExceededException { identifier.incrementRequestCounter("getProfile", userAgent); - if (requestAccount.isEmpty() && accessKey.isEmpty()) { + if (auth.isEmpty() && accessKey.isEmpty()) { throw new WebApplicationException(Response.Status.UNAUTHORIZED); } - if (requestAccount.isPresent()) { - rateLimiters.getProfileLimiter().validate(requestAccount.get().getUuid()); + if (auth.isPresent()) { + rateLimiters.getProfileLimiter().validate(auth.get().getAccount().getUuid()); } Optional accountProfile = accountsManager.get(identifier); - OptionalAccess.verify(requestAccount, accessKey, accountProfile); + OptionalAccess.verify(auth.map(AuthenticatedAccount::getAccount), accessKey, accountProfile); Optional username = Optional.empty(); @@ -369,24 +378,24 @@ public class ProfileController { } - @Deprecated - @Timed - @GET - @Produces(MediaType.APPLICATION_JSON) - @Path("/form/avatar") - public ProfileAvatarUploadAttributes getAvatarUploadForm(@Auth Account account) { - String previousAvatar = account.getAvatar(); - String objectName = generateAvatarObjectName(); - ProfileAvatarUploadAttributes profileAvatarUploadAttributes = generateAvatarUploadForm(objectName); + @Deprecated + @Timed + @GET + @Produces(MediaType.APPLICATION_JSON) + @Path("/form/avatar") + public ProfileAvatarUploadAttributes getAvatarUploadForm(@Auth AuthenticatedAccount auth) { + String previousAvatar = auth.getAccount().getAvatar(); + String objectName = generateAvatarObjectName(); + ProfileAvatarUploadAttributes profileAvatarUploadAttributes = generateAvatarUploadForm(objectName); - if (previousAvatar != null && previousAvatar.startsWith("profiles/")) { - s3client.deleteObject(DeleteObjectRequest.builder() - .bucket(bucket) - .key(previousAvatar) - .build()); - } + if (previousAvatar != null && previousAvatar.startsWith("profiles/")) { + s3client.deleteObject(DeleteObjectRequest.builder() + .bucket(bucket) + .key(previousAvatar) + .build()); + } - accountsManager.update(account, a -> a.setAvatar(objectName)); + accountsManager.update(auth.getAccount(), a -> a.setAvatar(objectName)); return profileAvatarUploadAttributes; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java index a8d4ee177..d1c392f1c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProvisioningController.java @@ -17,10 +17,10 @@ import javax.ws.rs.Produces; import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.entities.ProvisioningMessage; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.push.ProvisioningManager; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress; @Path("/v1/provisioning") @@ -39,16 +39,15 @@ public class ProvisioningController { @PUT @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) - public void sendProvisioningMessage(@Auth Account source, - @PathParam("destination") String destinationName, - @Valid ProvisioningMessage message) + public void sendProvisioningMessage(@Auth AuthenticatedAccount auth, + @PathParam("destination") String destinationName, + @Valid ProvisioningMessage message) throws RateLimitExceededException { - rateLimiters.getMessagesLimiter().validate(source.getUuid()); + rateLimiters.getMessagesLimiter().validate(auth.getAccount().getUuid()); if (!provisioningManager.sendProvisioningMessage(new ProvisioningAddress(destinationName, 0), - Base64.getDecoder().decode(message.getBody()))) - { + Base64.getDecoder().decode(message.getBody()))) { throw new WebApplicationException(Response.Status.NOT_FOUND); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RemoteConfigController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RemoteConfigController.java index 7e1d6ccf0..2127cf0d7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RemoteConfigController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RemoteConfigController.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -8,13 +8,16 @@ package org.whispersystems.textsecuregcm.controllers; import com.codahale.metrics.annotation.Timed; import com.google.common.annotations.VisibleForTesting; import io.dropwizard.auth.Auth; -import org.whispersystems.textsecuregcm.entities.UserRemoteConfig; -import org.whispersystems.textsecuregcm.entities.UserRemoteConfigList; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.storage.RemoteConfig; -import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager; -import org.whispersystems.textsecuregcm.util.Conversions; - +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.Stream; import javax.validation.Valid; import javax.ws.rs.Consumes; import javax.ws.rs.DELETE; @@ -27,16 +30,12 @@ import javax.ws.rs.Produces; import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.security.MessageDigest; -import java.security.NoSuchAlgorithmException; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.UUID; -import java.util.stream.Collectors; -import java.util.stream.Stream; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.entities.UserRemoteConfig; +import org.whispersystems.textsecuregcm.entities.UserRemoteConfigList; +import org.whispersystems.textsecuregcm.storage.RemoteConfig; +import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager; +import org.whispersystems.textsecuregcm.util.Conversions; @Path("/v1/config") public class RemoteConfigController { @@ -57,15 +56,19 @@ public class RemoteConfigController { @GET @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) - public UserRemoteConfigList getAll(@Auth Account account) { + public UserRemoteConfigList getAll(@Auth AuthenticatedAccount auth) { try { MessageDigest digest = MessageDigest.getInstance("SHA1"); - final Stream globalConfigStream = globalConfig.entrySet().stream().map(entry -> new UserRemoteConfig(GLOBAL_CONFIG_PREFIX + entry.getKey(), true, entry.getValue())); + final Stream globalConfigStream = globalConfig.entrySet().stream() + .map(entry -> new UserRemoteConfig(GLOBAL_CONFIG_PREFIX + entry.getKey(), true, entry.getValue())); return new UserRemoteConfigList(Stream.concat(remoteConfigsManager.getAll().stream().map(config -> { - final byte[] hashKey = config.getHashKey() != null ? config.getHashKey().getBytes(StandardCharsets.UTF_8) : config.getName().getBytes(StandardCharsets.UTF_8); - boolean inBucket = isInBucket(digest, account.getUuid(), hashKey, config.getPercentage(), config.getUuids()); - return new UserRemoteConfig(config.getName(), inBucket, inBucket ? config.getValue() : config.getDefaultValue()); + final byte[] hashKey = config.getHashKey() != null ? config.getHashKey().getBytes(StandardCharsets.UTF_8) + : config.getName().getBytes(StandardCharsets.UTF_8); + boolean inBucket = isInBucket(digest, auth.getAccount().getUuid(), hashKey, config.getPercentage(), + config.getUuids()); + return new UserRemoteConfig(config.getName(), inBucket, + inBucket ? config.getValue() : config.getDefaultValue()); }), globalConfigStream).collect(Collectors.toList())); } catch (NoSuchAlgorithmException e) { throw new AssertionError(e); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureBackupController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureBackupController.java index d9eca61a8..abc99c25a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureBackupController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureBackupController.java @@ -1,21 +1,19 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.controllers; import com.codahale.metrics.annotation.Timed; -import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; -import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; -import org.whispersystems.textsecuregcm.storage.Account; - +import io.dropwizard.auth.Auth; import javax.ws.rs.GET; import javax.ws.rs.Path; import javax.ws.rs.Produces; import javax.ws.rs.core.MediaType; - -import io.dropwizard.auth.Auth; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; +import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; @Path("/v1/backup") public class SecureBackupController { @@ -30,7 +28,7 @@ public class SecureBackupController { @GET @Path("/auth") @Produces(MediaType.APPLICATION_JSON) - public ExternalServiceCredentials getAuth(@Auth Account account) { - return backupServiceCredentialGenerator.generateFor(account.getUuid().toString()); + public ExternalServiceCredentials getAuth(@Auth AuthenticatedAccount auth) { + return backupServiceCredentialGenerator.generateFor(auth.getAccount().getUuid().toString()); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureStorageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureStorageController.java index 150a62e76..af9a564f1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureStorageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/SecureStorageController.java @@ -1,21 +1,19 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.controllers; import com.codahale.metrics.annotation.Timed; -import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; -import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; -import org.whispersystems.textsecuregcm.storage.Account; - +import io.dropwizard.auth.Auth; import javax.ws.rs.GET; import javax.ws.rs.Path; import javax.ws.rs.Produces; import javax.ws.rs.core.MediaType; - -import io.dropwizard.auth.Auth; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; +import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; @Path("/v1/storage") public class SecureStorageController { @@ -30,7 +28,7 @@ public class SecureStorageController { @GET @Path("/auth") @Produces(MediaType.APPLICATION_JSON) - public ExternalServiceCredentials getAuth(@Auth Account account) { - return storageServiceCredentialGenerator.generateFor(account.getUuid().toString()); + public ExternalServiceCredentials getAuth(@Auth AuthenticatedAccount auth) { + return storageServiceCredentialGenerator.generateFor(auth.getAccount().getUuid().toString()); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/StickerController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/StickerController.java index 5d626af5e..3f8bd0ae5 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/StickerController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/StickerController.java @@ -1,21 +1,16 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.controllers; import io.dropwizard.auth.Auth; -import org.whispersystems.textsecuregcm.entities.StickerPackFormUploadAttributes; -import org.whispersystems.textsecuregcm.entities.StickerPackFormUploadAttributes.StickerPackFormUploadItem; -import org.whispersystems.textsecuregcm.limits.RateLimiters; -import org.whispersystems.textsecuregcm.s3.PolicySigner; -import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.textsecuregcm.util.Constants; -import org.whispersystems.textsecuregcm.util.Hex; -import org.whispersystems.textsecuregcm.util.Pair; - +import java.security.SecureRandom; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import java.util.LinkedList; +import java.util.List; import javax.validation.constraints.Max; import javax.validation.constraints.Min; import javax.ws.rs.GET; @@ -23,11 +18,15 @@ import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.core.MediaType; -import java.security.SecureRandom; -import java.time.ZoneOffset; -import java.time.ZonedDateTime; -import java.util.LinkedList; -import java.util.List; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.entities.StickerPackFormUploadAttributes; +import org.whispersystems.textsecuregcm.entities.StickerPackFormUploadAttributes.StickerPackFormUploadItem; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.s3.PolicySigner; +import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator; +import org.whispersystems.textsecuregcm.util.Constants; +import org.whispersystems.textsecuregcm.util.Hex; +import org.whispersystems.textsecuregcm.util.Pair; @Path("/v1/sticker") public class StickerController { @@ -45,30 +44,31 @@ public class StickerController { @GET @Produces(MediaType.APPLICATION_JSON) @Path("/pack/form/{count}") - public StickerPackFormUploadAttributes getStickersForm(@Auth Account account, - @PathParam("count") @Min(1) @Max(201) int stickerCount) - throws RateLimitExceededException - { - rateLimiters.getStickerPackLimiter().validate(account.getUuid()); - - ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC); - String packId = generatePackId(); - String packLocation = "stickers/" + packId; - String manifestKey = packLocation + "/manifest.proto"; - Pair manifestPolicy = policyGenerator.createFor(now, manifestKey, Constants.MAXIMUM_STICKER_MANIFEST_SIZE_BYTES); - String manifestSignature = policySigner.getSignature(now, manifestPolicy.second()); - StickerPackFormUploadItem manifest = new StickerPackFormUploadItem(-1, manifestKey, manifestPolicy.first(), "private", "AWS4-HMAC-SHA256", - now.format(PostPolicyGenerator.AWS_DATE_TIME), manifestPolicy.second(), manifestSignature); + public StickerPackFormUploadAttributes getStickersForm(@Auth AuthenticatedAccount auth, + @PathParam("count") @Min(1) @Max(201) int stickerCount) + throws RateLimitExceededException { + rateLimiters.getStickerPackLimiter().validate(auth.getAccount().getUuid()); + ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC); + String packId = generatePackId(); + String packLocation = "stickers/" + packId; + String manifestKey = packLocation + "/manifest.proto"; + Pair manifestPolicy = policyGenerator.createFor(now, manifestKey, + Constants.MAXIMUM_STICKER_MANIFEST_SIZE_BYTES); + String manifestSignature = policySigner.getSignature(now, manifestPolicy.second()); + StickerPackFormUploadItem manifest = new StickerPackFormUploadItem(-1, manifestKey, manifestPolicy.first(), + "private", "AWS4-HMAC-SHA256", + now.format(PostPolicyGenerator.AWS_DATE_TIME), manifestPolicy.second(), manifestSignature); List stickers = new LinkedList<>(); - for (int i=0;i stickerPolicy = policyGenerator.createFor(now, stickerKey, Constants.MAXIMUM_STICKER_SIZE_BYTES); - String stickerSignature = policySigner.getSignature(now, stickerPolicy.second()); + for (int i = 0; i < stickerCount; i++) { + String stickerKey = packLocation + "/full/" + i; + Pair stickerPolicy = policyGenerator.createFor(now, stickerKey, + Constants.MAXIMUM_STICKER_SIZE_BYTES); + String stickerSignature = policySigner.getSignature(now, stickerPolicy.second()); stickers.add(new StickerPackFormUploadItem(i, stickerKey, stickerPolicy.first(), "private", "AWS4-HMAC-SHA256", - now.format(PostPolicyGenerator.AWS_DATE_TIME), stickerPolicy.second(), stickerSignature)); + now.format(PostPolicyGenerator.AWS_DATE_TIME), stickerPolicy.second(), stickerSignature)); } return new StickerPackFormUploadAttributes(packId, manifest, stickers); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java index a193aba30..739a5bcdc 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ReceiptSender.java @@ -1,20 +1,20 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.push; +import java.util.Optional; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.controllers.NoSuchUserException; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; -import java.util.Optional; - public class ReceiptSender { private final MessageSender messageSender; @@ -23,30 +23,29 @@ public class ReceiptSender { private static final Logger logger = LoggerFactory.getLogger(ReceiptSender.class); public ReceiptSender(AccountsManager accountManager, - MessageSender messageSender) - { + MessageSender messageSender) { this.accountManager = accountManager; - this.messageSender = messageSender; + this.messageSender = messageSender; } - public void sendReceipt(Account source, String destination, long messageId) - throws NoSuchUserException - { - if (source.getNumber().equals(destination)) { + public void sendReceipt(AuthenticatedAccount source, String destination, long messageId) + throws NoSuchUserException { + final Account sourceAccount = source.getAccount(); + if (sourceAccount.getNumber().equals(destination)) { return; } - Account destinationAccount = getDestinationAccount(destination); - Envelope.Builder message = Envelope.newBuilder() - .setServerTimestamp(System.currentTimeMillis()) - .setSource(source.getNumber()) - .setSourceUuid(source.getUuid().toString()) - .setSourceDevice((int) source.getAuthenticatedDevice().get().getId()) - .setTimestamp(messageId) - .setType(Envelope.Type.SERVER_DELIVERY_RECEIPT); + Account destinationAccount = getDestinationAccount(destination); + Envelope.Builder message = Envelope.newBuilder() + .setServerTimestamp(System.currentTimeMillis()) + .setSource(sourceAccount.getNumber()) + .setSourceUuid(sourceAccount.getUuid().toString()) + .setSourceDevice((int) source.getAuthenticatedDevice().getId()) + .setTimestamp(messageId) + .setType(Envelope.Type.SERVER_DELIVERY_RECEIPT); - if (source.getRelay().isPresent()) { - message.setRelay(source.getRelay().get()); + if (sourceAccount.getRelay().isPresent()) { + message.setRelay(sourceAccount.getRelay().get()); } for (final Device destinationDevice : destinationAccount.getDevices()) { @@ -63,7 +62,7 @@ public class ReceiptSender { { Optional account = accountManager.get(destination); - if (!account.isPresent()) { + if (account.isEmpty()) { throw new NoSuchUserException(destination); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java index 5291f60b3..757dabded 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Account.java @@ -8,12 +8,10 @@ package org.whispersystems.textsecuregcm.storage; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.annotations.VisibleForTesting; -import java.security.Principal; import java.util.HashSet; import java.util.Optional; import java.util.Set; import java.util.UUID; -import javax.security.auth.Subject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; @@ -22,7 +20,7 @@ import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock; import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.util.Util; -public class Account implements Principal { +public class Account { @JsonIgnore private static final Logger logger = LoggerFactory.getLogger(Account.class); @@ -63,9 +61,6 @@ public class Account implements Principal { @JsonProperty("inCds") private boolean discoverableByPhoneNumber = true; - @JsonIgnore - private Device authenticatedDevice; - @JsonProperty private int version; @@ -82,18 +77,6 @@ public class Account implements Principal { this.unidentifiedAccessKey = unidentifiedAccessKey; } - public Optional getAuthenticatedDevice() { - requireNotStale(); - - return Optional.ofNullable(authenticatedDevice); - } - - public void setAuthenticatedDevice(Device device) { - requireNotStale(); - - this.authenticatedDevice = device; - } - public UUID getUuid() { // this is the one method that may be called on a stale account return uuid; @@ -390,6 +373,10 @@ public class Account implements Principal { this.version = version; } + boolean isStale() { + return stale; + } + public void markStale() { stale = true; } @@ -403,17 +390,4 @@ public class Account implements Principal { } } - // Principal implementation - - @Override - @JsonIgnore - public String getName() { - return null; - } - - @Override - @JsonIgnore - public boolean implies(Subject subject) { - return false; - } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceNotFoundException.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceNotFoundException.java new file mode 100644 index 000000000..0ea57a0ad --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceNotFoundException.java @@ -0,0 +1,14 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +public class RefreshingAccountAndDeviceNotFoundException extends RuntimeException { + + public RefreshingAccountAndDeviceNotFoundException(final String message) { + super(message); + } + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplier.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplier.java new file mode 100644 index 000000000..b4a40aaae --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplier.java @@ -0,0 +1,35 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import java.util.function.Supplier; +import org.whispersystems.textsecuregcm.util.Pair; + +public class RefreshingAccountAndDeviceSupplier implements Supplier> { + + private Account account; + private Device device; + private final AccountsManager accountsManager; + + public RefreshingAccountAndDeviceSupplier(Account account, long deviceId, AccountsManager accountsManager) { + this.account = account; + this.device = account.getDevice(deviceId) + .orElseThrow(() -> new RefreshingAccountAndDeviceNotFoundException("Could not find device")); + this.accountsManager = accountsManager; + } + + @Override + public Pair get() { + if (account.isStale()) { + account = accountsManager.get(account.getUuid()) + .orElseThrow(() -> new RuntimeException("Could not find account")); + device = account.getDevice(device.getId()) + .orElseThrow(() -> new RefreshingAccountAndDeviceNotFoundException("Could not find device")); + } + + return new Pair<>(account, device); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index c735724d9..3e7a34d91 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -1,32 +1,31 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.websocket; +import static com.codahale.metrics.MetricRegistry.name; + import com.codahale.metrics.Counter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.Timer; +import java.util.concurrent.ScheduledExecutorService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.redis.RedisOperation; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.websocket.session.WebSocketSessionContext; import org.whispersystems.websocket.setup.WebSocketConnectListener; -import java.util.concurrent.ScheduledExecutorService; - -import static com.codahale.metrics.MetricRegistry.name; - public class AuthenticatedConnectListener implements WebSocketConnectListener { private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); @@ -60,16 +59,16 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { @Override public void onWebSocketConnect(WebSocketSessionContext context) { if (context.getAuthenticated() != null) { - final Account account = context.getAuthenticated(Account.class); - final Device device = account.getAuthenticatedDevice().get(); - final Timer.Context timer = durationTimer.time(); - final WebSocketConnection connection = new WebSocketConnection(receiptSender, - messagesManager, account, device, - context.getClient(), - retrySchedulingExecutor); + final AuthenticatedAccount auth = context.getAuthenticated(AuthenticatedAccount.class); + final Device device = auth.getAuthenticatedDevice(); + final Timer.Context timer = durationTimer.time(); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, + messagesManager, auth, device, + context.getClient(), + retrySchedulingExecutor); openWebsocketCounter.inc(); - RedisOperation.unchecked(() -> apnFallbackManager.cancel(account, device)); + RedisOperation.unchecked(() -> apnFallbackManager.cancel(auth.getAccount(), device)); context.addListener(new WebSocketSessionContext.WebSocketEventListener() { @Override @@ -79,20 +78,21 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { connection.stop(); - RedisOperation.unchecked(() -> clientPresenceManager.clearPresence(account.getUuid(), device.getId())); + RedisOperation.unchecked( + () -> clientPresenceManager.clearPresence(auth.getAccount().getUuid(), device.getId())); RedisOperation.unchecked(() -> { messagesManager.removeMessageAvailabilityListener(connection); - if (messagesManager.hasCachedMessages(account.getUuid(), device.getId())) { - messageSender.sendNewMessageNotification(account, device); + if (messagesManager.hasCachedMessages(auth.getAccount().getUuid(), device.getId())) { + messageSender.sendNewMessageNotification(auth.getAccount(), device); } }); } }); try { - clientPresenceManager.setPresent(account.getUuid(), device.getId(), connection); - messagesManager.addMessageAvailabilityListener(account.getUuid(), device.getId(), connection); + clientPresenceManager.setPresent(auth.getAccount().getUuid(), device.getId(), connection); + messagesManager.addMessageAvailabilityListener(auth.getAccount().getUuid(), device.getId(), connection); connection.start(); } catch (final Exception e) { log.warn("Failed to initialize websocket", e); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java index 3106ed700..c7dc07497 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticator.java @@ -1,23 +1,21 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.websocket; -import org.eclipse.jetty.websocket.api.UpgradeRequest; -import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; -import org.whispersystems.textsecuregcm.storage.Account; -import org.whispersystems.websocket.auth.WebSocketAuthenticator; - +import io.dropwizard.auth.basic.BasicCredentials; import java.util.List; import java.util.Map; import java.util.Optional; - -import io.dropwizard.auth.basic.BasicCredentials; +import org.eclipse.jetty.websocket.api.UpgradeRequest; +import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.websocket.auth.WebSocketAuthenticator; -public class WebSocketAccountAuthenticator implements WebSocketAuthenticator { +public class WebSocketAccountAuthenticator implements WebSocketAuthenticator { private final AccountAuthenticator accountAuthenticator; @@ -26,19 +24,18 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator authenticate(UpgradeRequest request) { + public AuthenticationResult authenticate(UpgradeRequest request) { Map> parameters = request.getParameterMap(); - List usernames = parameters.get("login"); - List passwords = parameters.get("password"); + List usernames = parameters.get("login"); + List passwords = parameters.get("password"); if (usernames == null || usernames.size() == 0 || - passwords == null || passwords.size() == 0) - { + passwords == null || passwords.size() == 0) { return new AuthenticationResult<>(Optional.empty(), false); } BasicCredentials credentials = new BasicCredentials(usernames.get(0).replace(" ", "+"), - passwords.get(0).replace(" ", "+")); + passwords.get(0).replace(" ", "+")); return new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index f406f7843..2a361d65c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -36,6 +36,7 @@ import javax.ws.rs.WebApplicationException; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.NoSuchUserException; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; @@ -43,7 +44,6 @@ import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.push.DisplacedPresenceListener; import org.whispersystems.textsecuregcm.push.ReceiptSender; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessageAvailabilityListener; import org.whispersystems.textsecuregcm.storage.MessagesManager; @@ -90,21 +90,22 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class); - private final ReceiptSender receiptSender; - private final MessagesManager messagesManager; + private final ReceiptSender receiptSender; + private final MessagesManager messagesManager; - private final Account account; - private final Device device; - private final WebSocketClient client; + private final AuthenticatedAccount auth; + private final Device device; + private final WebSocketClient client; private final ScheduledExecutorService retrySchedulingExecutor; - private final boolean isDesktopClient; + private final boolean isDesktopClient; - private final Semaphore processStoredMessagesSemaphore = new Semaphore(1); - private final AtomicReference storedMessageState = new AtomicReference<>(StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE); - private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false); - private final LongAdder sentMessageCounter = new LongAdder(); - private final AtomicLong queueDrainStartTime = new AtomicLong(); + private final Semaphore processStoredMessagesSemaphore = new Semaphore(1); + private final AtomicReference storedMessageState = new AtomicReference<>( + StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE); + private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false); + private final LongAdder sentMessageCounter = new LongAdder(); + private final AtomicLong queueDrainStartTime = new AtomicLong(); private final AtomicInteger consecutiveRetries = new AtomicInteger(); private final AtomicReference> retryFuture = new AtomicReference<>(); @@ -118,16 +119,15 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac public WebSocketConnection(ReceiptSender receiptSender, MessagesManager messagesManager, - Account account, + AuthenticatedAccount auth, Device device, WebSocketClient client, - ScheduledExecutorService retrySchedulingExecutor) - { - this.receiptSender = receiptSender; + ScheduledExecutorService retrySchedulingExecutor) { + this.receiptSender = receiptSender; this.messagesManager = messagesManager; - this.account = account; - this.device = device; - this.client = client; + this.auth = auth; + this.device = device; + this.client = client; this.retrySchedulingExecutor = retrySchedulingExecutor; Optional maybePlatform; @@ -168,7 +168,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac if (throwable == null) { if (isSuccessResponse(response)) { if (storedMessageInfo.isPresent()) { - messagesManager.delete(account.getUuid(), device.getId(), storedMessageInfo.get().getGuid()); + messagesManager.delete(auth.getAccount().getUuid(), device.getId(), storedMessageInfo.get().getGuid()); } if (message.getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT) { @@ -204,7 +204,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac if (!message.hasSource()) return; try { - receiptSender.sendReceipt(account, message.getSource(), message.getTimestamp()); + receiptSender.sendReceipt(auth, message.getSource(), message.getTimestamp()); } catch (NoSuchUserException e) { logger.info("No longer registered " + e.getMessage()); } catch (WebApplicationException e) { @@ -267,7 +267,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac private void sendNextMessagePage(final boolean cachedMessagesOnly, final CompletableFuture queueClearedFuture) { try { final OutgoingMessageEntityList messages = messagesManager - .getMessagesForDevice(account.getUuid(), device.getId(), client.getUserAgent(), cachedMessagesOnly); + .getMessagesForDevice(auth.getAccount().getUuid(), device.getId(), client.getUserAgent(), cachedMessagesOnly); final CompletableFuture[] sendFutures = new CompletableFuture[messages.getMessages().size()]; @@ -303,7 +303,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac final Envelope envelope = builder.build(); if (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient) { - messagesManager.delete(account.getUuid(), device.getId(), message.getGuid()); + messagesManager.delete(auth.getAccount().getUuid(), device.getId(), message.getGuid()); discardedMessagesMeter.mark(); sendFutures[i] = CompletableFuture.completedFuture(null); @@ -340,7 +340,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac public void handleNewEphemeralMessageAvailable() { ephemeralMessageAvailableMeter.mark(); - messagesManager.takeEphemeralMessage(account.getUuid(), device.getId()) + messagesManager.takeEphemeralMessage(auth.getAccount().getUuid(), device.getId()) .ifPresent(message -> sendMessage(message, Optional.empty())); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java index a19e4fbc6..d83126d58 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java @@ -24,11 +24,11 @@ import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager; import org.whispersystems.textsecuregcm.mappers.RetryLaterExceptionMapper; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.SystemMapper; @@ -41,7 +41,8 @@ class ChallengeControllerTest { private static final ResourceExtension EXTENSION = ResourceExtension.builder() .addProvider(AuthHelper.getAuthFilter()) - .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(Set.of(Account.class, DisabledPermittedAccount.class))) + .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( + Set.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) .setMapper(SystemMapper.getMapper()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource(new RetryLaterExceptionMapper()) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplierTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplierTest.java new file mode 100644 index 000000000..1ef30dcef --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RefreshingAccountAndDeviceSupplierTest.java @@ -0,0 +1,71 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.Optional; +import java.util.UUID; +import org.junit.jupiter.api.Test; +import org.whispersystems.textsecuregcm.util.Pair; + +class RefreshingAccountAndDeviceSupplierTest { + + @Test + void test() { + + final AccountsManager accountsManager = mock(AccountsManager.class); + + final UUID uuid = UUID.randomUUID(); + final long deviceId = 2L; + + final Account initialAccount = mock(Account.class); + final Device initialDevice = mock(Device.class); + + when(initialAccount.getUuid()).thenReturn(uuid); + when(initialDevice.getId()).thenReturn(deviceId); + when(initialAccount.getDevice(deviceId)).thenReturn(Optional.of(initialDevice)); + + when(accountsManager.get(any(UUID.class))).thenAnswer(answer -> { + final Account account = mock(Account.class); + final Device device = mock(Device.class); + + when(account.getUuid()).thenReturn(answer.getArgument(0, UUID.class)); + when(account.getDevice(deviceId)).thenReturn(Optional.of(device)); + when(device.getId()).thenReturn(deviceId); + + return Optional.of(account); + }); + + final RefreshingAccountAndDeviceSupplier refreshingAccountAndDeviceSupplier = new RefreshingAccountAndDeviceSupplier( + initialAccount, deviceId, accountsManager); + + Pair accountAndDevice = refreshingAccountAndDeviceSupplier.get(); + + assertSame(initialAccount, accountAndDevice.first()); + assertSame(initialDevice, accountAndDevice.second()); + + accountAndDevice = refreshingAccountAndDeviceSupplier.get(); + + assertSame(initialAccount, accountAndDevice.first()); + assertSame(initialDevice, accountAndDevice.second()); + + when(initialAccount.isStale()).thenReturn(true); + + accountAndDevice = refreshingAccountAndDeviceSupplier.get(); + + assertNotSame(initialAccount, accountAndDevice.first()); + assertNotSame(initialDevice, accountAndDevice.second()); + + assertEquals(uuid, accountAndDevice.first().getUuid()); + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java index 728be8203..6a627e499 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AccountControllerTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -52,8 +52,9 @@ import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials; -import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock; import org.whispersystems.textsecuregcm.auth.StoredVerificationCode; @@ -145,27 +146,29 @@ class AccountControllerTest { private static ExternalServiceCredentialGenerator storageCredentialGenerator = new ExternalServiceCredentialGenerator(new byte[32], new byte[32], false); private static final ResourceExtension resources = ResourceExtension.builder() - .addProvider(AuthHelper.getAuthFilter()) - .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( - ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) - .addProvider(new RateLimitExceededExceptionMapper()) - .setMapper(SystemMapper.getMapper()) - .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addResource(new AccountController(pendingAccountsManager, - accountsManager, - usernamesManager, - abusiveHostRules, - rateLimiters, - smsSender, - dynamicConfigurationManager, - turnTokenGenerator, - new HashMap<>(), - recaptchaClient, - gcmSender, - apnSender, - storageCredentialGenerator, - verifyExperimentEnrollmentManager)) - .build(); + .addProvider(AuthHelper.getAuthFilter()) + .addProvider( + new PolymorphicAuthValueFactoryProvider.Binder<>( + ImmutableSet.of(AuthenticatedAccount.class, + DisabledPermittedAuthenticatedAccount.class))) + .addProvider(new RateLimitExceededExceptionMapper()) + .setMapper(SystemMapper.getMapper()) + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addResource(new AccountController(pendingAccountsManager, + accountsManager, + usernamesManager, + abusiveHostRules, + rateLimiters, + smsSender, + dynamicConfigurationManager, + turnTokenGenerator, + new HashMap<>(), + recaptchaClient, + gcmSender, + apnSender, + storageCredentialGenerator, + verifyExperimentEnrollmentManager)) + .build(); @BeforeEach diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AttachmentControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AttachmentControllerTest.java index 0985f8310..a61226e9f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AttachmentControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/AttachmentControllerTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -33,7 +33,8 @@ import org.assertj.core.api.InstanceOfAssertFactories; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV1; import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV2; import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV3; @@ -43,7 +44,6 @@ import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV3; import org.whispersystems.textsecuregcm.entities.AttachmentUri; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.SystemMapper; @@ -78,8 +78,9 @@ class AttachmentControllerTest { static { try { resources = ResourceExtension.builder() - .addProvider(AuthHelper.getAuthFilter()) - .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) + .addProvider(AuthHelper.getAuthFilter()) + .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( + ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) .setMapper(SystemMapper.getMapper()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource(new AttachmentControllerV1(rateLimiters, "accessKey", "accessSecret", "attachment-bucket")) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/CertificateControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/CertificateControllerTest.java index 796b13908..4f4086299 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/CertificateControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/CertificateControllerTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -28,8 +28,9 @@ import org.signal.zkgroup.auth.AuthCredential; import org.signal.zkgroup.auth.AuthCredentialResponse; import org.signal.zkgroup.auth.ClientZkAuthOperations; import org.signal.zkgroup.auth.ServerZkAuthOperations; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.CertificateGenerator; -import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.controllers.CertificateController; import org.whispersystems.textsecuregcm.crypto.Curve; @@ -37,7 +38,6 @@ import org.whispersystems.textsecuregcm.entities.DeliveryCertificate; import org.whispersystems.textsecuregcm.entities.GroupCredentials; import org.whispersystems.textsecuregcm.entities.MessageProtos.SenderCertificate; import org.whispersystems.textsecuregcm.entities.MessageProtos.ServerCertificate; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.Util; @@ -66,12 +66,13 @@ class CertificateControllerTest { private static final ResourceExtension resources = ResourceExtension.builder() - .addProvider(AuthHelper.getAuthFilter()) - .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) - .setMapper(SystemMapper.getMapper()) - .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addResource(new CertificateController(certificateGenerator, serverZkAuthOperations, true)) - .build(); + .addProvider(AuthHelper.getAuthFilter()) + .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( + ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) + .setMapper(SystemMapper.getMapper()) + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addResource(new CertificateController(certificateGenerator, serverZkAuthOperations, true)) + .build(); @Test void testValidCertificate() throws Exception { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java index 3e9ff27b3..df235811c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.whispersystems.textsecuregcm.tests.controllers; @@ -36,7 +36,8 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.StoredVerificationCode; import org.whispersystems.textsecuregcm.controllers.DeviceController; import org.whispersystems.textsecuregcm.entities.AccountAttributes; @@ -88,17 +89,18 @@ class DeviceControllerTest { private static Map deviceConfiguration = new HashMap<>(); private static final ResourceExtension resources = ResourceExtension.builder() - .addProvider(AuthHelper.getAuthFilter()) - .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) - .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addProvider(new DeviceLimitExceededExceptionMapper()) - .addResource(new DumbVerificationDeviceController(pendingDevicesManager, - accountsManager, - messagesManager, - keys, - rateLimiters, - deviceConfiguration)) - .build(); + .addProvider(AuthHelper.getAuthFilter()) + .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( + ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addProvider(new DeviceLimitExceededExceptionMapper()) + .addResource(new DumbVerificationDeviceController(pendingDevicesManager, + accountsManager, + messagesManager, + keys, + rateLimiters, + deviceConfiguration)) + .build(); @BeforeEach @@ -114,15 +116,14 @@ class DeviceControllerTest { when(account.getNextDeviceId()).thenReturn(42L); when(account.getNumber()).thenReturn(AuthHelper.VALID_NUMBER); when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID); -// when(maxedAccount.getActiveDeviceCount()).thenReturn(6); - when(account.getAuthenticatedDevice()).thenReturn(Optional.of(masterDevice)); when(account.isEnabled()).thenReturn(false); when(account.isGroupsV2Supported()).thenReturn(true); when(account.isGv1MigrationSupported()).thenReturn(true); when(account.isSenderKeySupported()).thenReturn(true); when(account.isAnnouncementGroupSupported()).thenReturn(true); - when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(new StoredVerificationCode("5678901", System.currentTimeMillis(), null, null))); + when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER)).thenReturn( + Optional.of(new StoredVerificationCode("5678901", System.currentTimeMillis(), null, null))); when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.empty()); when(accountsManager.get(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(account)); when(accountsManager.get(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(maxedAccount)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DirectoryControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DirectoryControllerTest.java index d6d84e770..48852a9c1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DirectoryControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DirectoryControllerTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -20,14 +20,14 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import javax.ws.rs.core.Response.Status.Family; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; import org.whispersystems.textsecuregcm.controllers.DirectoryController; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; @ExtendWith(DropwizardExtensionsSupport.class) @@ -37,11 +37,12 @@ class DirectoryControllerTest { private static final ExternalServiceCredentials validCredentials = new ExternalServiceCredentials("username", "password"); private static final ResourceExtension resources = ResourceExtension.builder() - .addProvider(AuthHelper.getAuthFilter()) - .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) - .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addResource(new DirectoryController(directoryCredentialsGenerator)) - .build(); + .addProvider(AuthHelper.getAuthFilter()) + .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( + ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addResource(new DirectoryController(directoryCredentialsGenerator)) + .build(); @BeforeEach void setup() { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DonationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DonationControllerTest.java index 2bc0bd398..3a473dd2e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DonationControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DonationControllerTest.java @@ -26,14 +26,14 @@ import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; import org.whispersystems.textsecuregcm.configuration.DonationConfiguration; import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; import org.whispersystems.textsecuregcm.controllers.DonationController; import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationRequest; import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationResponse; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.SystemMapper; @@ -57,7 +57,8 @@ public class DonationControllerTest { configuration.setSupportedCurrencies(Set.of("usd", "gbp")); resources = ResourceExtension.builder() .addProvider(AuthHelper.getAuthFilter()) - .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) + .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( + ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) .setMapper(SystemMapper.getMapper()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource(new DonationController(executor, configuration)) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java index d52e813b1..885d155cf 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java @@ -44,7 +44,8 @@ import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatcher; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; -import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.controllers.KeysController; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; @@ -102,8 +103,8 @@ class KeysControllerTest { private static final ResourceExtension resources = ResourceExtension.builder() .addProvider(AuthHelper.getAuthFilter()) - .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( - ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) + .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of( + AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource(new RateLimitChallengeExceptionMapper(rateLimitChallengeManager)) .addResource(new ServerRejectedExceptionMapper()) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java index f3ff86672..661e5ac88 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/MessageControllerTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -68,7 +68,8 @@ import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatcher; import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; -import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration; @@ -142,15 +143,16 @@ class MessageControllerTest { private final ObjectMapper mapper = new ObjectMapper(); private static final ResourceExtension resources = ResourceExtension.builder() - .addProvider(AuthHelper.getAuthFilter()) - .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) - .addProvider(RateLimitExceededExceptionMapper.class) - .addProvider(new RateLimitChallengeExceptionMapper(rateLimitChallengeManager)) - .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addResource(new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, - messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager, - rateLimitChallengeManager, reportMessageManager, metricsCluster, receiptExecutor)) - .build(); + .addProvider(AuthHelper.getAuthFilter()) + .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( + ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) + .addProvider(RateLimitExceededExceptionMapper.class) + .addProvider(new RateLimitChallengeExceptionMapper(rateLimitChallengeManager)) + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addResource(new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, + messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager, + rateLimitChallengeManager, reportMessageManager, metricsCluster, receiptExecutor)) + .build(); @BeforeEach void setup() throws Exception { @@ -576,7 +578,7 @@ class MessageControllerTest { .delete(); assertThat("Good Response Code", response.getStatus(), is(equalTo(204))); - verify(receiptSender).sendReceipt(any(Account.class), eq("+14152222222"), eq(timestamp)); + verify(receiptSender).sendReceipt(any(AuthenticatedAccount.class), eq("+14152222222"), eq(timestamp)); response = resources.getJerseyTest() .target(String.format("/v1/messages/%s/%d", "+14152222222", 31338)) @@ -731,22 +733,54 @@ class MessageControllerTest { mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), Set.of(1L, 3L), null, + null, + false, null), arguments( mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), Set.of(1L, 2L, 3L), null, - Set.of(2L)), + Set.of(2L), + false, + null), arguments( mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), Set.of(1L), Set.of(3L), + null, + false, null), arguments( mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), Set.of(1L, 2L), Set.of(3L), - Set.of(2L)) + Set.of(2L), + false, + null), + arguments( + mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), + Set.of(1L), + Set.of(3L), + Set.of(1L), + true, + 1L + ), + arguments( + mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), + Set.of(2L), + Set.of(3L), + Set.of(2L), + true, + 1L + ), + arguments( + mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), + Set.of(3L), + null, + null, + true, + 1L + ) ); } @@ -756,10 +790,13 @@ class MessageControllerTest { Account account, Set deviceIds, Collection expectedMissingDeviceIds, - Collection expectedExtraDeviceIds) throws Exception { + Collection expectedExtraDeviceIds, + boolean isSyncMessage, + Long authenticatedDeviceId) throws Exception { if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) { final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class, - () -> MessageController.validateCompleteDeviceList(account, deviceIds, false)); + () -> MessageController.validateCompleteDeviceList(account, deviceIds, isSyncMessage, + Optional.ofNullable(authenticatedDeviceId))); if (expectedMissingDeviceIds != null) { Assertions.assertThat(mismatchedDevicesException.getMissingDevices()) .hasSameElementsAs(expectedMissingDeviceIds); @@ -768,7 +805,8 @@ class MessageControllerTest { Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds); } } else { - MessageController.validateCompleteDeviceList(account, deviceIds, false); + MessageController.validateCompleteDeviceList(account, deviceIds, isSyncMessage, + Optional.ofNullable(authenticatedDeviceId)); } } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/PaymentsControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/PaymentsControllerTest.java index 570ec7a9c..188853e27 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/PaymentsControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/PaymentsControllerTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -22,14 +22,14 @@ import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; import org.whispersystems.textsecuregcm.controllers.PaymentsController; import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager; import org.whispersystems.textsecuregcm.entities.CurrencyConversionEntity; import org.whispersystems.textsecuregcm.entities.CurrencyConversionEntityList; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; @ExtendWith(DropwizardExtensionsSupport.class) @@ -41,11 +41,12 @@ class PaymentsControllerTest { private final ExternalServiceCredentials validCredentials = new ExternalServiceCredentials("username", "password"); private static final ResourceExtension resources = ResourceExtension.builder() - .addProvider(AuthHelper.getAuthFilter()) - .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) - .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addResource(new PaymentsController(currencyManager, paymentsCredentialGenerator)) - .build(); + .addProvider(AuthHelper.getAuthFilter()) + .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( + ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addResource(new PaymentsController(currencyManager, paymentsCredentialGenerator)) + .build(); @BeforeEach diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/ProfileControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/ProfileControllerTest.java index a92ac3d1a..4b20fc8e9 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/ProfileControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/ProfileControllerTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -41,7 +41,8 @@ import org.signal.zkgroup.profiles.ProfileKey; import org.signal.zkgroup.profiles.ProfileKeyCommitment; import org.signal.zkgroup.profiles.ServerZkProfileOperations; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; -import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicPaymentsConfiguration; import org.whispersystems.textsecuregcm.controllers.ProfileController; @@ -87,22 +88,23 @@ class ProfileControllerTest { private Account profileAccount; private static final ResourceExtension resources = ResourceExtension.builder() - .addProvider(AuthHelper.getAuthFilter()) - .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) - .setMapper(SystemMapper.getMapper()) - .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addResource(new ProfileController(rateLimiters, - accountsManager, - profilesManager, - usernamesManager, - dynamicConfigurationManager, - s3client, - postPolicyGenerator, - policySigner, - "profilesBucket", - zkProfileOperations, - true)) - .build(); + .addProvider(AuthHelper.getAuthFilter()) + .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( + ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) + .setMapper(SystemMapper.getMapper()) + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addResource(new ProfileController(rateLimiters, + accountsManager, + profilesManager, + usernamesManager, + dynamicConfigurationManager, + s3client, + postPolicyGenerator, + policySigner, + "profilesBucket", + zkProfileOperations, + true)) + .build(); @BeforeEach void setup() throws Exception { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/RemoteConfigControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/RemoteConfigControllerTest.java index d023e3737..d7e761bc5 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/RemoteConfigControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/RemoteConfigControllerTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -30,17 +30,17 @@ import javax.ws.rs.client.Entity; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; -import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.controllers.RemoteConfigController; import org.whispersystems.textsecuregcm.entities.UserRemoteConfig; import org.whispersystems.textsecuregcm.entities.UserRemoteConfigList; import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.RemoteConfig; import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; @@ -52,12 +52,13 @@ class RemoteConfigControllerTest { private static final List remoteConfigsAuth = List.of("foo", "bar"); private static final ResourceExtension resources = ResourceExtension.builder() - .addProvider(AuthHelper.getAuthFilter()) - .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) - .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addProvider(new DeviceLimitExceededExceptionMapper()) - .addResource(new RemoteConfigController(remoteConfigsManager, remoteConfigsAuth, Map.of("maxGroupSize", "42"))) - .build(); + .addProvider(AuthHelper.getAuthFilter()) + .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( + ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addProvider(new DeviceLimitExceededExceptionMapper()) + .addResource(new RemoteConfigController(remoteConfigsManager, remoteConfigsAuth, Map.of("maxGroupSize", "42"))) + .build(); @BeforeEach diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/SecureStorageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/SecureStorageControllerTest.java index 94fb62361..ae6d30dd1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/SecureStorageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/SecureStorageControllerTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -15,11 +15,11 @@ import javax.ws.rs.core.Response; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; import org.whispersystems.textsecuregcm.controllers.SecureStorageController; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.SystemMapper; @@ -29,12 +29,13 @@ class SecureStorageControllerTest { private static final ExternalServiceCredentialGenerator storageCredentialGenerator = new ExternalServiceCredentialGenerator(new byte[32], new byte[32], false); private static final ResourceExtension resources = ResourceExtension.builder() - .addProvider(AuthHelper.getAuthFilter()) - .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) - .setMapper(SystemMapper.getMapper()) - .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addResource(new SecureStorageController(storageCredentialGenerator)) - .build(); + .addProvider(AuthHelper.getAuthFilter()) + .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( + ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) + .setMapper(SystemMapper.getMapper()) + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addResource(new SecureStorageController(storageCredentialGenerator)) + .build(); @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/StickerControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/StickerControllerTest.java index 132d6f9fc..239a9f8ac 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/StickerControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/StickerControllerTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -21,13 +21,13 @@ import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.StickerController; import org.whispersystems.textsecuregcm.entities.StickerPackFormUploadAttributes; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.SystemMapper; @@ -38,12 +38,13 @@ class StickerControllerTest { private static final RateLimiters rateLimiters = mock(RateLimiters.class); private static final ResourceExtension resources = ResourceExtension.builder() - .addProvider(AuthHelper.getAuthFilter()) - .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) - .setMapper(SystemMapper.getMapper()) - .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addResource(new StickerController(rateLimiters, "foo", "bar", "us-east-1", "mybucket")) - .build(); + .addProvider(AuthHelper.getAuthFilter()) + .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( + ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) + .setMapper(SystemMapper.getMapper()) + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addResource(new StickerController(rateLimiters, "foo", "bar", "us-east-1", "mybucket")) + .build(); @BeforeEach void setup() { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/VoiceVerificationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/VoiceVerificationControllerTest.java index a6381de15..9ce5ce726 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/VoiceVerificationControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/VoiceVerificationControllerTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -18,10 +18,10 @@ import javax.ws.rs.core.Response; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.controllers.VoiceVerificationController; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; -import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.SystemMapper; @@ -29,14 +29,15 @@ import org.whispersystems.textsecuregcm.util.SystemMapper; class VoiceVerificationControllerTest { private static final ResourceExtension resources = ResourceExtension.builder() - .addProvider(AuthHelper.getAuthFilter()) - .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) - .addProvider(new RateLimitExceededExceptionMapper()) - .setMapper(SystemMapper.getMapper()) - .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) - .addResource(new VoiceVerificationController("https://foo.com/bar", - new HashSet<>(Arrays.asList("pt-BR", "ru")))) - .build(); + .addProvider(AuthHelper.getAuthFilter()) + .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( + ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) + .addProvider(new RateLimitExceededExceptionMapper()) + .setMapper(SystemMapper.getMapper()) + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addResource(new VoiceVerificationController("https://foo.com/bar", + new HashSet<>(Arrays.asList("pt-BR", "ru")))) + .build(); @Test void testTwimlLocale() { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java index bd0034a32..fe5db145a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AccountsHelper.java @@ -113,10 +113,6 @@ public class AccountsHelper { when(updatedAccount.getMasterDevice()).thenAnswer(stubbing); break; } - case "getAuthenticatedDevice": { - when(updatedAccount.getAuthenticatedDevice()).thenAnswer(stubbing); - break; - } case "isEnabled": { when(updatedAccount.isEnabled()).thenAnswer(stubbing); break; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java index 8cf49b2d4..5a7839d30 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/util/AuthHelper.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -24,8 +24,9 @@ import java.util.UUID; import org.mockito.ArgumentMatcher; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials; -import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; +import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; @@ -121,11 +122,6 @@ public class AuthHelper { when(UNDISCOVERABLE_ACCOUNT.getNumber()).thenReturn(UNDISCOVERABLE_NUMBER); when(UNDISCOVERABLE_ACCOUNT.getUuid()).thenReturn(UNDISCOVERABLE_UUID); - when(VALID_ACCOUNT.getAuthenticatedDevice()).thenReturn(Optional.of(VALID_DEVICE)); - when(VALID_ACCOUNT_TWO.getAuthenticatedDevice()).thenReturn(Optional.of(VALID_DEVICE_TWO)); - when(DISABLED_ACCOUNT.getAuthenticatedDevice()).thenReturn(Optional.of(DISABLED_DEVICE)); - when(UNDISCOVERABLE_ACCOUNT.getAuthenticatedDevice()).thenReturn(Optional.of(UNDISCOVERABLE_DEVICE)); - when(VALID_ACCOUNT.getRelay()).thenReturn(Optional.empty()); when(VALID_ACCOUNT_TWO.getRelay()).thenReturn(Optional.empty()); when(UNDISCOVERABLE_ACCOUNT.getRelay()).thenReturn(Optional.empty()); @@ -151,18 +147,30 @@ public class AuthHelper { when(ACCOUNTS_MANAGER.get(VALID_NUMBER_TWO)).thenReturn(Optional.of(VALID_ACCOUNT_TWO)); when(ACCOUNTS_MANAGER.get(VALID_UUID_TWO)).thenReturn(Optional.of(VALID_ACCOUNT_TWO)); - when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(VALID_NUMBER_TWO)))).thenReturn(Optional.of(VALID_ACCOUNT_TWO)); - when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher) identifier -> identifier != null && identifier.hasUuid() && identifier.getUuid().equals(VALID_UUID_TWO)))).thenReturn(Optional.of(VALID_ACCOUNT_TWO)); + when(ACCOUNTS_MANAGER.get(argThat( + (ArgumentMatcher) identifier -> identifier != null && identifier.hasNumber() + && identifier.getNumber().equals(VALID_NUMBER_TWO)))).thenReturn(Optional.of(VALID_ACCOUNT_TWO)); + when(ACCOUNTS_MANAGER.get(argThat( + (ArgumentMatcher) identifier -> identifier != null && identifier.hasUuid() + && identifier.getUuid().equals(VALID_UUID_TWO)))).thenReturn(Optional.of(VALID_ACCOUNT_TWO)); when(ACCOUNTS_MANAGER.get(DISABLED_NUMBER)).thenReturn(Optional.of(DISABLED_ACCOUNT)); when(ACCOUNTS_MANAGER.get(DISABLED_UUID)).thenReturn(Optional.of(DISABLED_ACCOUNT)); - when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(DISABLED_NUMBER)))).thenReturn(Optional.of(DISABLED_ACCOUNT)); - when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher) identifier -> identifier != null && identifier.hasUuid() && identifier.getUuid().equals(DISABLED_UUID)))).thenReturn(Optional.of(DISABLED_ACCOUNT)); + when(ACCOUNTS_MANAGER.get(argThat( + (ArgumentMatcher) identifier -> identifier != null && identifier.hasNumber() + && identifier.getNumber().equals(DISABLED_NUMBER)))).thenReturn(Optional.of(DISABLED_ACCOUNT)); + when(ACCOUNTS_MANAGER.get(argThat( + (ArgumentMatcher) identifier -> identifier != null && identifier.hasUuid() + && identifier.getUuid().equals(DISABLED_UUID)))).thenReturn(Optional.of(DISABLED_ACCOUNT)); when(ACCOUNTS_MANAGER.get(UNDISCOVERABLE_NUMBER)).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT)); when(ACCOUNTS_MANAGER.get(UNDISCOVERABLE_UUID)).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT)); - when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(UNDISCOVERABLE_NUMBER)))).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT)); - when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher) identifier -> identifier != null && identifier.hasUuid() && identifier.getUuid().equals(UNDISCOVERABLE_UUID)))).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT)); + when(ACCOUNTS_MANAGER.get(argThat( + (ArgumentMatcher) identifier -> identifier != null && identifier.hasNumber() + && identifier.getNumber().equals(UNDISCOVERABLE_NUMBER)))).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT)); + when(ACCOUNTS_MANAGER.get(argThat( + (ArgumentMatcher) identifier -> identifier != null && identifier.hasUuid() + && identifier.getUuid().equals(UNDISCOVERABLE_UUID)))).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT)); AccountsHelper.setupMockUpdateForAuthHelper(ACCOUNTS_MANAGER); @@ -170,11 +178,13 @@ public class AuthHelper { testAccount.setup(ACCOUNTS_MANAGER); } - AuthFilter accountAuthFilter = new BasicCredentialAuthFilter.Builder().setAuthenticator(new AccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter (); - AuthFilter disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder().setAuthenticator(new DisabledPermittedAccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter(); + AuthFilter accountAuthFilter = new BasicCredentialAuthFilter.Builder().setAuthenticator( + new AccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter(); + AuthFilter disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder().setAuthenticator( + new DisabledPermittedAccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter(); - return new PolymorphicAuthDynamicFeature<>(ImmutableMap.of(Account.class, accountAuthFilter, - DisabledPermittedAccount.class, disabledPermittedAccountAuthFilter)); + return new PolymorphicAuthDynamicFeature<>(ImmutableMap.of(AuthenticatedAccount.class, accountAuthFilter, + DisabledPermittedAuthenticatedAccount.class, disabledPermittedAccountAuthFilter)); } public static String getAuthHeader(String number, String password) { @@ -223,13 +233,16 @@ public class AuthHelper { when(account.getMasterDevice()).thenReturn(Optional.of(device)); when(account.getNumber()).thenReturn(number); when(account.getUuid()).thenReturn(uuid); - when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); when(account.getRelay()).thenReturn(Optional.empty()); when(account.isEnabled()).thenReturn(true); when(accountsManager.get(number)).thenReturn(Optional.of(account)); when(accountsManager.get(uuid)).thenReturn(Optional.of(account)); - when(accountsManager.get(argThat((ArgumentMatcher) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(number)))).thenReturn(Optional.of(account)); - when(accountsManager.get(argThat((ArgumentMatcher) identifier -> identifier != null && identifier.hasUuid() && identifier.getUuid().equals(uuid)))).thenReturn(Optional.of(account)); + when(accountsManager.get(argThat( + (ArgumentMatcher) identifier -> identifier != null && identifier.hasNumber() + && identifier.getNumber().equals(number)))).thenReturn(Optional.of(account)); + when(accountsManager.get(argThat( + (ArgumentMatcher) identifier -> identifier != null && identifier.hasUuid() + && identifier.getUuid().equals(uuid)))).thenReturn(Optional.of(account)); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java index b971079a1..78de5abe1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionIntegrationTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -40,6 +40,7 @@ import org.junit.Rule; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.metrics.PushLatencyManager; @@ -52,6 +53,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbRule; +import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.messages.WebSocketResponseMessage; @@ -90,13 +92,13 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest when(account.getUuid()).thenReturn(UUID.randomUUID()); when(device.getId()).thenReturn(1L); - webSocketConnection = new WebSocketConnection( - mock(ReceiptSender.class), - new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class), reportMessageManager), - account, - device, - webSocketClient, - retrySchedulingExecutor); + webSocketConnection = new WebSocketConnection( + mock(ReceiptSender.class), + new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class), reportMessageManager), + new AuthenticatedAccount(() -> new Pair<>(account, device)), + device, + webSocketClient, + retrySchedulingExecutor); } @After diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 95b32d686..0261cbc3c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2013-2020 Signal Messenger, LLC + * Copyright 2013-2021 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ @@ -24,6 +24,7 @@ import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import io.dropwizard.auth.basic.BasicCredentials; +import io.lettuce.core.RedisException; import java.io.IOException; import java.util.Collections; import java.util.HashMap; @@ -39,7 +40,6 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; -import io.lettuce.core.RedisException; import org.apache.commons.lang3.RandomStringUtils; import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.junit.Before; @@ -48,6 +48,7 @@ import org.mockito.ArgumentMatchers; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.push.ApnFallbackManager; @@ -58,6 +59,7 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.MessagesManager; +import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult; import org.whispersystems.websocket.messages.WebSocketResponseMessage; @@ -75,6 +77,7 @@ public class WebSocketConnectionTest { private AccountsManager accountsManager; private Account account; private Device device; + private AuthenticatedAccount auth; private UpgradeRequest upgradeRequest; private ReceiptSender receiptSender; private ApnFallbackManager apnFallbackManager; @@ -86,6 +89,7 @@ public class WebSocketConnectionTest { accountsManager = mock(AccountsManager.class); account = mock(Account.class); device = mock(Device.class); + auth = new AuthenticatedAccount(() -> new Pair<>(account, device)); upgradeRequest = mock(UpgradeRequest.class); receiptSender = mock(ReceiptSender.class); apnFallbackManager = mock(ApnFallbackManager.class); @@ -94,35 +98,42 @@ public class WebSocketConnectionTest { @Test public void testCredentials() throws Exception { - MessagesManager storedMessages = mock(MessagesManager.class); + MessagesManager storedMessages = mock(MessagesManager.class); WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator); - AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages, mock(MessageSender.class), apnFallbackManager, mock(ClientPresenceManager.class), + AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages, + mock(MessageSender.class), apnFallbackManager, mock(ClientPresenceManager.class), retrySchedulingExecutor); - WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); + WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) - .thenReturn(Optional.of(account)); + .thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(account, device)))); when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD)))) - .thenReturn(Optional.empty()); + .thenReturn(Optional.empty()); - when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); - - when(upgradeRequest.getParameterMap()).thenReturn(new HashMap>() {{ - put("login", new LinkedList() {{add(VALID_USER);}}); - put("password", new LinkedList() {{add(VALID_PASSWORD);}}); + when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<>() {{ + put("login", new LinkedList<>() {{ + add(VALID_USER); + }}); + put("password", new LinkedList<>() {{ + add(VALID_PASSWORD); + }}); }}); - AuthenticationResult account = webSocketAuthenticator.authenticate(upgradeRequest); - when(sessionContext.getAuthenticated(Account.class)).thenReturn(account.getUser().orElse(null)); + AuthenticationResult account = webSocketAuthenticator.authenticate(upgradeRequest); + when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.getUser().orElse(null)); connectListener.onWebSocketConnect(sessionContext); verify(sessionContext).addListener(any(WebSocketSessionContext.WebSocketEventListener.class)); when(upgradeRequest.getParameterMap()).thenReturn(new HashMap>() {{ - put("login", new LinkedList() {{add(INVALID_USER);}}); - put("password", new LinkedList() {{add(INVALID_PASSWORD);}}); + put("login", new LinkedList() {{ + add(INVALID_USER); + }}); + put("password", new LinkedList() {{ + add(INVALID_PASSWORD); + }}); }}); account = webSocketAuthenticator.authenticate(upgradeRequest); @@ -148,7 +159,6 @@ public class WebSocketConnectionTest { when(device.getId()).thenReturn(2L); - when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); when(account.getNumber()).thenReturn("+14152222222"); when(account.getUuid()).thenReturn(accountUuid); @@ -184,12 +194,13 @@ public class WebSocketConnectionTest { }); WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, - account, device, client, retrySchedulingExecutor); + auth, device, client, retrySchedulingExecutor); connection.start(); - verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.>any()); + verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), + ArgumentMatchers.>any()); - assertTrue(futures.size() == 3); + assertEquals(3, futures.size()); WebSocketResponseMessage response = mock(WebSocketResponseMessage.class); when(response.getStatus()).thenReturn(200); @@ -199,7 +210,7 @@ public class WebSocketConnectionTest { futures.get(2).completeExceptionally(new IOException()); verify(storedMessages, times(1)).delete(eq(accountUuid), eq(2L), eq(outgoingMessages.get(1).getGuid())); - verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender1"), eq(2222L)); + verify(receiptSender, times(1)).sendReceipt(eq(auth), eq("sender1"), eq(2222L)); connection.stop(); verify(client).close(anyInt(), anyString()); @@ -207,9 +218,10 @@ public class WebSocketConnectionTest { @Test(timeout = 5_000L) public void testOnlineSend() throws Exception { - final MessagesManager messagesManager = mock(MessagesManager.class); - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, + retrySchedulingExecutor); final UUID accountUuid = UUID.randomUUID(); @@ -219,7 +231,7 @@ public class WebSocketConnectionTest { when(client.getUserAgent()).thenReturn("Test-UA"); when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) - .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)) + .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)) .thenReturn(new OutgoingMessageEntityList(List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first")), false)) .thenReturn(new OutgoingMessageEntityList(List.of(createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")), false)); @@ -300,7 +312,6 @@ public class WebSocketConnectionTest { when(device.getId()).thenReturn(2L); - when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); when(account.getNumber()).thenReturn("+14152222222"); when(account.getUuid()).thenReturn(UUID.randomUUID()); @@ -336,11 +347,12 @@ public class WebSocketConnectionTest { }); WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, - account, device, client, retrySchedulingExecutor); + auth, device, client, retrySchedulingExecutor); connection.start(); - verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.>any()); + verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class), + ArgumentMatchers.>any()); assertEquals(futures.size(), 2); @@ -349,7 +361,7 @@ public class WebSocketConnectionTest { futures.get(1).complete(response); futures.get(0).completeExceptionally(new IOException()); - verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender2"), eq(secondMessage.getTimestamp())); + verify(receiptSender, times(1)).sendReceipt(eq(auth), eq("sender2"), eq(secondMessage.getTimestamp())); connection.stop(); verify(client).close(anyInt(), anyString()); @@ -357,19 +369,21 @@ public class WebSocketConnectionTest { @Test(timeout = 5000L) public void testProcessStoredMessageConcurrency() throws InterruptedException { - final MessagesManager messagesManager = mock(MessagesManager.class); - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, + retrySchedulingExecutor); when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(UUID.randomUUID()); when(device.getId()).thenReturn(1L); when(client.getUserAgent()).thenReturn("Test-UA"); - final AtomicBoolean threadWaiting = new AtomicBoolean(false); + final AtomicBoolean threadWaiting = new AtomicBoolean(false); final AtomicBoolean returnMessageList = new AtomicBoolean(false); - when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)).thenAnswer((Answer)invocation -> { + when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)).thenAnswer( + (Answer) invocation -> { synchronized (threadWaiting) { threadWaiting.set(true); threadWaiting.notifyAll(); @@ -418,9 +432,10 @@ public class WebSocketConnectionTest { @Test(timeout = 5000L) public void testProcessStoredMessagesMultiplePages() throws InterruptedException { - final MessagesManager messagesManager = mock(MessagesManager.class); - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, + retrySchedulingExecutor); when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(UUID.randomUUID()); @@ -428,8 +443,8 @@ public class WebSocketConnectionTest { when(client.getUserAgent()).thenReturn("Test-UA"); final List firstPageMessages = - List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"), - createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")); + List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"), + createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")); final List secondPageMessages = List.of(createMessage(3L, false, "sender1", UUID.randomUUID(), 3333, false, "third")); @@ -463,7 +478,8 @@ public class WebSocketConnectionTest { public void testProcessStoredMessagesContainsSenderUuid() throws InterruptedException { final MessagesManager messagesManager = mock(MessagesManager.class); final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, + retrySchedulingExecutor); when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(UUID.randomUUID()); @@ -471,7 +487,8 @@ public class WebSocketConnectionTest { when(client.getUserAgent()).thenReturn("Test-UA"); final UUID senderUuid = UUID.randomUUID(); - final List messages = List.of(createMessage(1L, false, "senderE164", senderUuid, 1111L, false, "message the first")); + final List messages = List.of( + createMessage(1L, false, "senderE164", senderUuid, 1111L, false, "message the first")); final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(messages, false); when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)).thenReturn(firstPage); @@ -511,9 +528,10 @@ public class WebSocketConnectionTest { @Test public void testProcessStoredMessagesSingleEmptyCall() { - final MessagesManager messagesManager = mock(MessagesManager.class); - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, + retrySchedulingExecutor); final UUID accountUuid = UUID.randomUUID(); @@ -523,7 +541,7 @@ public class WebSocketConnectionTest { when(client.getUserAgent()).thenReturn("Test-UA"); when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) - .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); + .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); @@ -540,10 +558,11 @@ public class WebSocketConnectionTest { @Test(timeout = 5000L) public void testRequeryOnStateMismatch() throws InterruptedException { - final MessagesManager messagesManager = mock(MessagesManager.class); - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); - final UUID accountUuid = UUID.randomUUID(); + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, + retrySchedulingExecutor); + final UUID accountUuid = UUID.randomUUID(); when(account.getNumber()).thenReturn("+18005551234"); when(account.getUuid()).thenReturn(accountUuid); @@ -551,8 +570,8 @@ public class WebSocketConnectionTest { when(client.getUserAgent()).thenReturn("Test-UA"); final List firstPageMessages = - List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"), - createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")); + List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"), + createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")); final List secondPageMessages = List.of(createMessage(3L, false, "sender1", UUID.randomUUID(), 3333, false, "third")); @@ -587,9 +606,10 @@ public class WebSocketConnectionTest { @Test public void testProcessCachedMessagesOnly() { - final MessagesManager messagesManager = mock(MessagesManager.class); - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, + retrySchedulingExecutor); final UUID accountUuid = UUID.randomUUID(); @@ -599,7 +619,7 @@ public class WebSocketConnectionTest { when(client.getUserAgent()).thenReturn("Test-UA"); when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) - .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); + .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); @@ -619,9 +639,10 @@ public class WebSocketConnectionTest { @Test public void testProcessDatabaseMessagesAfterPersist() { - final MessagesManager messagesManager = mock(MessagesManager.class); - final WebSocketClient client = mock(WebSocketClient.class); - final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); + final MessagesManager messagesManager = mock(MessagesManager.class); + final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client, + retrySchedulingExecutor); final UUID accountUuid = UUID.randomUUID(); @@ -631,7 +652,7 @@ public class WebSocketConnectionTest { when(client.getUserAgent()).thenReturn("Test-UA"); when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) - .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); + .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); when(successResponse.getStatus()).thenReturn(200); @@ -664,7 +685,6 @@ public class WebSocketConnectionTest { when(device.getId()).thenReturn(2L); - when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); when(account.getNumber()).thenReturn("+14152222222"); when(account.getUuid()).thenReturn(accountUuid); @@ -689,20 +709,24 @@ public class WebSocketConnectionTest { final WebSocketClient client = mock(WebSocketClient.class); when(client.getUserAgent()).thenReturn(userAgent); - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.>any())) - .thenAnswer(new Answer>() { - @Override - public CompletableFuture answer(InvocationOnMock invocationOnMock) throws Throwable { - CompletableFuture future = new CompletableFuture<>(); - futures.add(future); - return future; - } - }); + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), + ArgumentMatchers.>any())) + .thenAnswer(new Answer>() { + @Override + public CompletableFuture answer(InvocationOnMock invocationOnMock) + throws Throwable { + CompletableFuture future = new CompletableFuture<>(); + futures.add(future); + return future; + } + }); - WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client, retrySchedulingExecutor); + WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client, + retrySchedulingExecutor); connection.start(); - verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.>any()); + verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), + ArgumentMatchers.>any()); assertEquals(2, futures.size()); @@ -737,7 +761,6 @@ public class WebSocketConnectionTest { when(device.getId()).thenReturn(2L); - when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); when(account.getNumber()).thenReturn("+14152222222"); when(account.getUuid()).thenReturn(accountUuid); @@ -762,20 +785,24 @@ public class WebSocketConnectionTest { final WebSocketClient client = mock(WebSocketClient.class); when(client.getUserAgent()).thenReturn(userAgent); - when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.>any())) - .thenAnswer(new Answer>() { - @Override - public CompletableFuture answer(InvocationOnMock invocationOnMock) throws Throwable { - CompletableFuture future = new CompletableFuture<>(); - futures.add(future); - return future; - } - }); + when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), + ArgumentMatchers.>any())) + .thenAnswer(new Answer>() { + @Override + public CompletableFuture answer(InvocationOnMock invocationOnMock) + throws Throwable { + CompletableFuture future = new CompletableFuture<>(); + futures.add(future); + return future; + } + }); - WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client, retrySchedulingExecutor); + WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client, + retrySchedulingExecutor); connection.start(); - verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.>any()); + verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), + ArgumentMatchers.>any()); assertEquals(3, futures.size()); @@ -799,7 +826,6 @@ public class WebSocketConnectionTest { when(device.getId()).thenReturn(2L); - when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); when(account.getNumber()).thenReturn("+14152222222"); when(account.getUuid()).thenReturn(accountUuid); @@ -808,17 +834,20 @@ public class WebSocketConnectionTest { when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false)) .thenThrow(new RedisException("OH NO")); - when(retrySchedulingExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer((Answer>) invocation -> { - invocation.getArgument(0, Runnable.class).run(); - return mock(ScheduledFuture.class); - }); + when(retrySchedulingExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer( + (Answer>) invocation -> { + invocation.getArgument(0, Runnable.class).run(); + return mock(ScheduledFuture.class); + }); - final WebSocketClient client = mock(WebSocketClient.class); + final WebSocketClient client = mock(WebSocketClient.class); - WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client, retrySchedulingExecutor); + WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client, + retrySchedulingExecutor); connection.start(); - verify(retrySchedulingExecutor, times(WebSocketConnection.MAX_CONSECUTIVE_RETRIES)).schedule(any(Runnable.class), anyLong(), any()); + verify(retrySchedulingExecutor, times(WebSocketConnection.MAX_CONSECUTIVE_RETRIES)).schedule(any(Runnable.class), + anyLong(), any()); verify(client).close(eq(1011), anyString()); }