Use refreshing `AuthenticatedAccount` for `@Auth`

This commit is contained in:
Chris Eager 2021-08-11 14:52:25 -05:00 committed by GitHub
parent b3e6a50dee
commit 31022aeb79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
53 changed files with 1251 additions and 969 deletions

View File

@ -64,8 +64,9 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.DispatchManager; import org.whispersystems.dispatch.DispatchManager;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.CertificateGenerator; import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator;
@ -149,7 +150,6 @@ import org.whispersystems.textsecuregcm.sms.TwilioSmsSender;
import org.whispersystems.textsecuregcm.sms.TwilioVerifyExperimentEnrollmentManager; import org.whispersystems.textsecuregcm.sms.TwilioVerifyExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue; import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.AbusiveHostRules; import org.whispersystems.textsecuregcm.storage.AbusiveHostRules;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountCleaner; import org.whispersystems.textsecuregcm.storage.AccountCleaner;
import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawler; import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawler;
import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerCache; import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerCache;
@ -544,31 +544,40 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.credentialsProvider(cdnCredentialsProvider) .credentialsProvider(cdnCredentialsProvider)
.region(Region.of(config.getCdnConfiguration().getRegion())) .region(Region.of(config.getCdnConfiguration().getRegion()))
.build(); .build();
PostPolicyGenerator profileCdnPolicyGenerator = new PostPolicyGenerator(config.getCdnConfiguration().getRegion(), config.getCdnConfiguration().getBucket(), config.getCdnConfiguration().getAccessKey()); PostPolicyGenerator profileCdnPolicyGenerator = new PostPolicyGenerator(config.getCdnConfiguration().getRegion(),
PolicySigner profileCdnPolicySigner = new PolicySigner(config.getCdnConfiguration().getAccessSecret(), config.getCdnConfiguration().getRegion()); config.getCdnConfiguration().getBucket(), config.getCdnConfiguration().getAccessKey());
PolicySigner profileCdnPolicySigner = new PolicySigner(config.getCdnConfiguration().getAccessSecret(),
config.getCdnConfiguration().getRegion());
ServerSecretParams zkSecretParams = new ServerSecretParams(config.getZkConfig().getServerSecret()); ServerSecretParams zkSecretParams = new ServerSecretParams(config.getZkConfig().getServerSecret());
ServerZkProfileOperations zkProfileOperations = new ServerZkProfileOperations(zkSecretParams); ServerZkProfileOperations zkProfileOperations = new ServerZkProfileOperations(zkSecretParams);
ServerZkAuthOperations zkAuthOperations = new ServerZkAuthOperations(zkSecretParams); ServerZkAuthOperations zkAuthOperations = new ServerZkAuthOperations(zkSecretParams);
boolean isZkEnabled = config.getZkConfig().isEnabled(); boolean isZkEnabled = config.getZkConfig().isEnabled();
AuthFilter<BasicCredentials, Account> accountAuthFilter = new BasicCredentialAuthFilter.Builder<Account>().setAuthenticator(accountAuthenticator).buildAuthFilter (); AuthFilter<BasicCredentials, AuthenticatedAccount> accountAuthFilter = new BasicCredentialAuthFilter.Builder<AuthenticatedAccount>().setAuthenticator(
AuthFilter<BasicCredentials, DisabledPermittedAccount> disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder<DisabledPermittedAccount>().setAuthenticator(disabledPermittedAccountAuthenticator).buildAuthFilter(); accountAuthenticator).buildAuthFilter();
AuthFilter<BasicCredentials, DisabledPermittedAuthenticatedAccount> disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder<DisabledPermittedAuthenticatedAccount>().setAuthenticator(
disabledPermittedAccountAuthenticator).buildAuthFilter();
environment.servlets().addFilter("RemoteDeprecationFilter", new RemoteDeprecationFilter(dynamicConfigurationManager)) environment.servlets()
.addFilter("RemoteDeprecationFilter", new RemoteDeprecationFilter(dynamicConfigurationManager))
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
environment.jersey().register(new ContentLengthFilter(TrafficSource.HTTP)); environment.jersey().register(new ContentLengthFilter(TrafficSource.HTTP));
environment.jersey().register(MultiRecipientMessageProvider.class); environment.jersey().register(MultiRecipientMessageProvider.class);
environment.jersey().register(new MetricsApplicationEventListener(TrafficSource.HTTP)); environment.jersey().register(new MetricsApplicationEventListener(TrafficSource.HTTP));
environment.jersey().register(new PolymorphicAuthDynamicFeature<>(ImmutableMap.of(Account.class, accountAuthFilter, environment.jersey()
DisabledPermittedAccount.class, disabledPermittedAccountAuthFilter))); .register(new PolymorphicAuthDynamicFeature<>(ImmutableMap.of(AuthenticatedAccount.class, accountAuthFilter,
environment.jersey().register(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))); DisabledPermittedAuthenticatedAccount.class, disabledPermittedAccountAuthFilter)));
environment.jersey().register(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)));
environment.jersey().register(new TimestampResponseFilter()); environment.jersey().register(new TimestampResponseFilter());
environment.jersey().register(new VoiceVerificationController(config.getVoiceVerificationConfiguration().getUrl(), config.getVoiceVerificationConfiguration().getLocales())); environment.jersey().register(new VoiceVerificationController(config.getVoiceVerificationConfiguration().getUrl(),
config.getVoiceVerificationConfiguration().getLocales()));
/// ///
WebSocketEnvironment<Account> webSocketEnvironment = new WebSocketEnvironment<>(environment, config.getWebSocketConfiguration(), 90000); WebSocketEnvironment<AuthenticatedAccount> webSocketEnvironment = new WebSocketEnvironment<>(environment,
config.getWebSocketConfiguration(), 90000);
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator)); webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator));
webSocketEnvironment.setConnectListener( webSocketEnvironment.setConnectListener(
new AuthenticatedConnectListener(receiptSender, messagesManager, messageSender, apnFallbackManager, new AuthenticatedConnectListener(receiptSender, messagesManager, messageSender, apnFallbackManager,
@ -602,15 +611,18 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new RemoteConfigController(remoteConfigsManager, config.getRemoteConfigConfiguration().getAuthorizedTokens(), config.getRemoteConfigConfiguration().getGlobalConfig()), new RemoteConfigController(remoteConfigsManager, config.getRemoteConfigConfiguration().getAuthorizedTokens(), config.getRemoteConfigConfiguration().getGlobalConfig()),
new SecureBackupController(backupCredentialsGenerator), new SecureBackupController(backupCredentialsGenerator),
new SecureStorageController(storageCredentialsGenerator), new SecureStorageController(storageCredentialsGenerator),
new StickerController(rateLimiters, config.getCdnConfiguration().getAccessKey(), config.getCdnConfiguration().getAccessSecret(), config.getCdnConfiguration().getRegion(), config.getCdnConfiguration().getBucket()) new StickerController(rateLimiters, config.getCdnConfiguration().getAccessKey(),
config.getCdnConfiguration().getAccessSecret(), config.getCdnConfiguration().getRegion(),
config.getCdnConfiguration().getBucket())
); );
for (Object controller : commonControllers) { for (Object controller : commonControllers) {
environment.jersey().register(controller); environment.jersey().register(controller);
webSocketEnvironment.jersey().register(controller); webSocketEnvironment.jersey().register(controller);
} }
WebSocketEnvironment<Account> provisioningEnvironment = new WebSocketEnvironment<>(environment, webSocketEnvironment.getRequestLog(), 60000); WebSocketEnvironment<AuthenticatedAccount> provisioningEnvironment = new WebSocketEnvironment<>(environment,
webSocketEnvironment.getRequestLog(), 60000);
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(pubSubManager)); provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(pubSubManager));
provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET)); provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET));
provisioningEnvironment.jersey().register(new KeepAliveController(clientPresenceManager)); provisioningEnvironment.jersey().register(new KeepAliveController(clientPresenceManager));
@ -618,16 +630,19 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
registerCorsFilter(environment); registerCorsFilter(environment);
registerExceptionMappers(environment, webSocketEnvironment, provisioningEnvironment); registerExceptionMappers(environment, webSocketEnvironment, provisioningEnvironment);
RateLimitChallengeExceptionMapper rateLimitChallengeExceptionMapper = new RateLimitChallengeExceptionMapper(rateLimitChallengeManager); RateLimitChallengeExceptionMapper rateLimitChallengeExceptionMapper = new RateLimitChallengeExceptionMapper(
rateLimitChallengeManager);
environment.jersey().register(rateLimitChallengeExceptionMapper); environment.jersey().register(rateLimitChallengeExceptionMapper);
webSocketEnvironment.jersey().register(rateLimitChallengeExceptionMapper); webSocketEnvironment.jersey().register(rateLimitChallengeExceptionMapper);
provisioningEnvironment.jersey().register(rateLimitChallengeExceptionMapper); provisioningEnvironment.jersey().register(rateLimitChallengeExceptionMapper);
WebSocketResourceProviderFactory<Account> webSocketServlet = new WebSocketResourceProviderFactory<>(webSocketEnvironment, Account.class); WebSocketResourceProviderFactory<AuthenticatedAccount> webSocketServlet = new WebSocketResourceProviderFactory<>(
WebSocketResourceProviderFactory<Account> provisioningServlet = new WebSocketResourceProviderFactory<>(provisioningEnvironment, Account.class); webSocketEnvironment, AuthenticatedAccount.class);
WebSocketResourceProviderFactory<AuthenticatedAccount> provisioningServlet = new WebSocketResourceProviderFactory<>(
provisioningEnvironment, AuthenticatedAccount.class);
ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet ); ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet);
ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet); ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet);
websocket.addMapping("/v1/websocket/"); websocket.addMapping("/v1/websocket/");
@ -649,14 +664,18 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.metrics().register(name(NetworkReceivedGauge.class, "bytes_received"), new NetworkReceivedGauge()); environment.metrics().register(name(NetworkReceivedGauge.class, "bytes_received"), new NetworkReceivedGauge());
environment.metrics().register(name(FileDescriptorGauge.class, "fd_count"), new FileDescriptorGauge()); environment.metrics().register(name(FileDescriptorGauge.class, "fd_count"), new FileDescriptorGauge());
environment.metrics().register(name(MaxFileDescriptorGauge.class, "max_fd_count"), new MaxFileDescriptorGauge()); environment.metrics().register(name(MaxFileDescriptorGauge.class, "max_fd_count"), new MaxFileDescriptorGauge());
environment.metrics().register(name(OperatingSystemMemoryGauge.class, "buffers"), new OperatingSystemMemoryGauge("Buffers")); environment.metrics()
environment.metrics().register(name(OperatingSystemMemoryGauge.class, "cached"), new OperatingSystemMemoryGauge("Cached")); .register(name(OperatingSystemMemoryGauge.class, "buffers"), new OperatingSystemMemoryGauge("Buffers"));
environment.metrics()
.register(name(OperatingSystemMemoryGauge.class, "cached"), new OperatingSystemMemoryGauge("Cached"));
BufferPoolGauges.registerMetrics(); BufferPoolGauges.registerMetrics();
GarbageCollectionGauges.registerMetrics(); GarbageCollectionGauges.registerMetrics();
} }
private void registerExceptionMappers(Environment environment, WebSocketEnvironment<Account> webSocketEnvironment, WebSocketEnvironment<Account> provisioningEnvironment) { private void registerExceptionMappers(Environment environment,
WebSocketEnvironment<AuthenticatedAccount> webSocketEnvironment,
WebSocketEnvironment<AuthenticatedAccount> provisioningEnvironment) {
environment.jersey().register(new LoggingUnhandledExceptionMapper()); environment.jersey().register(new LoggingUnhandledExceptionMapper());
environment.jersey().register(new IOExceptionMapper()); environment.jersey().register(new IOExceptionMapper());
environment.jersey().register(new RateLimitExceededExceptionMapper()); environment.jersey().register(new RateLimitExceededExceptionMapper());

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.auth; package org.whispersystems.textsecuregcm.auth;
@ -7,17 +7,17 @@ package org.whispersystems.textsecuregcm.auth;
import io.dropwizard.auth.Authenticator; import io.dropwizard.auth.Authenticator;
import io.dropwizard.auth.basic.BasicCredentials; import io.dropwizard.auth.basic.BasicCredentials;
import java.util.Optional; import java.util.Optional;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
public class AccountAuthenticator extends BaseAccountAuthenticator implements Authenticator<BasicCredentials, Account> { public class AccountAuthenticator extends BaseAccountAuthenticator implements
Authenticator<BasicCredentials, AuthenticatedAccount> {
public AccountAuthenticator(AccountsManager accountsManager) { public AccountAuthenticator(AccountsManager accountsManager) {
super(accountsManager); super(accountsManager);
} }
@Override @Override
public Optional<Account> authenticate(BasicCredentials basicCredentials) { public Optional<AuthenticatedAccount> authenticate(BasicCredentials basicCredentials) {
return super.authenticate(basicCredentials, true); return super.authenticate(basicCredentials, true);
} }

View File

@ -0,0 +1,42 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import java.security.Principal;
import java.util.function.Supplier;
import javax.security.auth.Subject;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair;
public class AuthenticatedAccount implements Principal {
private final Supplier<Pair<Account, Device>> accountAndDevice;
public AuthenticatedAccount(final Supplier<Pair<Account, Device>> accountAndDevice) {
this.accountAndDevice = accountAndDevice;
}
public Account getAccount() {
return accountAndDevice.get().first();
}
public Device getAuthenticatedDevice() {
return accountAndDevice.get().second();
}
// Principal implementation
@Override
public String getName() {
return null;
}
@Override
public boolean implies(final Subject subject) {
return false;
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -19,6 +19,7 @@ import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.RefreshingAccountAndDeviceSupplier;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
public class BaseAccountAuthenticator { public class BaseAccountAuthenticator {
@ -45,14 +46,15 @@ public class BaseAccountAuthenticator {
this.clock = clock; this.clock = clock;
} }
public Optional<Account> authenticate(BasicCredentials basicCredentials, boolean enabledRequired) { public Optional<AuthenticatedAccount> authenticate(BasicCredentials basicCredentials, boolean enabledRequired) {
boolean succeeded = false; boolean succeeded = false;
String failureReason = null; String failureReason = null;
String credentialType = null; String credentialType = null;
try { try {
AuthorizationHeader authorizationHeader = AuthorizationHeader.fromUserAndPassword(basicCredentials.getUsername(), basicCredentials.getPassword()); AuthorizationHeader authorizationHeader = AuthorizationHeader.fromUserAndPassword(basicCredentials.getUsername(),
Optional<Account> account = accountsManager.get(authorizationHeader.getIdentifier()); basicCredentials.getPassword());
Optional<Account> account = accountsManager.get(authorizationHeader.getIdentifier());
credentialType = authorizationHeader.getIdentifier().hasNumber() ? "e164" : "uuid"; credentialType = authorizationHeader.getIdentifier().hasNumber() ? "e164" : "uuid";
@ -83,9 +85,8 @@ public class BaseAccountAuthenticator {
if (device.get().getAuthenticationCredentials().verify(basicCredentials.getPassword())) { if (device.get().getAuthenticationCredentials().verify(basicCredentials.getPassword())) {
succeeded = true; succeeded = true;
final Account authenticatedAccount = updateLastSeen(account.get(), device.get()); final Account authenticatedAccount = updateLastSeen(account.get(), device.get());
// the device in scope might be stale after the update, so get the latest from the authenticated account return Optional.of(new AuthenticatedAccount(
authenticatedAccount.setAuthenticatedDevice(authenticatedAccount.getDevice(device.get().getId()).orElseThrow()); new RefreshingAccountAndDeviceSupplier(authenticatedAccount, device.get().getId(), accountsManager)));
return Optional.of(authenticatedAccount);
} }
return Optional.empty(); return Optional.empty();

View File

@ -1,36 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import org.whispersystems.textsecuregcm.storage.Account;
import javax.security.auth.Subject;
import java.security.Principal;
public class DisabledPermittedAccount implements Principal {
private final Account account;
public DisabledPermittedAccount(Account account) {
this.account = account;
}
public Account getAccount() {
return account;
}
// Principal implementation
@Override
public String getName() {
return null;
}
@Override
public boolean implies(Subject subject) {
return false;
}
}

View File

@ -1,27 +1,25 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.auth; package org.whispersystems.textsecuregcm.auth;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import java.util.Optional;
import io.dropwizard.auth.Authenticator; import io.dropwizard.auth.Authenticator;
import io.dropwizard.auth.basic.BasicCredentials; import io.dropwizard.auth.basic.BasicCredentials;
import java.util.Optional;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
public class DisabledPermittedAccountAuthenticator extends BaseAccountAuthenticator implements Authenticator<BasicCredentials, DisabledPermittedAccount> { public class DisabledPermittedAccountAuthenticator extends BaseAccountAuthenticator implements
Authenticator<BasicCredentials, DisabledPermittedAuthenticatedAccount> {
public DisabledPermittedAccountAuthenticator(AccountsManager accountsManager) { public DisabledPermittedAccountAuthenticator(AccountsManager accountsManager) {
super(accountsManager); super(accountsManager);
} }
@Override @Override
public Optional<DisabledPermittedAccount> authenticate(BasicCredentials credentials) { public Optional<DisabledPermittedAuthenticatedAccount> authenticate(BasicCredentials credentials) {
Optional<Account> account = super.authenticate(credentials, false); Optional<AuthenticatedAccount> account = super.authenticate(credentials, false);
return account.map(DisabledPermittedAccount::new); return account.map(DisabledPermittedAuthenticatedAccount::new);
} }
} }

View File

@ -0,0 +1,40 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import java.security.Principal;
import javax.security.auth.Subject;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
public class DisabledPermittedAuthenticatedAccount implements Principal {
private final AuthenticatedAccount authenticatedAccount;
public DisabledPermittedAuthenticatedAccount(final AuthenticatedAccount authenticatedAccount) {
this.authenticatedAccount = authenticatedAccount;
}
public Account getAccount() {
return authenticatedAccount.getAccount();
}
public Device getAuthenticatedDevice() {
return authenticatedAccount.getAuthenticatedDevice();
}
// Principal implementation
@Override
public String getName() {
return null;
}
@Override
public boolean implies(Subject subject) {
return false;
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
@ -39,9 +39,10 @@ import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials; import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.auth.AuthorizationHeader; import org.whispersystems.textsecuregcm.auth.AuthorizationHeader;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.auth.InvalidAuthorizationHeaderException; import org.whispersystems.textsecuregcm.auth.InvalidAuthorizationHeaderException;
@ -404,8 +405,8 @@ public class AccountController {
@GET @GET
@Path("/turn/") @Path("/turn/")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public TurnToken getTurnToken(@Auth Account account) throws RateLimitExceededException { public TurnToken getTurnToken(@Auth AuthenticatedAccount auth) throws RateLimitExceededException {
rateLimiters.getTurnLimiter().validate(account.getUuid()); rateLimiters.getTurnLimiter().validate(auth.getAccount().getUuid());
return turnTokenGenerator.generate(); return turnTokenGenerator.generate();
} }
@ -413,13 +414,13 @@ public class AccountController {
@PUT @PUT
@Path("/gcm/") @Path("/gcm/")
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
public void setGcmRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount, @Valid GcmRegistrationId registrationId) { public void setGcmRegistrationId(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth,
Account account = disabledPermittedAccount.getAccount(); @Valid GcmRegistrationId registrationId) {
Device device = account.getAuthenticatedDevice().get(); Account account = disabledPermittedAuth.getAccount();
Device device = disabledPermittedAuth.getAuthenticatedDevice();
if (device.getGcmId() != null && if (device.getGcmId() != null &&
device.getGcmId().equals(registrationId.getGcmRegistrationId())) device.getGcmId().equals(registrationId.getGcmRegistrationId())) {
{
return; return;
} }
@ -434,9 +435,9 @@ public class AccountController {
@Timed @Timed
@DELETE @DELETE
@Path("/gcm/") @Path("/gcm/")
public void deleteGcmRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount) { public void deleteGcmRegistrationId(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth) {
Account account = disabledPermittedAccount.getAccount(); Account account = disabledPermittedAuth.getAccount();
Device device = account.getAuthenticatedDevice().get(); Device device = disabledPermittedAuth.getAuthenticatedDevice();
accounts.updateDevice(account, device.getId(), d -> { accounts.updateDevice(account, device.getId(), d -> {
d.setGcmId(null); d.setGcmId(null);
@ -449,9 +450,10 @@ public class AccountController {
@PUT @PUT
@Path("/apn/") @Path("/apn/")
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
public void setApnRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount, @Valid ApnRegistrationId registrationId) { public void setApnRegistrationId(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth,
Account account = disabledPermittedAccount.getAccount(); @Valid ApnRegistrationId registrationId) {
Device device = account.getAuthenticatedDevice().get(); Account account = disabledPermittedAuth.getAccount();
Device device = disabledPermittedAuth.getAuthenticatedDevice();
accounts.updateDevice(account, device.getId(), d -> { accounts.updateDevice(account, device.getId(), d -> {
d.setApnId(registrationId.getApnRegistrationId()); d.setApnId(registrationId.getApnRegistrationId());
@ -464,9 +466,9 @@ public class AccountController {
@Timed @Timed
@DELETE @DELETE
@Path("/apn/") @Path("/apn/")
public void deleteApnRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount) { public void deleteApnRegistrationId(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth) {
Account account = disabledPermittedAccount.getAccount(); Account account = disabledPermittedAuth.getAccount();
Device device = account.getAuthenticatedDevice().get(); Device device = disabledPermittedAuth.getAuthenticatedDevice();
accounts.updateDevice(account, device.getId(), d -> { accounts.updateDevice(account, device.getId(), d -> {
d.setApnId(null); d.setApnId(null);
@ -483,57 +485,54 @@ public class AccountController {
@PUT @PUT
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Path("/registration_lock") @Path("/registration_lock")
public void setRegistrationLock(@Auth Account account, @Valid RegistrationLock accountLock) { public void setRegistrationLock(@Auth AuthenticatedAccount auth, @Valid RegistrationLock accountLock) {
AuthenticationCredentials credentials = new AuthenticationCredentials(accountLock.getRegistrationLock()); AuthenticationCredentials credentials = new AuthenticationCredentials(accountLock.getRegistrationLock());
accounts.update(account, a -> { accounts.update(auth.getAccount(),
a.setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt()); a -> a.setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt()));
});
} }
@Timed @Timed
@DELETE @DELETE
@Path("/registration_lock") @Path("/registration_lock")
public void removeRegistrationLock(@Auth Account account) { public void removeRegistrationLock(@Auth AuthenticatedAccount auth) {
accounts.update(account, a -> a.setRegistrationLock(null, null)); accounts.update(auth.getAccount(), a -> a.setRegistrationLock(null, null));
} }
@Timed @Timed
@PUT @PUT
@Path("/name/") @Path("/name/")
public void setName(@Auth DisabledPermittedAccount disabledPermittedAccount, @Valid DeviceName deviceName) { public void setName(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth, @Valid DeviceName deviceName) {
Account account = disabledPermittedAccount.getAccount(); Account account = disabledPermittedAuth.getAccount();
Device device = account.getAuthenticatedDevice().get(); Device device = disabledPermittedAuth.getAuthenticatedDevice();
accounts.updateDevice(account, device.getId(), d -> d.setName(deviceName.getDeviceName())); accounts.updateDevice(account, device.getId(), d -> d.setName(deviceName.getDeviceName()));
} }
@Timed @Timed
@DELETE @DELETE
@Path("/signaling_key") @Path("/signaling_key")
public void removeSignalingKey(@Auth DisabledPermittedAccount disabledPermittedAccount) { public void removeSignalingKey(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth) {
} }
@Timed @Timed
@PUT @PUT
@Path("/attributes/") @Path("/attributes/")
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
public void setAccountAttributes(@Auth DisabledPermittedAccount disabledPermittedAccount, public void setAccountAttributes(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth,
@HeaderParam("X-Signal-Agent") String userAgent, @HeaderParam("X-Signal-Agent") String userAgent,
@Valid AccountAttributes attributes) @Valid AccountAttributes attributes) {
{ Account account = disabledPermittedAuth.getAccount();
Account account = disabledPermittedAccount.getAccount(); long deviceId = disabledPermittedAuth.getAuthenticatedDevice().getId();
long deviceId = account.getAuthenticatedDevice().get().getId();
accounts.update(account, a-> {
accounts.update(account, a -> {
a.getDevice(deviceId).ifPresent(d -> { a.getDevice(deviceId).ifPresent(d -> {
d.setFetchesMessages(attributes.getFetchesMessages()); d.setFetchesMessages(attributes.getFetchesMessages());
d.setName(attributes.getName()); d.setName(attributes.getName());
d.setLastSeen(Util.todayInMillis()); d.setLastSeen(Util.todayInMillis());
d.setCapabilities(attributes.getCapabilities()); d.setCapabilities(attributes.getCapabilities());
d.setRegistrationId(attributes.getRegistrationId()); d.setRegistrationId(attributes.getRegistrationId());
d.setUserAgent(userAgent); d.setUserAgent(userAgent);
}); });
a.setRegistrationLockFromAttributes(attributes); a.setRegistrationLockFromAttributes(attributes);
@ -546,29 +545,30 @@ public class AccountController {
@GET @GET
@Path("/me") @Path("/me")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public AccountCreationResult getMe(@Auth Account account) { public AccountCreationResult getMe(@Auth AuthenticatedAccount auth) {
return whoAmI(account); return whoAmI(auth);
} }
@GET @GET
@Path("/whoami") @Path("/whoami")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public AccountCreationResult whoAmI(@Auth Account account) { public AccountCreationResult whoAmI(@Auth AuthenticatedAccount auth) {
return new AccountCreationResult(account.getUuid(), account.isStorageSupported()); return new AccountCreationResult(auth.getAccount().getUuid(), auth.getAccount().isStorageSupported());
} }
@DELETE @DELETE
@Path("/username") @Path("/username")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public void deleteUsername(@Auth Account account) { public void deleteUsername(@Auth AuthenticatedAccount auth) {
usernames.delete(account.getUuid()); usernames.delete(auth.getAccount().getUuid());
} }
@PUT @PUT
@Path("/username/{username}") @Path("/username/{username}")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public Response setUsername(@Auth Account account, @PathParam("username") String username) throws RateLimitExceededException { public Response setUsername(@Auth AuthenticatedAccount auth, @PathParam("username") String username)
rateLimiters.getUsernameSetLimiter().validate(account.getUuid()); throws RateLimitExceededException {
rateLimiters.getUsernameSetLimiter().validate(auth.getAccount().getUuid());
if (username == null || username.isEmpty()) { if (username == null || username.isEmpty()) {
return Response.status(Response.Status.BAD_REQUEST).build(); return Response.status(Response.Status.BAD_REQUEST).build();
@ -580,7 +580,7 @@ public class AccountController {
return Response.status(Response.Status.BAD_REQUEST).build(); return Response.status(Response.Status.BAD_REQUEST).build();
} }
if (!usernames.put(account.getUuid(), username)) { if (!usernames.put(auth.getAccount().getUuid(), username)) {
return Response.status(Response.Status.CONFLICT).build(); return Response.status(Response.Status.CONFLICT).build();
} }
@ -678,8 +678,8 @@ public class AccountController {
@Timed @Timed
@DELETE @DELETE
@Path("/me") @Path("/me")
public void deleteAccount(@Auth Account account) throws InterruptedException { public void deleteAccount(@Auth AuthenticatedAccount auth) throws InterruptedException {
accounts.delete(account, AccountsManager.DeletionReason.USER_REQUEST); accounts.delete(auth.getAccount(), AccountsManager.DeletionReason.USER_REQUEST);
} }
private boolean shouldAutoBlock(String sourceHost) { private boolean shouldAutoBlock(String sourceHost) {

View File

@ -1,29 +1,27 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import com.amazonaws.HttpMethod; import com.amazonaws.HttpMethod;
import com.codahale.metrics.annotation.Timed; import com.codahale.metrics.annotation.Timed;
import org.slf4j.Logger; import io.dropwizard.auth.Auth;
import org.slf4j.LoggerFactory; import java.io.IOException;
import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV1; import java.net.URL;
import org.whispersystems.textsecuregcm.entities.AttachmentUri; import java.util.stream.Stream;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.s3.UrlSigner;
import org.whispersystems.textsecuregcm.storage.Account;
import javax.ws.rs.GET; import javax.ws.rs.GET;
import javax.ws.rs.Path; import javax.ws.rs.Path;
import javax.ws.rs.PathParam; import javax.ws.rs.PathParam;
import javax.ws.rs.Produces; import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import java.io.IOException; import org.slf4j.Logger;
import java.net.URL; import org.slf4j.LoggerFactory;
import java.util.stream.Stream; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV1;
import io.dropwizard.auth.Auth; import org.whispersystems.textsecuregcm.entities.AttachmentUri;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.s3.UrlSigner;
@Path("/v1/attachments") @Path("/v1/attachments")
@ -35,25 +33,25 @@ public class AttachmentControllerV1 extends AttachmentControllerBase {
private static final String[] UNACCELERATED_REGIONS = {"+20", "+971", "+968", "+974"}; private static final String[] UNACCELERATED_REGIONS = {"+20", "+971", "+968", "+974"};
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
private final UrlSigner urlSigner; private final UrlSigner urlSigner;
public AttachmentControllerV1(RateLimiters rateLimiters, String accessKey, String accessSecret, String bucket) { public AttachmentControllerV1(RateLimiters rateLimiters, String accessKey, String accessSecret, String bucket) {
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
this.urlSigner = new UrlSigner(accessKey, accessSecret, bucket); this.urlSigner = new UrlSigner(accessKey, accessSecret, bucket);
} }
@Timed @Timed
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public AttachmentDescriptorV1 allocateAttachment(@Auth Account account) public AttachmentDescriptorV1 allocateAttachment(@Auth AuthenticatedAccount auth)
throws RateLimitExceededException throws RateLimitExceededException {
{ if (auth.getAccount().isRateLimited()) {
if (account.isRateLimited()) { rateLimiters.getAttachmentLimiter().validate(auth.getAccount().getUuid());
rateLimiters.getAttachmentLimiter().validate(account.getUuid());
} }
long attachmentId = generateAttachmentId(); long attachmentId = generateAttachmentId();
URL url = urlSigner.getPreSignedUrl(attachmentId, HttpMethod.PUT, Stream.of(UNACCELERATED_REGIONS).anyMatch(region -> account.getNumber().startsWith(region))); URL url = urlSigner.getPreSignedUrl(attachmentId, HttpMethod.PUT,
Stream.of(UNACCELERATED_REGIONS).anyMatch(region -> auth.getAccount().getNumber().startsWith(region)));
return new AttachmentDescriptorV1(attachmentId, url.toExternalForm()); return new AttachmentDescriptorV1(attachmentId, url.toExternalForm());
@ -63,11 +61,11 @@ public class AttachmentControllerV1 extends AttachmentControllerBase {
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Path("/{attachmentId}") @Path("/{attachmentId}")
public AttachmentUri redirectToAttachment(@Auth Account account, public AttachmentUri redirectToAttachment(@Auth AuthenticatedAccount auth,
@PathParam("attachmentId") long attachmentId) @PathParam("attachmentId") long attachmentId)
throws IOException throws IOException {
{ return new AttachmentUri(urlSigner.getPreSignedUrl(attachmentId, HttpMethod.GET,
return new AttachmentUri(urlSigner.getPreSignedUrl(attachmentId, HttpMethod.GET, Stream.of(UNACCELERATED_REGIONS).anyMatch(region -> account.getNumber().startsWith(region)))); Stream.of(UNACCELERATED_REGIONS).anyMatch(region -> auth.getAccount().getNumber().startsWith(region))));
} }
} }

View File

@ -1,28 +1,26 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed; import com.codahale.metrics.annotation.Timed;
import io.dropwizard.auth.Auth;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV2; import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV2;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.s3.PolicySigner; import org.whispersystems.textsecuregcm.s3.PolicySigner;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator; import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Pair;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import io.dropwizard.auth.Auth;
@Path("/v2/attachments") @Path("/v2/attachments")
public class AttachmentControllerV2 extends AttachmentControllerBase { public class AttachmentControllerV2 extends AttachmentControllerBase {
@ -40,19 +38,20 @@ public class AttachmentControllerV2 extends AttachmentControllerBase {
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Path("/form/upload") @Path("/form/upload")
public AttachmentDescriptorV2 getAttachmentUploadForm(@Auth Account account) throws RateLimitExceededException { public AttachmentDescriptorV2 getAttachmentUploadForm(@Auth AuthenticatedAccount auth)
rateLimiter.validate(account.getUuid()); throws RateLimitExceededException {
rateLimiter.validate(auth.getAccount().getUuid());
ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC); ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC);
long attachmentId = generateAttachmentId(); long attachmentId = generateAttachmentId();
String objectName = String.valueOf(attachmentId); String objectName = String.valueOf(attachmentId);
Pair<String, String> policy = policyGenerator.createFor(now, String.valueOf(objectName), 100 * 1024 * 1024); Pair<String, String> policy = policyGenerator.createFor(now, String.valueOf(objectName), 100 * 1024 * 1024);
String signature = policySigner.getSignature(now, policy.second()); String signature = policySigner.getSignature(now, policy.second());
return new AttachmentDescriptorV2(attachmentId, objectName, policy.first(), return new AttachmentDescriptorV2(attachmentId, objectName, policy.first(),
"private", "AWS4-HMAC-SHA256", "private", "AWS4-HMAC-SHA256",
now.format(PostPolicyGenerator.AWS_DATE_TIME), now.format(PostPolicyGenerator.AWS_DATE_TIME),
policy.second(), signature); policy.second(), signature);
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -7,19 +7,6 @@ package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed; import com.codahale.metrics.annotation.Timed;
import io.dropwizard.auth.Auth; import io.dropwizard.auth.Auth;
import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV3;
import org.whispersystems.textsecuregcm.gcp.CanonicalRequest;
import org.whispersystems.textsecuregcm.gcp.CanonicalRequestGenerator;
import org.whispersystems.textsecuregcm.gcp.CanonicalRequestSigner;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import javax.annotation.Nonnull;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import java.io.IOException; import java.io.IOException;
import java.security.InvalidKeyException; import java.security.InvalidKeyException;
import java.security.SecureRandom; import java.security.SecureRandom;
@ -29,6 +16,18 @@ import java.time.ZonedDateTime;
import java.util.Base64; import java.util.Base64;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import javax.annotation.Nonnull;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV3;
import org.whispersystems.textsecuregcm.gcp.CanonicalRequest;
import org.whispersystems.textsecuregcm.gcp.CanonicalRequestGenerator;
import org.whispersystems.textsecuregcm.gcp.CanonicalRequestSigner;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
@Path("/v3/attachments") @Path("/v3/attachments")
public class AttachmentControllerV3 extends AttachmentControllerBase { public class AttachmentControllerV3 extends AttachmentControllerBase {
@ -45,26 +44,29 @@ public class AttachmentControllerV3 extends AttachmentControllerBase {
@Nonnull @Nonnull
private final SecureRandom secureRandom; private final SecureRandom secureRandom;
public AttachmentControllerV3(@Nonnull RateLimiters rateLimiters, @Nonnull String domain, @Nonnull String email, int maxSizeInBytes, @Nonnull String pathPrefix, @Nonnull String rsaSigningKey) public AttachmentControllerV3(@Nonnull RateLimiters rateLimiters, @Nonnull String domain, @Nonnull String email,
int maxSizeInBytes, @Nonnull String pathPrefix, @Nonnull String rsaSigningKey)
throws IOException, InvalidKeyException, InvalidKeySpecException { throws IOException, InvalidKeyException, InvalidKeySpecException {
this.rateLimiter = rateLimiters.getAttachmentLimiter(); this.rateLimiter = rateLimiters.getAttachmentLimiter();
this.canonicalRequestGenerator = new CanonicalRequestGenerator(domain, email, maxSizeInBytes, pathPrefix); this.canonicalRequestGenerator = new CanonicalRequestGenerator(domain, email, maxSizeInBytes, pathPrefix);
this.canonicalRequestSigner = new CanonicalRequestSigner(rsaSigningKey); this.canonicalRequestSigner = new CanonicalRequestSigner(rsaSigningKey);
this.secureRandom = new SecureRandom(); this.secureRandom = new SecureRandom();
} }
@Timed @Timed
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Path("/form/upload") @Path("/form/upload")
public AttachmentDescriptorV3 getAttachmentUploadForm(@Auth Account account) throws RateLimitExceededException { public AttachmentDescriptorV3 getAttachmentUploadForm(@Auth AuthenticatedAccount auth)
rateLimiter.validate(account.getUuid()); throws RateLimitExceededException {
rateLimiter.validate(auth.getAccount().getUuid());
final ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC); final ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC);
final String key = generateAttachmentKey(); final String key = generateAttachmentKey();
final CanonicalRequest canonicalRequest = canonicalRequestGenerator.createFor(key, now); final CanonicalRequest canonicalRequest = canonicalRequestGenerator.createFor(key, now);
return new AttachmentDescriptorV3(2, key, getHeaderMap(canonicalRequest), getSignedUploadLocation(canonicalRequest)); return new AttachmentDescriptorV3(2, key, getHeaderMap(canonicalRequest),
getSignedUploadLocation(canonicalRequest));
} }
private String getSignedUploadLocation(@Nonnull CanonicalRequest canonicalRequest) { private String getSignedUploadLocation(@Nonnull CanonicalRequest canonicalRequest) {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -10,7 +10,6 @@ import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.annotation.Timed; import com.codahale.metrics.annotation.Timed;
import io.dropwizard.auth.Auth; import io.dropwizard.auth.Auth;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import java.io.IOException;
import java.security.InvalidKeyException; import java.security.InvalidKeyException;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
@ -24,10 +23,10 @@ import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.signal.zkgroup.auth.ServerZkAuthOperations; import org.signal.zkgroup.auth.ServerZkAuthOperations;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.CertificateGenerator; import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
import org.whispersystems.textsecuregcm.entities.DeliveryCertificate; import org.whispersystems.textsecuregcm.entities.DeliveryCertificate;
import org.whispersystems.textsecuregcm.entities.GroupCredentials; import org.whispersystems.textsecuregcm.entities.GroupCredentials;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@ -51,43 +50,49 @@ public class CertificateController {
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Path("/delivery") @Path("/delivery")
public DeliveryCertificate getDeliveryCertificate(@Auth Account account, public DeliveryCertificate getDeliveryCertificate(@Auth AuthenticatedAccount auth,
@QueryParam("includeE164") Optional<Boolean> maybeIncludeE164) @QueryParam("includeE164") Optional<Boolean> maybeIncludeE164)
throws InvalidKeyException throws InvalidKeyException {
{ if (Util.isEmpty(auth.getAccount().getIdentityKey())) {
if (account.getAuthenticatedDevice().isEmpty()) {
throw new AssertionError();
}
if (Util.isEmpty(account.getIdentityKey())) {
throw new WebApplicationException(Response.Status.BAD_REQUEST); throw new WebApplicationException(Response.Status.BAD_REQUEST);
} }
final boolean includeE164 = maybeIncludeE164.orElse(true); final boolean includeE164 = maybeIncludeE164.orElse(true);
Metrics.counter(GENERATE_DELIVERY_CERTIFICATE_COUNTER_NAME, INCLUDE_E164_TAG_NAME, String.valueOf(includeE164)).increment(); Metrics.counter(GENERATE_DELIVERY_CERTIFICATE_COUNTER_NAME, INCLUDE_E164_TAG_NAME, String.valueOf(includeE164))
.increment();
return new DeliveryCertificate(certificateGenerator.createFor(account, account.getAuthenticatedDevice().get(), includeE164)); return new DeliveryCertificate(
certificateGenerator.createFor(auth.getAccount(), auth.getAuthenticatedDevice(), includeE164));
} }
@Timed @Timed
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Path("/group/{startRedemptionTime}/{endRedemptionTime}") @Path("/group/{startRedemptionTime}/{endRedemptionTime}")
public GroupCredentials getAuthenticationCredentials(@Auth Account account, public GroupCredentials getAuthenticationCredentials(@Auth AuthenticatedAccount auth,
@PathParam("startRedemptionTime") int startRedemptionTime, @PathParam("startRedemptionTime") int startRedemptionTime,
@PathParam("endRedemptionTime") int endRedemptionTime) @PathParam("endRedemptionTime") int endRedemptionTime) {
{ if (!isZkEnabled) {
if (!isZkEnabled) throw new WebApplicationException(Response.Status.NOT_FOUND); throw new WebApplicationException(Response.Status.NOT_FOUND);
if (startRedemptionTime > endRedemptionTime) throw new WebApplicationException(Response.Status.BAD_REQUEST); }
if (endRedemptionTime > Util.currentDaysSinceEpoch() + 7) throw new WebApplicationException(Response.Status.BAD_REQUEST); if (startRedemptionTime > endRedemptionTime) {
if (startRedemptionTime < Util.currentDaysSinceEpoch()) throw new WebApplicationException(Response.Status.BAD_REQUEST); throw new WebApplicationException(Response.Status.BAD_REQUEST);
}
if (endRedemptionTime > Util.currentDaysSinceEpoch() + 7) {
throw new WebApplicationException(Response.Status.BAD_REQUEST);
}
if (startRedemptionTime < Util.currentDaysSinceEpoch()) {
throw new WebApplicationException(Response.Status.BAD_REQUEST);
}
List<GroupCredentials.GroupCredential> credentials = new LinkedList<>(); List<GroupCredentials.GroupCredential> credentials = new LinkedList<>();
for (int i=startRedemptionTime;i<=endRedemptionTime;i++) { for (int i = startRedemptionTime; i <= endRedemptionTime; i++) {
credentials.add(new GroupCredentials.GroupCredential(serverZkAuthOperations.issueAuthCredential(account.getUuid(), i) credentials.add(new GroupCredentials.GroupCredential(
.serialize(), serverZkAuthOperations.issueAuthCredential(auth.getAccount().getUuid(), i)
i)); .serialize(),
i));
} }
return new GroupCredentials(credentials); return new GroupCredentials(credentials);

View File

@ -17,12 +17,12 @@ import javax.ws.rs.Path;
import javax.ws.rs.Produces; import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.AnswerChallengeRequest; import org.whispersystems.textsecuregcm.entities.AnswerChallengeRequest;
import org.whispersystems.textsecuregcm.entities.AnswerPushChallengeRequest; import org.whispersystems.textsecuregcm.entities.AnswerPushChallengeRequest;
import org.whispersystems.textsecuregcm.entities.AnswerRecaptchaChallengeRequest; import org.whispersystems.textsecuregcm.entities.AnswerRecaptchaChallengeRequest;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager; import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.ForwardedIpUtil; import org.whispersystems.textsecuregcm.util.ForwardedIpUtil;
@Path("/v1/challenge") @Path("/v1/challenge")
@ -38,7 +38,7 @@ public class ChallengeController {
@PUT @PUT
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
public Response handleChallengeResponse(@Auth final Account account, public Response handleChallengeResponse(@Auth final AuthenticatedAccount auth,
@Valid final AnswerChallengeRequest answerRequest, @Valid final AnswerChallengeRequest answerRequest,
@HeaderParam("X-Forwarded-For") String forwardedFor) throws RetryLaterException { @HeaderParam("X-Forwarded-For") String forwardedFor) throws RetryLaterException {
@ -46,14 +46,15 @@ public class ChallengeController {
if (answerRequest instanceof AnswerPushChallengeRequest) { if (answerRequest instanceof AnswerPushChallengeRequest) {
final AnswerPushChallengeRequest pushChallengeRequest = (AnswerPushChallengeRequest) answerRequest; final AnswerPushChallengeRequest pushChallengeRequest = (AnswerPushChallengeRequest) answerRequest;
rateLimitChallengeManager.answerPushChallenge(account, pushChallengeRequest.getChallenge()); rateLimitChallengeManager.answerPushChallenge(auth.getAccount(), pushChallengeRequest.getChallenge());
} else if (answerRequest instanceof AnswerRecaptchaChallengeRequest) { } else if (answerRequest instanceof AnswerRecaptchaChallengeRequest) {
try { try {
final AnswerRecaptchaChallengeRequest recaptchaChallengeRequest = (AnswerRecaptchaChallengeRequest) answerRequest; final AnswerRecaptchaChallengeRequest recaptchaChallengeRequest = (AnswerRecaptchaChallengeRequest) answerRequest;
final String mostRecentProxy = ForwardedIpUtil.getMostRecentProxy(forwardedFor).orElseThrow(); final String mostRecentProxy = ForwardedIpUtil.getMostRecentProxy(forwardedFor).orElseThrow();
rateLimitChallengeManager.answerRecaptchaChallenge(account, recaptchaChallengeRequest.getCaptcha(), mostRecentProxy); rateLimitChallengeManager.answerRecaptchaChallenge(auth.getAccount(), recaptchaChallengeRequest.getCaptcha(),
mostRecentProxy);
} catch (final NoSuchElementException e) { } catch (final NoSuchElementException e) {
return Response.status(400).build(); return Response.status(400).build();
@ -69,9 +70,9 @@ public class ChallengeController {
@Timed @Timed
@POST @POST
@Path("/push") @Path("/push")
public Response requestPushChallenge(@Auth final Account account) { public Response requestPushChallenge(@Auth final AuthenticatedAccount auth) {
try { try {
rateLimitChallengeManager.sendPushChallenge(account); rateLimitChallengeManager.sendPushChallenge(auth.getAccount());
return Response.status(200).build(); return Response.status(200).build();
} catch (final NotPushRegisteredException e) { } catch (final NotPushRegisteredException e) {
return Response.status(404).build(); return Response.status(404).build();

View File

@ -26,6 +26,7 @@ import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials; import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.auth.AuthorizationHeader; import org.whispersystems.textsecuregcm.auth.AuthorizationHeader;
import org.whispersystems.textsecuregcm.auth.InvalidAuthorizationHeaderException; import org.whispersystems.textsecuregcm.auth.InvalidAuthorizationHeaderException;
@ -79,12 +80,12 @@ public class DeviceController {
@Timed @Timed
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public DeviceInfoList getDevices(@Auth Account account) { public DeviceInfoList getDevices(@Auth AuthenticatedAccount auth) {
List<DeviceInfo> devices = new LinkedList<>(); List<DeviceInfo> devices = new LinkedList<>();
for (Device device : account.getDevices()) { for (Device device : auth.getAccount().getDevices()) {
devices.add(new DeviceInfo(device.getId(), device.getName(), devices.add(new DeviceInfo(device.getId(), device.getName(),
device.getLastSeen(), device.getCreated())); device.getLastSeen(), device.getCreated()));
} }
return new DeviceInfoList(devices); return new DeviceInfoList(devices);
@ -93,8 +94,9 @@ public class DeviceController {
@Timed @Timed
@DELETE @DELETE
@Path("/{device_id}") @Path("/{device_id}")
public void removeDevice(@Auth Account account, @PathParam("device_id") long deviceId) { public void removeDevice(@Auth AuthenticatedAccount auth, @PathParam("device_id") long deviceId) {
if (account.getAuthenticatedDevice().get().getId() != Device.MASTER_ID) { Account account = auth.getAccount();
if (auth.getAuthenticatedDevice().getId() != Device.MASTER_ID) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED); throw new WebApplicationException(Response.Status.UNAUTHORIZED);
} }
@ -109,9 +111,11 @@ public class DeviceController {
@GET @GET
@Path("/provisioning/code") @Path("/provisioning/code")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public VerificationCode createDeviceToken(@Auth Account account) public VerificationCode createDeviceToken(@Auth AuthenticatedAccount auth)
throws RateLimitExceededException, DeviceLimitExceededException throws RateLimitExceededException, DeviceLimitExceededException {
{
final Account account = auth.getAccount();
rateLimiters.getAllocateDeviceLimiter().validate(account.getUuid()); rateLimiters.getAllocateDeviceLimiter().validate(account.getUuid());
int maxDeviceLimit = MAX_DEVICES; int maxDeviceLimit = MAX_DEVICES;
@ -124,7 +128,7 @@ public class DeviceController {
throw new DeviceLimitExceededException(account.getDevices().size(), MAX_DEVICES); throw new DeviceLimitExceededException(account.getDevices().size(), MAX_DEVICES);
} }
if (account.getAuthenticatedDevice().get().getId() != Device.MASTER_ID) { if (auth.getAuthenticatedDevice().getId() != Device.MASTER_ID) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED); throw new WebApplicationException(Response.Status.UNAUTHORIZED);
} }
@ -213,18 +217,18 @@ public class DeviceController {
@Timed @Timed
@PUT @PUT
@Path("/unauthenticated_delivery") @Path("/unauthenticated_delivery")
public void setUnauthenticatedDelivery(@Auth Account account) { public void setUnauthenticatedDelivery(@Auth AuthenticatedAccount auth) {
assert(account.getAuthenticatedDevice().isPresent()); assert (auth.getAuthenticatedDevice() != null);
// Deprecated // Deprecated
} }
@Timed @Timed
@PUT @PUT
@Path("/capabilities") @Path("/capabilities")
public void setCapabiltities(@Auth Account account, @Valid DeviceCapabilities capabilities) { public void setCapabiltities(@Auth AuthenticatedAccount auth, @Valid DeviceCapabilities capabilities) {
assert(account.getAuthenticatedDevice().isPresent()); assert (auth.getAuthenticatedDevice() != null);
final long deviceId = account.getAuthenticatedDevice().get().getId(); final long deviceId = auth.getAuthenticatedDevice().getId();
accounts.updateDevice(account, deviceId, d -> d.setCapabilities(capabilities)); accounts.updateDevice(auth.getAccount(), deviceId, d -> d.setCapabilities(capabilities));
} }
@VisibleForTesting protected VerificationCode generateVerificationCode() { @VisibleForTesting protected VerificationCode generateVerificationCode() {

View File

@ -1,14 +1,11 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed; import com.codahale.metrics.annotation.Timed;
import io.dropwizard.auth.Auth; import io.dropwizard.auth.Auth;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.storage.Account;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
import javax.ws.rs.GET; import javax.ws.rs.GET;
import javax.ws.rs.PUT; import javax.ws.rs.PUT;
@ -16,6 +13,8 @@ import javax.ws.rs.Path;
import javax.ws.rs.Produces; import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
@Path("/v1/directory") @Path("/v1/directory")
public class DirectoryController { public class DirectoryController {
@ -30,15 +29,15 @@ public class DirectoryController {
@GET @GET
@Path("/auth") @Path("/auth")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public Response getAuthToken(@Auth Account account) { public Response getAuthToken(@Auth AuthenticatedAccount auth) {
return Response.ok().entity(directoryServiceTokenGenerator.generateFor(account.getNumber())).build(); return Response.ok().entity(directoryServiceTokenGenerator.generateFor(auth.getAccount().getNumber())).build();
} }
@PUT @PUT
@Path("/feedback-v3/{status}") @Path("/feedback-v3/{status}")
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public Response setFeedback(@Auth Account account) { public Response setFeedback(@Auth AuthenticatedAccount auth) {
return Response.ok().build(); return Response.ok().build();
} }
@ -47,7 +46,7 @@ public class DirectoryController {
@GET @GET
@Path("/{token}") @Path("/{token}")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public Response getTokenPresence(@Auth Account account) { public Response getTokenPresence(@Auth AuthenticatedAccount auth) {
return Response.status(429).build(); return Response.status(429).build();
} }
@ -56,7 +55,7 @@ public class DirectoryController {
@Path("/tokens") @Path("/tokens")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
public Response getContactIntersection(@Auth Account account) { public Response getContactIntersection(@Auth AuthenticatedAccount auth) {
return Response.status(429).build(); return Response.status(429).build();
} }
} }

View File

@ -34,12 +34,12 @@ import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.configuration.DonationConfiguration; import org.whispersystems.textsecuregcm.configuration.DonationConfiguration;
import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationRequest; import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationRequest;
import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationResponse; import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationResponse;
import org.whispersystems.textsecuregcm.http.FaultTolerantHttpClient; import org.whispersystems.textsecuregcm.http.FaultTolerantHttpClient;
import org.whispersystems.textsecuregcm.http.FormDataBodyPublisher; import org.whispersystems.textsecuregcm.http.FormDataBodyPublisher;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
@Path("/v1/donation") @Path("/v1/donation")
@ -75,7 +75,7 @@ public class DonationController {
@Path("/authorize-apple-pay") @Path("/authorize-apple-pay")
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> getApplePayAuthorization(@Auth Account account, @Valid ApplePayAuthorizationRequest request) { public CompletableFuture<Response> getApplePayAuthorization(@Auth AuthenticatedAccount auth, @Valid ApplePayAuthorizationRequest request) {
if (!supportedCurrencies.contains(request.getCurrency())) { if (!supportedCurrencies.contains(request.getCurrency())) {
return CompletableFuture.completedFuture(Response.status(422).build()); return CompletableFuture.completedFuture(Response.status(422).build());
} }

View File

@ -1,28 +1,27 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.annotation.Timed; import com.codahale.metrics.annotation.Timed;
import io.dropwizard.auth.Auth; import io.dropwizard.auth.Auth;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.core.Response;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.websocket.session.WebSocketSession; import org.whispersystems.websocket.session.WebSocketSession;
import org.whispersystems.websocket.session.WebSocketSessionContext; import org.whispersystems.websocket.session.WebSocketSessionContext;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.core.Response;
import static com.codahale.metrics.MetricRegistry.name;
@Path("/v1/keepalive") @Path("/v1/keepalive")
public class KeepAliveController { public class KeepAliveController {
@ -40,15 +39,14 @@ public class KeepAliveController {
@Timed @Timed
@GET @GET
public Response getKeepAlive(@Auth Account account, public Response getKeepAlive(@Auth AuthenticatedAccount auth,
@WebSocketSession WebSocketSessionContext context) @WebSocketSession WebSocketSessionContext context) {
{ if (auth != null) {
if (account != null) { if (!clientPresenceManager.isLocallyPresent(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId())) {
if (!clientPresenceManager.isLocallyPresent(account.getUuid(), account.getAuthenticatedDevice().get().getId())) {
logger.warn("***** No local subscription found for {}::{}; age = {}ms, User-Agent = {}", logger.warn("***** No local subscription found for {}::{}; age = {}ms, User-Agent = {}",
account.getUuid(), account.getAuthenticatedDevice().get().getId(), auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(),
System.currentTimeMillis() - context.getClient().getCreatedTimestamp(), System.currentTimeMillis() - context.getClient().getCreatedTimestamp(),
context.getClient().getUserAgent()); context.getClient().getUserAgent());
context.getClient().close(1000, "OK"); context.getClient().close(1000, "OK");

View File

@ -28,7 +28,8 @@ import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.Anonymous; import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.PreKey; import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.PreKeyCount; import org.whispersystems.textsecuregcm.entities.PreKeyCount;
@ -76,8 +77,8 @@ public class KeysController {
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public PreKeyCount getStatus(@Auth Account account) { public PreKeyCount getStatus(@Auth AuthenticatedAccount auth) {
int count = keysDynamoDb.getCount(account, account.getAuthenticatedDevice().get().getId()); int count = keysDynamoDb.getCount(auth.getAccount(), auth.getAuthenticatedDevice().getId());
if (count > 0) { if (count > 0) {
count = count - 1; count = count - 1;
@ -89,10 +90,10 @@ public class KeysController {
@Timed @Timed
@PUT @PUT
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
public void setKeys(@Auth DisabledPermittedAccount disabledPermittedAccount, @Valid PreKeyState preKeys) { public void setKeys(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth, @Valid PreKeyState preKeys) {
Account account = disabledPermittedAccount.getAccount(); Account account = disabledPermittedAuth.getAccount();
Device device = account.getAuthenticatedDevice().get(); Device device = disabledPermittedAuth.getAuthenticatedDevice();
boolean updateAccount = false; boolean updateAccount = false;
if (!preKeys.getSignedPreKey().equals(device.getSignedPreKey())) { if (!preKeys.getSignedPreKey().equals(device.getSignedPreKey())) {
updateAccount = true; updateAccount = true;
@ -116,7 +117,7 @@ public class KeysController {
@GET @GET
@Path("/{identifier}/{device_id}") @Path("/{identifier}/{device_id}")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public Response getDeviceKeys(@Auth Optional<Account> account, public Response getDeviceKeys(@Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey, @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@PathParam("identifier") AmbiguousIdentifier targetName, @PathParam("identifier") AmbiguousIdentifier targetName,
@PathParam("device_id") String deviceId, @PathParam("device_id") String deviceId,
@ -125,14 +126,16 @@ public class KeysController {
targetName.incrementRequestCounter("getDeviceKeys", userAgent); targetName.incrementRequestCounter("getDeviceKeys", userAgent);
if (!account.isPresent() && !accessKey.isPresent()) { if (auth.isEmpty() && accessKey.isEmpty()) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED); throw new WebApplicationException(Response.Status.UNAUTHORIZED);
} }
final Optional<Account> account = auth.map(AuthenticatedAccount::getAccount);
Optional<Account> target = accounts.get(targetName); Optional<Account> target = accounts.get(targetName);
OptionalAccess.verify(account, accessKey, target, deviceId); OptionalAccess.verify(account, accessKey, target, deviceId);
assert(target.isPresent()); assert (target.isPresent());
{ {
final String sourceCountryCode = account.map(a -> Util.getCountryCode(a.getNumber())).orElse("0"); final String sourceCountryCode = account.map(a -> Util.getCountryCode(a.getNumber())).orElse("0");
@ -146,7 +149,9 @@ public class KeysController {
} }
if (account.isPresent()) { if (account.isPresent()) {
rateLimiters.getPreKeysLimiter().validate(account.get().getUuid() + "." + account.get().getAuthenticatedDevice().get().getId() + "__" + target.get().getUuid() + "." + deviceId); rateLimiters.getPreKeysLimiter().validate(
account.get().getUuid() + "." + auth.get().getAuthenticatedDevice().getId() + "__" + target.get().getUuid()
+ "." + deviceId);
try { try {
preKeyRateLimiter.validate(account.get()); preKeyRateLimiter.validate(account.get());
@ -188,22 +193,25 @@ public class KeysController {
@PUT @PUT
@Path("/signed") @Path("/signed")
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
public void setSignedKey(@Auth Account account, @Valid SignedPreKey signedPreKey) { public void setSignedKey(@Auth AuthenticatedAccount auth, @Valid SignedPreKey signedPreKey) {
Device device = account.getAuthenticatedDevice().get(); Device device = auth.getAuthenticatedDevice();
accounts.updateDevice(account, device.getId(), d -> d.setSignedPreKey(signedPreKey)); accounts.updateDevice(auth.getAccount(), device.getId(), d -> d.setSignedPreKey(signedPreKey));
} }
@Timed @Timed
@GET @GET
@Path("/signed") @Path("/signed")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public Optional<SignedPreKey> getSignedKey(@Auth Account account) { public Optional<SignedPreKey> getSignedKey(@Auth AuthenticatedAccount auth) {
Device device = account.getAuthenticatedDevice().get(); Device device = auth.getAuthenticatedDevice();
SignedPreKey signedPreKey = device.getSignedPreKey(); SignedPreKey signedPreKey = device.getSignedPreKey();
if (signedPreKey != null) return Optional.of(signedPreKey); if (signedPreKey != null) {
else return Optional.empty(); return Optional.of(signedPreKey);
} else {
return Optional.empty();
}
} }
private Map<Long, PreKey> getLocalKeys(Account destination, String deviceIdSelector) { private Map<Long, PreKey> getLocalKeys(Account destination, String deviceIdSelector) {

View File

@ -62,6 +62,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.Anonymous; import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKeys; import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKeys;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration;
@ -189,12 +190,12 @@ public class MessageController {
@PUT @PUT
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public Response sendMessage(@Auth Optional<Account> source, public Response sendMessage(@Auth Optional<AuthenticatedAccount> source,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey, @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@HeaderParam("User-Agent") String userAgent, @HeaderParam("User-Agent") String userAgent,
@HeaderParam("X-Forwarded-For") String forwardedFor, @HeaderParam("X-Forwarded-For") String forwardedFor,
@PathParam("destination") AmbiguousIdentifier destinationName, @PathParam("destination") AmbiguousIdentifier destinationName,
@Valid IncomingMessageList messages) @Valid IncomingMessageList messages)
throws RateLimitExceededException, RateLimitChallengeException { throws RateLimitExceededException, RateLimitChallengeException {
destinationName.incrementRequestCounter("sendMessage", userAgent); destinationName.incrementRequestCounter("sendMessage", userAgent);
@ -203,20 +204,22 @@ public class MessageController {
throw new WebApplicationException(Response.Status.UNAUTHORIZED); throw new WebApplicationException(Response.Status.UNAUTHORIZED);
} }
if (source.isPresent() && !source.get().isFor(destinationName)) { if (source.isPresent() && !source.get().getAccount().isFor(destinationName)) {
assert source.get().getMasterDevice().isPresent(); assert source.get().getAccount().getMasterDevice().isPresent();
final Device masterDevice = source.get().getMasterDevice().get(); final Device masterDevice = source.get().getAccount().getMasterDevice().get();
final String senderCountryCode = Util.getCountryCode(source.get().getNumber()); final String senderCountryCode = Util.getCountryCode(source.get().getAccount().getNumber());
if (StringUtils.isAllBlank(masterDevice.getApnId(), masterDevice.getVoipApnId(), masterDevice.getGcmId()) || masterDevice.getUninstalledFeedbackTimestamp() > 0) { if (StringUtils.isAllBlank(masterDevice.getApnId(), masterDevice.getVoipApnId(), masterDevice.getGcmId())
Metrics.counter(UNSEALED_SENDER_WITHOUT_PUSH_TOKEN_COUNTER_NAME, SENDER_COUNTRY_TAG_NAME, senderCountryCode).increment(); || masterDevice.getUninstalledFeedbackTimestamp() > 0) {
Metrics.counter(UNSEALED_SENDER_WITHOUT_PUSH_TOKEN_COUNTER_NAME, SENDER_COUNTRY_TAG_NAME, senderCountryCode)
.increment();
} }
} }
final String senderType; final String senderType;
if (source.isPresent() && !source.get().isFor(destinationName)) { if (source.isPresent() && !source.get().getAccount().isFor(destinationName)) {
identifiedMeter.mark(); identifiedMeter.mark();
senderType = "identified"; senderType = "identified";
} else if (source.isEmpty()) { } else if (source.isEmpty()) {
@ -246,23 +249,26 @@ public class MessageController {
} }
try { try {
boolean isSyncMessage = source.isPresent() && source.get().isFor(destinationName); boolean isSyncMessage = source.isPresent() && source.get().getAccount().isFor(destinationName);
Optional<Account> destination; Optional<Account> destination;
if (!isSyncMessage) destination = accountsManager.get(destinationName); if (!isSyncMessage) {
else destination = source; destination = accountsManager.get(destinationName);
} else {
destination = source.map(AuthenticatedAccount::getAccount);
}
OptionalAccess.verify(source, accessKey, destination); OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination);
assert(destination.isPresent()); assert (destination.isPresent());
if (source.isPresent() && !source.get().isFor(destinationName)) { if (source.isPresent() && !source.get().getAccount().isFor(destinationName)) {
rateLimiters.getMessagesLimiter().validate(source.get().getUuid(), destination.get().getUuid()); rateLimiters.getMessagesLimiter().validate(source.get().getAccount().getUuid(), destination.get().getUuid());
final String senderCountryCode = Util.getCountryCode(source.get().getNumber()); final String senderCountryCode = Util.getCountryCode(source.get().getAccount().getNumber());
try { try {
unsealedSenderRateLimiter.validate(source.get(), destination.get()); unsealedSenderRateLimiter.validate(source.get().getAccount(), destination.get());
} catch (final RateLimitExceededException e) { } catch (final RateLimitExceededException e) {
final boolean legacyClient = rateLimitChallengeManager.isClientBelowMinimumVersion(userAgent); final boolean legacyClient = rateLimitChallengeManager.isClientBelowMinimumVersion(userAgent);
@ -276,11 +282,11 @@ public class MessageController {
throw e; throw e;
} }
throw new RateLimitChallengeException(source.get(), e.getRetryDuration()); throw new RateLimitChallengeException(source.get().getAccount(), e.getRetryDuration());
} }
final String destinationCountryCode = Util.getCountryCode(destination.get().getNumber()); final String destinationCountryCode = Util.getCountryCode(destination.get().getNumber());
final Device masterDevice = source.get().getMasterDevice().get(); final Device masterDevice = source.get().getAccount().getMasterDevice().get();
if (!senderCountryCode.equals(destinationCountryCode)) { if (!senderCountryCode.equals(destinationCountryCode)) {
recordInternationalUnsealedSenderMetrics(forwardedFor, senderCountryCode, destination.get().getNumber()); recordInternationalUnsealedSenderMetrics(forwardedFor, senderCountryCode, destination.get().getNumber());
@ -293,31 +299,34 @@ public class MessageController {
.orElse(false); .orElse(false);
if (isRateLimitedHost) { if (isRateLimitedHost) {
return declineDelivery(messages, source.get(), destination.get()); return declineDelivery(messages, source.get().getAccount(), destination.get());
} }
} }
} }
} }
} }
validateCompleteDeviceList(destination.get(), messages.getMessages(), isSyncMessage); validateCompleteDeviceList(destination.get(), messages.getMessages(), isSyncMessage,
source.map(AuthenticatedAccount::getAuthenticatedDevice).map(Device::getId));
validateRegistrationIds(destination.get(), messages.getMessages()); validateRegistrationIds(destination.get(), messages.getMessages());
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent), final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.isOnline())), Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.isOnline())),
Tag.of(SENDER_TYPE_TAG_NAME, senderType), Tag.of(SENDER_TYPE_TAG_NAME, senderType),
Tag.of(DESTINATION_TYPE_TAG_NAME, destinationName.hasNumber() ? "e164" : "uuid")); Tag.of(DESTINATION_TYPE_TAG_NAME, destinationName.hasNumber() ? "e164" : "uuid"));
for (IncomingMessage incomingMessage : messages.getMessages()) { for (IncomingMessage incomingMessage : messages.getMessages()) {
Optional<Device> destinationDevice = destination.get().getDevice(incomingMessage.getDestinationDeviceId()); Optional<Device> destinationDevice = destination.get().getDevice(incomingMessage.getDestinationDeviceId());
if (destinationDevice.isPresent()) { if (destinationDevice.isPresent()) {
Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment(); Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment();
sendMessage(source, destination.get(), destinationDevice.get(), messages.getTimestamp(), messages.isOnline(), incomingMessage); sendMessage(source, destination.get(), destinationDevice.get(), messages.getTimestamp(), messages.isOnline(),
incomingMessage);
} }
} }
return Response.ok(new SendMessageResponse(!isSyncMessage && source.isPresent() && source.get().getEnabledDeviceCount() > 1)).build(); return Response.ok(new SendMessageResponse(
!isSyncMessage && source.isPresent() && source.get().getAccount().getEnabledDeviceCount() > 1)).build();
} catch (NoSuchUserException e) { } catch (NoSuchUserException e) {
throw new WebApplicationException(Response.status(404).build()); throw new WebApplicationException(Response.status(404).build());
} catch (MismatchedDevicesException e) { } catch (MismatchedDevicesException e) {
@ -380,7 +389,7 @@ public class MessageController {
final Set<Pair<Long, Integer>> deviceIdAndRegistrationIdSet = accountToDeviceIdAndRegistrationIdMap.get(account); final Set<Pair<Long, Integer>> deviceIdAndRegistrationIdSet = accountToDeviceIdAndRegistrationIdMap.get(account);
final Set<Long> deviceIds = deviceIdAndRegistrationIdSet.stream().map(Pair::first).collect(Collectors.toSet()); final Set<Long> deviceIds = deviceIdAndRegistrationIdSet.stream().map(Pair::first).collect(Collectors.toSet());
try { try {
validateCompleteDeviceList(account, deviceIds, false); validateCompleteDeviceList(account, deviceIds, false, Optional.empty());
validateRegistrationIds(account, deviceIdAndRegistrationIdSet.stream()); validateRegistrationIds(account, deviceIdAndRegistrationIdSet.stream());
} catch (MismatchedDevicesException e) { } catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(), accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(),
@ -476,7 +485,9 @@ public class MessageController {
if (random.nextDouble() <= messageRateConfiguration.getReceiptProbability()) { if (random.nextDouble() <= messageRateConfiguration.getReceiptProbability()) {
receiptExecutorService.schedule(() -> { receiptExecutorService.schedule(() -> {
try { try {
receiptSender.sendReceipt(destination, source.getNumber(), timestamp); receiptSender.sendReceipt(
new AuthenticatedAccount(() -> new Pair<>(destination, destination.getMasterDevice().get())),
source.getNumber(), timestamp);
} catch (final NoSuchUserException ignored) { } catch (final NoSuchUserException ignored) {
} }
}, receiptDelay.toMillis(), TimeUnit.MILLISECONDS); }, receiptDelay.toMillis(), TimeUnit.MILLISECONDS);
@ -503,16 +514,17 @@ public class MessageController {
@Timed @Timed
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public OutgoingMessageEntityList getPendingMessages(@Auth Account account, @HeaderParam("User-Agent") String userAgent) { public OutgoingMessageEntityList getPendingMessages(@Auth AuthenticatedAccount auth,
assert account.getAuthenticatedDevice().isPresent(); @HeaderParam("User-Agent") String userAgent) {
assert auth.getAuthenticatedDevice() != null;
if (!Util.isEmpty(account.getAuthenticatedDevice().get().getApnId())) { if (!Util.isEmpty(auth.getAuthenticatedDevice().getApnId())) {
RedisOperation.unchecked(() -> apnFallbackManager.cancel(account, account.getAuthenticatedDevice().get())); RedisOperation.unchecked(() -> apnFallbackManager.cancel(auth.getAccount(), auth.getAuthenticatedDevice()));
} }
final OutgoingMessageEntityList outgoingMessages = messagesManager.getMessagesForDevice( final OutgoingMessageEntityList outgoingMessages = messagesManager.getMessagesForDevice(
account.getUuid(), auth.getAccount().getUuid(),
account.getAuthenticatedDevice().get().getId(), auth.getAuthenticatedDevice().getId(),
userAgent, userAgent,
false); false);
@ -549,21 +561,20 @@ public class MessageController {
@Timed @Timed
@DELETE @DELETE
@Path("/{source}/{timestamp}") @Path("/{source}/{timestamp}")
public void removePendingMessage(@Auth Account account, public void removePendingMessage(@Auth AuthenticatedAccount auth,
@PathParam("source") String source, @PathParam("source") String source,
@PathParam("timestamp") long timestamp) @PathParam("timestamp") long timestamp) {
{
try { try {
WebSocketConnection.recordMessageDeliveryDuration(timestamp, account.getAuthenticatedDevice().get()); WebSocketConnection.recordMessageDeliveryDuration(timestamp, auth.getAuthenticatedDevice());
Optional<OutgoingMessageEntity> message = messagesManager.delete( Optional<OutgoingMessageEntity> message = messagesManager.delete(
account.getUuid(), auth.getAccount().getUuid(),
account.getAuthenticatedDevice().get().getId(), auth.getAuthenticatedDevice().getId(),
source, timestamp); source, timestamp);
if (message.isPresent() && message.get().getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE) { if (message.isPresent() && message.get().getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE) {
receiptSender.sendReceipt(account, receiptSender.sendReceipt(auth,
message.get().getSource(), message.get().getSource(),
message.get().getTimestamp()); message.get().getTimestamp());
} }
} catch (NoSuchUserException e) { } catch (NoSuchUserException e) {
logger.warn("Sending delivery receipt", e); logger.warn("Sending delivery receipt", e);
@ -573,17 +584,18 @@ public class MessageController {
@Timed @Timed
@DELETE @DELETE
@Path("/uuid/{uuid}") @Path("/uuid/{uuid}")
public void removePendingMessage(@Auth Account account, @PathParam("uuid") UUID uuid) { public void removePendingMessage(@Auth AuthenticatedAccount auth, @PathParam("uuid") UUID uuid) {
try { try {
Optional<OutgoingMessageEntity> message = messagesManager.delete( Optional<OutgoingMessageEntity> message = messagesManager.delete(
account.getUuid(), auth.getAccount().getUuid(),
account.getAuthenticatedDevice().get().getId(), auth.getAuthenticatedDevice().getId(),
uuid); uuid);
if (message.isPresent()) { if (message.isPresent()) {
WebSocketConnection.recordMessageDeliveryDuration(message.get().getTimestamp(), account.getAuthenticatedDevice().get()); WebSocketConnection.recordMessageDeliveryDuration(message.get().getTimestamp(), auth.getAuthenticatedDevice());
if (!Util.isEmpty(message.get().getSource()) && message.get().getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE) { if (!Util.isEmpty(message.get().getSource())
receiptSender.sendReceipt(account, message.get().getSource(), message.get().getTimestamp()); && message.get().getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE) {
receiptSender.sendReceipt(auth, message.get().getSource(), message.get().getTimestamp());
} }
} }
@ -595,7 +607,8 @@ public class MessageController {
@Timed @Timed
@POST @POST
@Path("/report/{sourceNumber}/{messageGuid}") @Path("/report/{sourceNumber}/{messageGuid}")
public Response reportMessage(@Auth Account account, @PathParam("sourceNumber") String sourceNumber, @PathParam("messageGuid") UUID messageGuid) { public Response reportMessage(@Auth AuthenticatedAccount auth, @PathParam("sourceNumber") String sourceNumber,
@PathParam("messageGuid") UUID messageGuid) {
reportMessageManager.report(sourceNumber, messageGuid); reportMessageManager.report(sourceNumber, messageGuid);
@ -603,27 +616,26 @@ public class MessageController {
.build(); .build();
} }
private void sendMessage(Optional<Account> source, private void sendMessage(Optional<AuthenticatedAccount> source,
Account destinationAccount, Account destinationAccount,
Device destinationDevice, Device destinationDevice,
long timestamp, long timestamp,
boolean online, boolean online,
IncomingMessage incomingMessage) IncomingMessage incomingMessage)
throws NoSuchUserException throws NoSuchUserException {
{
try (final Timer.Context ignored = sendMessageInternalTimer.time()) { try (final Timer.Context ignored = sendMessageInternalTimer.time()) {
Optional<byte[]> messageBody = getMessageBody(incomingMessage); Optional<byte[]> messageBody = getMessageBody(incomingMessage);
Optional<byte[]> messageContent = getMessageContent(incomingMessage); Optional<byte[]> messageContent = getMessageContent(incomingMessage);
Envelope.Builder messageBuilder = Envelope.newBuilder(); Envelope.Builder messageBuilder = Envelope.newBuilder();
messageBuilder.setType(Envelope.Type.forNumber(incomingMessage.getType())) messageBuilder.setType(Envelope.Type.forNumber(incomingMessage.getType()))
.setTimestamp(timestamp == 0 ? System.currentTimeMillis() : timestamp) .setTimestamp(timestamp == 0 ? System.currentTimeMillis() : timestamp)
.setServerTimestamp(System.currentTimeMillis()); .setServerTimestamp(System.currentTimeMillis());
if (source.isPresent()) { if (source.isPresent()) {
messageBuilder.setSource(source.get().getNumber()) messageBuilder.setSource(source.get().getAccount().getNumber())
.setSourceUuid(source.get().getUuid().toString()) .setSourceUuid(source.get().getAccount().getUuid().toString())
.setSourceDevice((int)source.get().getAuthenticatedDevice().get().getId()); .setSourceDevice((int) source.get().getAuthenticatedDevice().getId());
} }
if (messageBody.isPresent()) { if (messageBody.isPresent()) {
@ -697,24 +709,26 @@ public class MessageController {
} }
@VisibleForTesting @VisibleForTesting
public static void validateCompleteDeviceList(Account account, List<IncomingMessage> messages, boolean isSyncMessage) public static void validateCompleteDeviceList(Account account, List<IncomingMessage> messages, boolean isSyncMessage,
Optional<Long> authenticatedDeviceId)
throws MismatchedDevicesException { throws MismatchedDevicesException {
Set<Long> messageDeviceIds = messages.stream().map(IncomingMessage::getDestinationDeviceId).collect(Collectors.toSet()); Set<Long> messageDeviceIds = messages.stream().map(IncomingMessage::getDestinationDeviceId)
validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage); .collect(Collectors.toSet());
validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage, authenticatedDeviceId);
} }
@VisibleForTesting @VisibleForTesting
public static void validateCompleteDeviceList(Account account, Set<Long> messageDeviceIds, boolean isSyncMessage) public static void validateCompleteDeviceList(Account account, Set<Long> messageDeviceIds, boolean isSyncMessage,
Optional<Long> authenticatedDeviceId)
throws MismatchedDevicesException { throws MismatchedDevicesException {
Set<Long> accountDeviceIds = new HashSet<>(); Set<Long> accountDeviceIds = new HashSet<>();
List<Long> missingDeviceIds = new LinkedList<>(); List<Long> missingDeviceIds = new LinkedList<>();
List<Long> extraDeviceIds = new LinkedList<>(); List<Long> extraDeviceIds = new LinkedList<>();
for (Device device : account.getDevices()) { for (Device device : account.getDevices()) {
if (device.isEnabled() && if (device.isEnabled() &&
!(isSyncMessage && device.getId() == account.getAuthenticatedDevice().get().getId())) !(isSyncMessage && device.getId() == authenticatedDeviceId.get())) {
{
accountDeviceIds.add(device.getId()); accountDeviceIds.add(device.getId());
if (!messageDeviceIds.contains(device.getId())) { if (!messageDeviceIds.contains(device.getId())) {

View File

@ -1,23 +1,21 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed; import com.codahale.metrics.annotation.Timed;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; import io.dropwizard.auth.Auth;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.entities.CurrencyConversionEntityList;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager;
import javax.ws.rs.GET; import javax.ws.rs.GET;
import javax.ws.rs.Path; import javax.ws.rs.Path;
import javax.ws.rs.Produces; import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import io.dropwizard.auth.Auth; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager;
import org.whispersystems.textsecuregcm.entities.CurrencyConversionEntityList;
@Path("/v1/payments") @Path("/v1/payments")
public class PaymentsController { public class PaymentsController {
@ -34,15 +32,15 @@ public class PaymentsController {
@GET @GET
@Path("/auth") @Path("/auth")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public ExternalServiceCredentials getAuth(@Auth Account account) { public ExternalServiceCredentials getAuth(@Auth AuthenticatedAccount auth) {
return paymentsServiceCredentialGenerator.generateFor(account.getUuid().toString()); return paymentsServiceCredentialGenerator.generateFor(auth.getAccount().getUuid().toString());
} }
@Timed @Timed
@GET @GET
@Path("/conversions") @Path("/conversions")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public CurrencyConversionEntityList getConversions(@Auth Account account) { public CurrencyConversionEntityList getConversions(@Auth AuthenticatedAccount auth) {
return currencyManager.getCurrencyConversions().orElseThrow(); return currencyManager.getCurrencyConversions().orElseThrow();
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -41,6 +41,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.Anonymous; import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessChecksum; import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessChecksum;
import org.whispersystems.textsecuregcm.entities.CreateProfileRequest; import org.whispersystems.textsecuregcm.entities.CreateProfileRequest;
@ -107,60 +108,64 @@ public class ProfileController {
this.isZkEnabled = isZkEnabled; this.isZkEnabled = isZkEnabled;
} }
@Timed @Timed
@PUT @PUT
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
public Response setProfile(@Auth Account account, @Valid CreateProfileRequest request) { public Response setProfile(@Auth AuthenticatedAccount auth, @Valid CreateProfileRequest request) {
if (!isZkEnabled) throw new WebApplicationException(Response.Status.NOT_FOUND); if (!isZkEnabled) {
throw new WebApplicationException(Response.Status.NOT_FOUND);
}
final Set<String> allowedPaymentsCountryCodes = final Set<String> allowedPaymentsCountryCodes =
dynamicConfigurationManager.getConfiguration().getPaymentsConfiguration().getAllowedCountryCodes(); dynamicConfigurationManager.getConfiguration().getPaymentsConfiguration().getAllowedCountryCodes();
if (StringUtils.isNotBlank(request.getPaymentAddress()) && if (StringUtils.isNotBlank(request.getPaymentAddress()) &&
!allowedPaymentsCountryCodes.contains(Util.getCountryCode(account.getNumber()))) { !allowedPaymentsCountryCodes.contains(Util.getCountryCode(auth.getAccount().getNumber()))) {
return Response.status(Status.FORBIDDEN).build(); return Response.status(Status.FORBIDDEN).build();
} }
Optional<VersionedProfile> currentProfile = profilesManager.get(account.getUuid(), request.getVersion()); Optional<VersionedProfile> currentProfile = profilesManager.get(auth.getAccount().getUuid(), request.getVersion());
String avatar = request.isAvatar() ? generateAvatarObjectName() : null; String avatar = request.isAvatar() ? generateAvatarObjectName() : null;
Optional<ProfileAvatarUploadAttributes> response = Optional.empty(); Optional<ProfileAvatarUploadAttributes> response = Optional.empty();
profilesManager.set(account.getUuid(), profilesManager.set(auth.getAccount().getUuid(),
new VersionedProfile( new VersionedProfile(
request.getVersion(), request.getVersion(),
request.getName(), request.getName(),
avatar, avatar,
request.getAboutEmoji(), request.getAboutEmoji(),
request.getAbout(), request.getAbout(),
request.getPaymentAddress(), request.getPaymentAddress(),
request.getCommitment().serialize())); request.getCommitment().serialize()));
if (request.isAvatar()) { if (request.isAvatar()) {
Optional<String> currentAvatar = Optional.empty(); Optional<String> currentAvatar = Optional.empty();
if (currentProfile.isPresent() && currentProfile.get().getAvatar() != null && currentProfile.get().getAvatar().startsWith("profiles/")) { if (currentProfile.isPresent() && currentProfile.get().getAvatar() != null && currentProfile.get().getAvatar()
currentAvatar = Optional.of(currentProfile.get().getAvatar()); .startsWith("profiles/")) {
} currentAvatar = Optional.of(currentProfile.get().getAvatar());
}
if (currentAvatar.isEmpty() && account.getAvatar() != null && account.getAvatar().startsWith("profiles/")) { if (currentAvatar.isEmpty() && auth.getAccount().getAvatar() != null && auth.getAccount().getAvatar()
currentAvatar = Optional.of(account.getAvatar()); .startsWith("profiles/")) {
} currentAvatar = Optional.of(auth.getAccount().getAvatar());
}
currentAvatar.ifPresent(s -> s3client.deleteObject(DeleteObjectRequest.builder() currentAvatar.ifPresent(s -> s3client.deleteObject(DeleteObjectRequest.builder()
.bucket(bucket) .bucket(bucket)
.key(s) .key(s)
.build())); .build()));
response = Optional.of(generateAvatarUploadForm(avatar)); response = Optional.of(generateAvatarUploadForm(avatar));
} }
accountsManager.update(account, a -> { accountsManager.update(auth.getAccount(), a -> {
a.setProfileName(request.getName()); a.setProfileName(request.getName());
a.setAvatar(avatar); a.setAvatar(avatar);
a.setCurrentProfileVersion(request.getVersion()); a.setCurrentProfileVersion(request.getVersion());
}); });
if (response.isPresent()) return Response.ok(response).build(); if (response.isPresent()) return Response.ok(response).build();
else return Response.ok().build(); else return Response.ok().build();
@ -170,29 +175,32 @@ public class ProfileController {
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Path("/{uuid}/{version}") @Path("/{uuid}/{version}")
public Optional<Profile> getProfile(@Auth Optional<Account> requestAccount, public Optional<Profile> getProfile(@Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey, @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@PathParam("uuid") UUID uuid, @PathParam("uuid") UUID uuid,
@PathParam("version") String version) @PathParam("version") String version)
throws RateLimitExceededException throws RateLimitExceededException {
{ if (!isZkEnabled) {
if (!isZkEnabled) throw new WebApplicationException(Response.Status.NOT_FOUND); throw new WebApplicationException(Response.Status.NOT_FOUND);
return getVersionedProfile(requestAccount, accessKey, uuid, version, Optional.empty()); }
return getVersionedProfile(auth.map(AuthenticatedAccount::getAccount), accessKey, uuid, version, Optional.empty());
} }
@Timed @Timed
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Path("/{uuid}/{version}/{credentialRequest}") @Path("/{uuid}/{version}/{credentialRequest}")
public Optional<Profile> getProfile(@Auth Optional<Account> requestAccount, public Optional<Profile> getProfile(@Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey, @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@PathParam("uuid") UUID uuid, @PathParam("uuid") UUID uuid,
@PathParam("version") String version, @PathParam("version") String version,
@PathParam("credentialRequest") String credentialRequest) @PathParam("credentialRequest") String credentialRequest)
throws RateLimitExceededException throws RateLimitExceededException {
{ if (!isZkEnabled) {
if (!isZkEnabled) throw new WebApplicationException(Response.Status.NOT_FOUND); throw new WebApplicationException(Response.Status.NOT_FOUND);
return getVersionedProfile(requestAccount, accessKey, uuid, version, Optional.of(credentialRequest)); }
return getVersionedProfile(auth.map(AuthenticatedAccount::getAccount), accessKey, uuid, version,
Optional.of(credentialRequest));
} }
private Optional<Profile> getVersionedProfile(Optional<Account> requestAccount, private Optional<Profile> getVersionedProfile(Optional<Account> requestAccount,
@ -255,22 +263,23 @@ public class ProfileController {
} }
@Timed @Timed
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Path("/username/{username}") @Path("/username/{username}")
public Profile getProfileByUsername(@Auth Account account, @PathParam("username") String username) throws RateLimitExceededException { public Profile getProfileByUsername(@Auth AuthenticatedAccount auth, @PathParam("username") String username)
rateLimiters.getUsernameLookupLimiter().validate(account.getUuid()); throws RateLimitExceededException {
rateLimiters.getUsernameLookupLimiter().validate(auth.getAccount().getUuid());
username = username.toLowerCase(); username = username.toLowerCase();
Optional<UUID> uuid = usernamesManager.get(username); Optional<UUID> uuid = usernamesManager.get(username);
if (uuid.isEmpty()) { if (uuid.isEmpty()) {
throw new WebApplicationException(Response.status(Response.Status.NOT_FOUND).build()); throw new WebApplicationException(Response.status(Response.Status.NOT_FOUND).build());
} }
Optional<Account> accountProfile = accountsManager.get(uuid.get()); Optional<Account> accountProfile = accountsManager.get(uuid.get());
if (accountProfile.isEmpty()) { if (accountProfile.isEmpty()) {
throw new WebApplicationException(Response.status(Response.Status.NOT_FOUND).build()); throw new WebApplicationException(Response.status(Response.Status.NOT_FOUND).build());
@ -312,40 +321,40 @@ public class ProfileController {
// Old profile endpoints. Replaced by versioned profile endpoints (above) // Old profile endpoints. Replaced by versioned profile endpoints (above)
@Deprecated @Deprecated
@Timed @Timed
@PUT @PUT
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Path("/name/{name}") @Path("/name/{name}")
public void setProfile(@Auth Account account, @PathParam("name") @ExactlySize(value = {72, 108}, payload = {Unwrapping.Unwrap.class}) Optional<String> name) { public void setProfile(@Auth AuthenticatedAccount auth,
accountsManager.update(account, a -> a.setProfileName(name.orElse(null))); @PathParam("name") @ExactlySize(value = {72, 108}, payload = {Unwrapping.Unwrap.class}) Optional<String> name) {
} accountsManager.update(auth.getAccount(), a -> a.setProfileName(name.orElse(null)));
}
@Deprecated @Deprecated
@Timed @Timed
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Path("/{identifier}") @Path("/{identifier}")
public Profile getProfile(@Auth Optional<Account> requestAccount, public Profile getProfile(@Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey, @HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@HeaderParam("User-Agent") String userAgent, @HeaderParam("User-Agent") String userAgent,
@PathParam("identifier") AmbiguousIdentifier identifier, @PathParam("identifier") AmbiguousIdentifier identifier,
@QueryParam("ca") boolean useCaCertificate) @QueryParam("ca") boolean useCaCertificate)
throws RateLimitExceededException throws RateLimitExceededException {
{
identifier.incrementRequestCounter("getProfile", userAgent); identifier.incrementRequestCounter("getProfile", userAgent);
if (requestAccount.isEmpty() && accessKey.isEmpty()) { if (auth.isEmpty() && accessKey.isEmpty()) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED); throw new WebApplicationException(Response.Status.UNAUTHORIZED);
} }
if (requestAccount.isPresent()) { if (auth.isPresent()) {
rateLimiters.getProfileLimiter().validate(requestAccount.get().getUuid()); rateLimiters.getProfileLimiter().validate(auth.get().getAccount().getUuid());
} }
Optional<Account> accountProfile = accountsManager.get(identifier); Optional<Account> accountProfile = accountsManager.get(identifier);
OptionalAccess.verify(requestAccount, accessKey, accountProfile); OptionalAccess.verify(auth.map(AuthenticatedAccount::getAccount), accessKey, accountProfile);
Optional<String> username = Optional.empty(); Optional<String> username = Optional.empty();
@ -369,24 +378,24 @@ public class ProfileController {
} }
@Deprecated @Deprecated
@Timed @Timed
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Path("/form/avatar") @Path("/form/avatar")
public ProfileAvatarUploadAttributes getAvatarUploadForm(@Auth Account account) { public ProfileAvatarUploadAttributes getAvatarUploadForm(@Auth AuthenticatedAccount auth) {
String previousAvatar = account.getAvatar(); String previousAvatar = auth.getAccount().getAvatar();
String objectName = generateAvatarObjectName(); String objectName = generateAvatarObjectName();
ProfileAvatarUploadAttributes profileAvatarUploadAttributes = generateAvatarUploadForm(objectName); ProfileAvatarUploadAttributes profileAvatarUploadAttributes = generateAvatarUploadForm(objectName);
if (previousAvatar != null && previousAvatar.startsWith("profiles/")) { if (previousAvatar != null && previousAvatar.startsWith("profiles/")) {
s3client.deleteObject(DeleteObjectRequest.builder() s3client.deleteObject(DeleteObjectRequest.builder()
.bucket(bucket) .bucket(bucket)
.key(previousAvatar) .key(previousAvatar)
.build()); .build());
} }
accountsManager.update(account, a -> a.setAvatar(objectName)); accountsManager.update(auth.getAccount(), a -> a.setAvatar(objectName));
return profileAvatarUploadAttributes; return profileAvatarUploadAttributes;
} }

View File

@ -17,10 +17,10 @@ import javax.ws.rs.Produces;
import javax.ws.rs.WebApplicationException; import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.ProvisioningMessage; import org.whispersystems.textsecuregcm.entities.ProvisioningMessage;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.ProvisioningManager; import org.whispersystems.textsecuregcm.push.ProvisioningManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress; import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
@Path("/v1/provisioning") @Path("/v1/provisioning")
@ -39,16 +39,15 @@ public class ProvisioningController {
@PUT @PUT
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public void sendProvisioningMessage(@Auth Account source, public void sendProvisioningMessage(@Auth AuthenticatedAccount auth,
@PathParam("destination") String destinationName, @PathParam("destination") String destinationName,
@Valid ProvisioningMessage message) @Valid ProvisioningMessage message)
throws RateLimitExceededException { throws RateLimitExceededException {
rateLimiters.getMessagesLimiter().validate(source.getUuid()); rateLimiters.getMessagesLimiter().validate(auth.getAccount().getUuid());
if (!provisioningManager.sendProvisioningMessage(new ProvisioningAddress(destinationName, 0), if (!provisioningManager.sendProvisioningMessage(new ProvisioningAddress(destinationName, 0),
Base64.getDecoder().decode(message.getBody()))) Base64.getDecoder().decode(message.getBody()))) {
{
throw new WebApplicationException(Response.Status.NOT_FOUND); throw new WebApplicationException(Response.Status.NOT_FOUND);
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -8,13 +8,16 @@ package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed; import com.codahale.metrics.annotation.Timed;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.auth.Auth; import io.dropwizard.auth.Auth;
import org.whispersystems.textsecuregcm.entities.UserRemoteConfig; import java.nio.ByteBuffer;
import org.whispersystems.textsecuregcm.entities.UserRemoteConfigList; import java.nio.charset.StandardCharsets;
import org.whispersystems.textsecuregcm.storage.Account; import java.security.MessageDigest;
import org.whispersystems.textsecuregcm.storage.RemoteConfig; import java.security.NoSuchAlgorithmException;
import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager; import java.util.List;
import org.whispersystems.textsecuregcm.util.Conversions; import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.validation.Valid; import javax.validation.Valid;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE; import javax.ws.rs.DELETE;
@ -27,16 +30,12 @@ import javax.ws.rs.Produces;
import javax.ws.rs.WebApplicationException; import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import java.nio.ByteBuffer; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import java.nio.charset.StandardCharsets; import org.whispersystems.textsecuregcm.entities.UserRemoteConfig;
import java.security.MessageDigest; import org.whispersystems.textsecuregcm.entities.UserRemoteConfigList;
import java.security.NoSuchAlgorithmException; import org.whispersystems.textsecuregcm.storage.RemoteConfig;
import java.util.List; import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager;
import java.util.Map; import org.whispersystems.textsecuregcm.util.Conversions;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@Path("/v1/config") @Path("/v1/config")
public class RemoteConfigController { public class RemoteConfigController {
@ -57,15 +56,19 @@ public class RemoteConfigController {
@GET @GET
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public UserRemoteConfigList getAll(@Auth Account account) { public UserRemoteConfigList getAll(@Auth AuthenticatedAccount auth) {
try { try {
MessageDigest digest = MessageDigest.getInstance("SHA1"); MessageDigest digest = MessageDigest.getInstance("SHA1");
final Stream<UserRemoteConfig> globalConfigStream = globalConfig.entrySet().stream().map(entry -> new UserRemoteConfig(GLOBAL_CONFIG_PREFIX + entry.getKey(), true, entry.getValue())); final Stream<UserRemoteConfig> globalConfigStream = globalConfig.entrySet().stream()
.map(entry -> new UserRemoteConfig(GLOBAL_CONFIG_PREFIX + entry.getKey(), true, entry.getValue()));
return new UserRemoteConfigList(Stream.concat(remoteConfigsManager.getAll().stream().map(config -> { return new UserRemoteConfigList(Stream.concat(remoteConfigsManager.getAll().stream().map(config -> {
final byte[] hashKey = config.getHashKey() != null ? config.getHashKey().getBytes(StandardCharsets.UTF_8) : config.getName().getBytes(StandardCharsets.UTF_8); final byte[] hashKey = config.getHashKey() != null ? config.getHashKey().getBytes(StandardCharsets.UTF_8)
boolean inBucket = isInBucket(digest, account.getUuid(), hashKey, config.getPercentage(), config.getUuids()); : config.getName().getBytes(StandardCharsets.UTF_8);
return new UserRemoteConfig(config.getName(), inBucket, inBucket ? config.getValue() : config.getDefaultValue()); boolean inBucket = isInBucket(digest, auth.getAccount().getUuid(), hashKey, config.getPercentage(),
config.getUuids());
return new UserRemoteConfig(config.getName(), inBucket,
inBucket ? config.getValue() : config.getDefaultValue());
}), globalConfigStream).collect(Collectors.toList())); }), globalConfigStream).collect(Collectors.toList()));
} catch (NoSuchAlgorithmException e) { } catch (NoSuchAlgorithmException e) {
throw new AssertionError(e); throw new AssertionError(e);

View File

@ -1,21 +1,19 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed; import com.codahale.metrics.annotation.Timed;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; import io.dropwizard.auth.Auth;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.storage.Account;
import javax.ws.rs.GET; import javax.ws.rs.GET;
import javax.ws.rs.Path; import javax.ws.rs.Path;
import javax.ws.rs.Produces; import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import io.dropwizard.auth.Auth; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
@Path("/v1/backup") @Path("/v1/backup")
public class SecureBackupController { public class SecureBackupController {
@ -30,7 +28,7 @@ public class SecureBackupController {
@GET @GET
@Path("/auth") @Path("/auth")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public ExternalServiceCredentials getAuth(@Auth Account account) { public ExternalServiceCredentials getAuth(@Auth AuthenticatedAccount auth) {
return backupServiceCredentialGenerator.generateFor(account.getUuid().toString()); return backupServiceCredentialGenerator.generateFor(auth.getAccount().getUuid().toString());
} }
} }

View File

@ -1,21 +1,19 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed; import com.codahale.metrics.annotation.Timed;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; import io.dropwizard.auth.Auth;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.storage.Account;
import javax.ws.rs.GET; import javax.ws.rs.GET;
import javax.ws.rs.Path; import javax.ws.rs.Path;
import javax.ws.rs.Produces; import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import io.dropwizard.auth.Auth; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
@Path("/v1/storage") @Path("/v1/storage")
public class SecureStorageController { public class SecureStorageController {
@ -30,7 +28,7 @@ public class SecureStorageController {
@GET @GET
@Path("/auth") @Path("/auth")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
public ExternalServiceCredentials getAuth(@Auth Account account) { public ExternalServiceCredentials getAuth(@Auth AuthenticatedAccount auth) {
return storageServiceCredentialGenerator.generateFor(account.getUuid().toString()); return storageServiceCredentialGenerator.generateFor(auth.getAccount().getUuid().toString());
} }
} }

View File

@ -1,21 +1,16 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.controllers; package org.whispersystems.textsecuregcm.controllers;
import io.dropwizard.auth.Auth; import io.dropwizard.auth.Auth;
import org.whispersystems.textsecuregcm.entities.StickerPackFormUploadAttributes; import java.security.SecureRandom;
import org.whispersystems.textsecuregcm.entities.StickerPackFormUploadAttributes.StickerPackFormUploadItem; import java.time.ZoneOffset;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import java.time.ZonedDateTime;
import org.whispersystems.textsecuregcm.s3.PolicySigner; import java.util.LinkedList;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator; import java.util.List;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Hex;
import org.whispersystems.textsecuregcm.util.Pair;
import javax.validation.constraints.Max; import javax.validation.constraints.Max;
import javax.validation.constraints.Min; import javax.validation.constraints.Min;
import javax.ws.rs.GET; import javax.ws.rs.GET;
@ -23,11 +18,15 @@ import javax.ws.rs.Path;
import javax.ws.rs.PathParam; import javax.ws.rs.PathParam;
import javax.ws.rs.Produces; import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import java.security.SecureRandom; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import java.time.ZoneOffset; import org.whispersystems.textsecuregcm.entities.StickerPackFormUploadAttributes;
import java.time.ZonedDateTime; import org.whispersystems.textsecuregcm.entities.StickerPackFormUploadAttributes.StickerPackFormUploadItem;
import java.util.LinkedList; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import java.util.List; import org.whispersystems.textsecuregcm.s3.PolicySigner;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Hex;
import org.whispersystems.textsecuregcm.util.Pair;
@Path("/v1/sticker") @Path("/v1/sticker")
public class StickerController { public class StickerController {
@ -45,30 +44,31 @@ public class StickerController {
@GET @GET
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@Path("/pack/form/{count}") @Path("/pack/form/{count}")
public StickerPackFormUploadAttributes getStickersForm(@Auth Account account, public StickerPackFormUploadAttributes getStickersForm(@Auth AuthenticatedAccount auth,
@PathParam("count") @Min(1) @Max(201) int stickerCount) @PathParam("count") @Min(1) @Max(201) int stickerCount)
throws RateLimitExceededException throws RateLimitExceededException {
{ rateLimiters.getStickerPackLimiter().validate(auth.getAccount().getUuid());
rateLimiters.getStickerPackLimiter().validate(account.getUuid());
ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC);
String packId = generatePackId();
String packLocation = "stickers/" + packId;
String manifestKey = packLocation + "/manifest.proto";
Pair<String, String> manifestPolicy = policyGenerator.createFor(now, manifestKey, Constants.MAXIMUM_STICKER_MANIFEST_SIZE_BYTES);
String manifestSignature = policySigner.getSignature(now, manifestPolicy.second());
StickerPackFormUploadItem manifest = new StickerPackFormUploadItem(-1, manifestKey, manifestPolicy.first(), "private", "AWS4-HMAC-SHA256",
now.format(PostPolicyGenerator.AWS_DATE_TIME), manifestPolicy.second(), manifestSignature);
ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC);
String packId = generatePackId();
String packLocation = "stickers/" + packId;
String manifestKey = packLocation + "/manifest.proto";
Pair<String, String> manifestPolicy = policyGenerator.createFor(now, manifestKey,
Constants.MAXIMUM_STICKER_MANIFEST_SIZE_BYTES);
String manifestSignature = policySigner.getSignature(now, manifestPolicy.second());
StickerPackFormUploadItem manifest = new StickerPackFormUploadItem(-1, manifestKey, manifestPolicy.first(),
"private", "AWS4-HMAC-SHA256",
now.format(PostPolicyGenerator.AWS_DATE_TIME), manifestPolicy.second(), manifestSignature);
List<StickerPackFormUploadItem> stickers = new LinkedList<>(); List<StickerPackFormUploadItem> stickers = new LinkedList<>();
for (int i=0;i<stickerCount;i++) { for (int i = 0; i < stickerCount; i++) {
String stickerKey = packLocation + "/full/" + i; String stickerKey = packLocation + "/full/" + i;
Pair<String, String> stickerPolicy = policyGenerator.createFor(now, stickerKey, Constants.MAXIMUM_STICKER_SIZE_BYTES); Pair<String, String> stickerPolicy = policyGenerator.createFor(now, stickerKey,
String stickerSignature = policySigner.getSignature(now, stickerPolicy.second()); Constants.MAXIMUM_STICKER_SIZE_BYTES);
String stickerSignature = policySigner.getSignature(now, stickerPolicy.second());
stickers.add(new StickerPackFormUploadItem(i, stickerKey, stickerPolicy.first(), "private", "AWS4-HMAC-SHA256", stickers.add(new StickerPackFormUploadItem(i, stickerKey, stickerPolicy.first(), "private", "AWS4-HMAC-SHA256",
now.format(PostPolicyGenerator.AWS_DATE_TIME), stickerPolicy.second(), stickerSignature)); now.format(PostPolicyGenerator.AWS_DATE_TIME), stickerPolicy.second(), stickerSignature));
} }
return new StickerPackFormUploadAttributes(packId, manifest, stickers); return new StickerPackFormUploadAttributes(packId, manifest, stickers);

View File

@ -1,20 +1,20 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.push; package org.whispersystems.textsecuregcm.push;
import java.util.Optional;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.controllers.NoSuchUserException; import org.whispersystems.textsecuregcm.controllers.NoSuchUserException;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import java.util.Optional;
public class ReceiptSender { public class ReceiptSender {
private final MessageSender messageSender; private final MessageSender messageSender;
@ -23,30 +23,29 @@ public class ReceiptSender {
private static final Logger logger = LoggerFactory.getLogger(ReceiptSender.class); private static final Logger logger = LoggerFactory.getLogger(ReceiptSender.class);
public ReceiptSender(AccountsManager accountManager, public ReceiptSender(AccountsManager accountManager,
MessageSender messageSender) MessageSender messageSender) {
{
this.accountManager = accountManager; this.accountManager = accountManager;
this.messageSender = messageSender; this.messageSender = messageSender;
} }
public void sendReceipt(Account source, String destination, long messageId) public void sendReceipt(AuthenticatedAccount source, String destination, long messageId)
throws NoSuchUserException throws NoSuchUserException {
{ final Account sourceAccount = source.getAccount();
if (source.getNumber().equals(destination)) { if (sourceAccount.getNumber().equals(destination)) {
return; return;
} }
Account destinationAccount = getDestinationAccount(destination); Account destinationAccount = getDestinationAccount(destination);
Envelope.Builder message = Envelope.newBuilder() Envelope.Builder message = Envelope.newBuilder()
.setServerTimestamp(System.currentTimeMillis()) .setServerTimestamp(System.currentTimeMillis())
.setSource(source.getNumber()) .setSource(sourceAccount.getNumber())
.setSourceUuid(source.getUuid().toString()) .setSourceUuid(sourceAccount.getUuid().toString())
.setSourceDevice((int) source.getAuthenticatedDevice().get().getId()) .setSourceDevice((int) source.getAuthenticatedDevice().getId())
.setTimestamp(messageId) .setTimestamp(messageId)
.setType(Envelope.Type.SERVER_DELIVERY_RECEIPT); .setType(Envelope.Type.SERVER_DELIVERY_RECEIPT);
if (source.getRelay().isPresent()) { if (sourceAccount.getRelay().isPresent()) {
message.setRelay(source.getRelay().get()); message.setRelay(sourceAccount.getRelay().get());
} }
for (final Device destinationDevice : destinationAccount.getDevices()) { for (final Device destinationDevice : destinationAccount.getDevices()) {
@ -63,7 +62,7 @@ public class ReceiptSender {
{ {
Optional<Account> account = accountManager.get(destination); Optional<Account> account = accountManager.get(destination);
if (!account.isPresent()) { if (account.isEmpty()) {
throw new NoSuchUserException(destination); throw new NoSuchUserException(destination);
} }

View File

@ -8,12 +8,10 @@ package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import java.security.Principal;
import java.util.HashSet; import java.util.HashSet;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
import javax.security.auth.Subject;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
@ -22,7 +20,7 @@ import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
public class Account implements Principal { public class Account {
@JsonIgnore @JsonIgnore
private static final Logger logger = LoggerFactory.getLogger(Account.class); private static final Logger logger = LoggerFactory.getLogger(Account.class);
@ -63,9 +61,6 @@ public class Account implements Principal {
@JsonProperty("inCds") @JsonProperty("inCds")
private boolean discoverableByPhoneNumber = true; private boolean discoverableByPhoneNumber = true;
@JsonIgnore
private Device authenticatedDevice;
@JsonProperty @JsonProperty
private int version; private int version;
@ -82,18 +77,6 @@ public class Account implements Principal {
this.unidentifiedAccessKey = unidentifiedAccessKey; this.unidentifiedAccessKey = unidentifiedAccessKey;
} }
public Optional<Device> getAuthenticatedDevice() {
requireNotStale();
return Optional.ofNullable(authenticatedDevice);
}
public void setAuthenticatedDevice(Device device) {
requireNotStale();
this.authenticatedDevice = device;
}
public UUID getUuid() { public UUID getUuid() {
// this is the one method that may be called on a stale account // this is the one method that may be called on a stale account
return uuid; return uuid;
@ -390,6 +373,10 @@ public class Account implements Principal {
this.version = version; this.version = version;
} }
boolean isStale() {
return stale;
}
public void markStale() { public void markStale() {
stale = true; stale = true;
} }
@ -403,17 +390,4 @@ public class Account implements Principal {
} }
} }
// Principal implementation
@Override
@JsonIgnore
public String getName() {
return null;
}
@Override
@JsonIgnore
public boolean implies(Subject subject) {
return false;
}
} }

View File

@ -0,0 +1,14 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
public class RefreshingAccountAndDeviceNotFoundException extends RuntimeException {
public RefreshingAccountAndDeviceNotFoundException(final String message) {
super(message);
}
}

View File

@ -0,0 +1,35 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import java.util.function.Supplier;
import org.whispersystems.textsecuregcm.util.Pair;
public class RefreshingAccountAndDeviceSupplier implements Supplier<Pair<Account, Device>> {
private Account account;
private Device device;
private final AccountsManager accountsManager;
public RefreshingAccountAndDeviceSupplier(Account account, long deviceId, AccountsManager accountsManager) {
this.account = account;
this.device = account.getDevice(deviceId)
.orElseThrow(() -> new RefreshingAccountAndDeviceNotFoundException("Could not find device"));
this.accountsManager = accountsManager;
}
@Override
public Pair<Account, Device> get() {
if (account.isStale()) {
account = accountsManager.get(account.getUuid())
.orElseThrow(() -> new RuntimeException("Could not find account"));
device = account.getDevice(device.getId())
.orElseThrow(() -> new RefreshingAccountAndDeviceNotFoundException("Could not find device"));
}
return new Pair<>(account, device);
}
}

View File

@ -1,32 +1,31 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.websocket; package org.whispersystems.textsecuregcm.websocket;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.Counter; import com.codahale.metrics.Counter;
import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer; import com.codahale.metrics.Timer;
import java.util.concurrent.ScheduledExecutorService;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.MessageSender; import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.websocket.session.WebSocketSessionContext; import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener; import org.whispersystems.websocket.setup.WebSocketConnectListener;
import java.util.concurrent.ScheduledExecutorService;
import static com.codahale.metrics.MetricRegistry.name;
public class AuthenticatedConnectListener implements WebSocketConnectListener { public class AuthenticatedConnectListener implements WebSocketConnectListener {
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
@ -60,16 +59,16 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
@Override @Override
public void onWebSocketConnect(WebSocketSessionContext context) { public void onWebSocketConnect(WebSocketSessionContext context) {
if (context.getAuthenticated() != null) { if (context.getAuthenticated() != null) {
final Account account = context.getAuthenticated(Account.class); final AuthenticatedAccount auth = context.getAuthenticated(AuthenticatedAccount.class);
final Device device = account.getAuthenticatedDevice().get(); final Device device = auth.getAuthenticatedDevice();
final Timer.Context timer = durationTimer.time(); final Timer.Context timer = durationTimer.time();
final WebSocketConnection connection = new WebSocketConnection(receiptSender, final WebSocketConnection connection = new WebSocketConnection(receiptSender,
messagesManager, account, device, messagesManager, auth, device,
context.getClient(), context.getClient(),
retrySchedulingExecutor); retrySchedulingExecutor);
openWebsocketCounter.inc(); openWebsocketCounter.inc();
RedisOperation.unchecked(() -> apnFallbackManager.cancel(account, device)); RedisOperation.unchecked(() -> apnFallbackManager.cancel(auth.getAccount(), device));
context.addListener(new WebSocketSessionContext.WebSocketEventListener() { context.addListener(new WebSocketSessionContext.WebSocketEventListener() {
@Override @Override
@ -79,20 +78,21 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
connection.stop(); connection.stop();
RedisOperation.unchecked(() -> clientPresenceManager.clearPresence(account.getUuid(), device.getId())); RedisOperation.unchecked(
() -> clientPresenceManager.clearPresence(auth.getAccount().getUuid(), device.getId()));
RedisOperation.unchecked(() -> { RedisOperation.unchecked(() -> {
messagesManager.removeMessageAvailabilityListener(connection); messagesManager.removeMessageAvailabilityListener(connection);
if (messagesManager.hasCachedMessages(account.getUuid(), device.getId())) { if (messagesManager.hasCachedMessages(auth.getAccount().getUuid(), device.getId())) {
messageSender.sendNewMessageNotification(account, device); messageSender.sendNewMessageNotification(auth.getAccount(), device);
} }
}); });
} }
}); });
try { try {
clientPresenceManager.setPresent(account.getUuid(), device.getId(), connection); clientPresenceManager.setPresent(auth.getAccount().getUuid(), device.getId(), connection);
messagesManager.addMessageAvailabilityListener(account.getUuid(), device.getId(), connection); messagesManager.addMessageAvailabilityListener(auth.getAccount().getUuid(), device.getId(), connection);
connection.start(); connection.start();
} catch (final Exception e) { } catch (final Exception e) {
log.warn("Failed to initialize websocket", e); log.warn("Failed to initialize websocket", e);

View File

@ -1,23 +1,21 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.websocket; package org.whispersystems.textsecuregcm.websocket;
import org.eclipse.jetty.websocket.api.UpgradeRequest; import io.dropwizard.auth.basic.BasicCredentials;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import io.dropwizard.auth.basic.BasicCredentials; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Account> { public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<AuthenticatedAccount> {
private final AccountAuthenticator accountAuthenticator; private final AccountAuthenticator accountAuthenticator;
@ -26,19 +24,18 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Acc
} }
@Override @Override
public AuthenticationResult<Account> authenticate(UpgradeRequest request) { public AuthenticationResult<AuthenticatedAccount> authenticate(UpgradeRequest request) {
Map<String, List<String>> parameters = request.getParameterMap(); Map<String, List<String>> parameters = request.getParameterMap();
List<String> usernames = parameters.get("login"); List<String> usernames = parameters.get("login");
List<String> passwords = parameters.get("password"); List<String> passwords = parameters.get("password");
if (usernames == null || usernames.size() == 0 || if (usernames == null || usernames.size() == 0 ||
passwords == null || passwords.size() == 0) passwords == null || passwords.size() == 0) {
{
return new AuthenticationResult<>(Optional.empty(), false); return new AuthenticationResult<>(Optional.empty(), false);
} }
BasicCredentials credentials = new BasicCredentials(usernames.get(0).replace(" ", "+"), BasicCredentials credentials = new BasicCredentials(usernames.get(0).replace(" ", "+"),
passwords.get(0).replace(" ", "+")); passwords.get(0).replace(" ", "+"));
return new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true); return new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true);
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -36,6 +36,7 @@ import javax.ws.rs.WebApplicationException;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.controllers.MessageController; import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.NoSuchUserException; import org.whispersystems.textsecuregcm.controllers.NoSuchUserException;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
@ -43,7 +44,6 @@ import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.DisplacedPresenceListener; import org.whispersystems.textsecuregcm.push.DisplacedPresenceListener;
import org.whispersystems.textsecuregcm.push.ReceiptSender; import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessageAvailabilityListener; import org.whispersystems.textsecuregcm.storage.MessageAvailabilityListener;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
@ -90,21 +90,22 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class); private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class);
private final ReceiptSender receiptSender; private final ReceiptSender receiptSender;
private final MessagesManager messagesManager; private final MessagesManager messagesManager;
private final Account account; private final AuthenticatedAccount auth;
private final Device device; private final Device device;
private final WebSocketClient client; private final WebSocketClient client;
private final ScheduledExecutorService retrySchedulingExecutor; private final ScheduledExecutorService retrySchedulingExecutor;
private final boolean isDesktopClient; private final boolean isDesktopClient;
private final Semaphore processStoredMessagesSemaphore = new Semaphore(1); private final Semaphore processStoredMessagesSemaphore = new Semaphore(1);
private final AtomicReference<StoredMessageState> storedMessageState = new AtomicReference<>(StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE); private final AtomicReference<StoredMessageState> storedMessageState = new AtomicReference<>(
private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false); StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE);
private final LongAdder sentMessageCounter = new LongAdder(); private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false);
private final AtomicLong queueDrainStartTime = new AtomicLong(); private final LongAdder sentMessageCounter = new LongAdder();
private final AtomicLong queueDrainStartTime = new AtomicLong();
private final AtomicInteger consecutiveRetries = new AtomicInteger(); private final AtomicInteger consecutiveRetries = new AtomicInteger();
private final AtomicReference<ScheduledFuture<?>> retryFuture = new AtomicReference<>(); private final AtomicReference<ScheduledFuture<?>> retryFuture = new AtomicReference<>();
@ -118,16 +119,15 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
public WebSocketConnection(ReceiptSender receiptSender, public WebSocketConnection(ReceiptSender receiptSender,
MessagesManager messagesManager, MessagesManager messagesManager,
Account account, AuthenticatedAccount auth,
Device device, Device device,
WebSocketClient client, WebSocketClient client,
ScheduledExecutorService retrySchedulingExecutor) ScheduledExecutorService retrySchedulingExecutor) {
{ this.receiptSender = receiptSender;
this.receiptSender = receiptSender;
this.messagesManager = messagesManager; this.messagesManager = messagesManager;
this.account = account; this.auth = auth;
this.device = device; this.device = device;
this.client = client; this.client = client;
this.retrySchedulingExecutor = retrySchedulingExecutor; this.retrySchedulingExecutor = retrySchedulingExecutor;
Optional<ClientPlatform> maybePlatform; Optional<ClientPlatform> maybePlatform;
@ -168,7 +168,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
if (throwable == null) { if (throwable == null) {
if (isSuccessResponse(response)) { if (isSuccessResponse(response)) {
if (storedMessageInfo.isPresent()) { if (storedMessageInfo.isPresent()) {
messagesManager.delete(account.getUuid(), device.getId(), storedMessageInfo.get().getGuid()); messagesManager.delete(auth.getAccount().getUuid(), device.getId(), storedMessageInfo.get().getGuid());
} }
if (message.getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT) { if (message.getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT) {
@ -204,7 +204,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
if (!message.hasSource()) return; if (!message.hasSource()) return;
try { try {
receiptSender.sendReceipt(account, message.getSource(), message.getTimestamp()); receiptSender.sendReceipt(auth, message.getSource(), message.getTimestamp());
} catch (NoSuchUserException e) { } catch (NoSuchUserException e) {
logger.info("No longer registered " + e.getMessage()); logger.info("No longer registered " + e.getMessage());
} catch (WebApplicationException e) { } catch (WebApplicationException e) {
@ -267,7 +267,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
private void sendNextMessagePage(final boolean cachedMessagesOnly, final CompletableFuture<Void> queueClearedFuture) { private void sendNextMessagePage(final boolean cachedMessagesOnly, final CompletableFuture<Void> queueClearedFuture) {
try { try {
final OutgoingMessageEntityList messages = messagesManager final OutgoingMessageEntityList messages = messagesManager
.getMessagesForDevice(account.getUuid(), device.getId(), client.getUserAgent(), cachedMessagesOnly); .getMessagesForDevice(auth.getAccount().getUuid(), device.getId(), client.getUserAgent(), cachedMessagesOnly);
final CompletableFuture<?>[] sendFutures = new CompletableFuture[messages.getMessages().size()]; final CompletableFuture<?>[] sendFutures = new CompletableFuture[messages.getMessages().size()];
@ -303,7 +303,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
final Envelope envelope = builder.build(); final Envelope envelope = builder.build();
if (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient) { if (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient) {
messagesManager.delete(account.getUuid(), device.getId(), message.getGuid()); messagesManager.delete(auth.getAccount().getUuid(), device.getId(), message.getGuid());
discardedMessagesMeter.mark(); discardedMessagesMeter.mark();
sendFutures[i] = CompletableFuture.completedFuture(null); sendFutures[i] = CompletableFuture.completedFuture(null);
@ -340,7 +340,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
public void handleNewEphemeralMessageAvailable() { public void handleNewEphemeralMessageAvailable() {
ephemeralMessageAvailableMeter.mark(); ephemeralMessageAvailableMeter.mark();
messagesManager.takeEphemeralMessage(account.getUuid(), device.getId()) messagesManager.takeEphemeralMessage(auth.getAccount().getUuid(), device.getId())
.ifPresent(message -> sendMessage(message, Optional.empty())); .ifPresent(message -> sendMessage(message, Optional.empty()));
} }

View File

@ -24,11 +24,11 @@ import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager; import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.mappers.RetryLaterExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RetryLaterExceptionMapper;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -41,7 +41,8 @@ class ChallengeControllerTest {
private static final ResourceExtension EXTENSION = ResourceExtension.builder() private static final ResourceExtension EXTENSION = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(Set.of(Account.class, DisabledPermittedAccount.class))) .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
Set.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setMapper(SystemMapper.getMapper()) .setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new RetryLaterExceptionMapper()) .addResource(new RetryLaterExceptionMapper())

View File

@ -0,0 +1,71 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.util.Pair;
class RefreshingAccountAndDeviceSupplierTest {
@Test
void test() {
final AccountsManager accountsManager = mock(AccountsManager.class);
final UUID uuid = UUID.randomUUID();
final long deviceId = 2L;
final Account initialAccount = mock(Account.class);
final Device initialDevice = mock(Device.class);
when(initialAccount.getUuid()).thenReturn(uuid);
when(initialDevice.getId()).thenReturn(deviceId);
when(initialAccount.getDevice(deviceId)).thenReturn(Optional.of(initialDevice));
when(accountsManager.get(any(UUID.class))).thenAnswer(answer -> {
final Account account = mock(Account.class);
final Device device = mock(Device.class);
when(account.getUuid()).thenReturn(answer.getArgument(0, UUID.class));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
return Optional.of(account);
});
final RefreshingAccountAndDeviceSupplier refreshingAccountAndDeviceSupplier = new RefreshingAccountAndDeviceSupplier(
initialAccount, deviceId, accountsManager);
Pair<Account, Device> accountAndDevice = refreshingAccountAndDeviceSupplier.get();
assertSame(initialAccount, accountAndDevice.first());
assertSame(initialDevice, accountAndDevice.second());
accountAndDevice = refreshingAccountAndDeviceSupplier.get();
assertSame(initialAccount, accountAndDevice.first());
assertSame(initialDevice, accountAndDevice.second());
when(initialAccount.isStale()).thenReturn(true);
accountAndDevice = refreshingAccountAndDeviceSupplier.get();
assertNotSame(initialAccount, accountAndDevice.first());
assertNotSame(initialDevice, accountAndDevice.second());
assertEquals(uuid, accountAndDevice.first().getUuid());
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -52,8 +52,9 @@ import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials; import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock; import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode; import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
@ -145,27 +146,29 @@ class AccountControllerTest {
private static ExternalServiceCredentialGenerator storageCredentialGenerator = new ExternalServiceCredentialGenerator(new byte[32], new byte[32], false); private static ExternalServiceCredentialGenerator storageCredentialGenerator = new ExternalServiceCredentialGenerator(new byte[32], new byte[32], false);
private static final ResourceExtension resources = ResourceExtension.builder() private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( .addProvider(
ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) new PolymorphicAuthValueFactoryProvider.Binder<>(
.addProvider(new RateLimitExceededExceptionMapper()) ImmutableSet.of(AuthenticatedAccount.class,
.setMapper(SystemMapper.getMapper()) DisabledPermittedAuthenticatedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addProvider(new RateLimitExceededExceptionMapper())
.addResource(new AccountController(pendingAccountsManager, .setMapper(SystemMapper.getMapper())
accountsManager, .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
usernamesManager, .addResource(new AccountController(pendingAccountsManager,
abusiveHostRules, accountsManager,
rateLimiters, usernamesManager,
smsSender, abusiveHostRules,
dynamicConfigurationManager, rateLimiters,
turnTokenGenerator, smsSender,
new HashMap<>(), dynamicConfigurationManager,
recaptchaClient, turnTokenGenerator,
gcmSender, new HashMap<>(),
apnSender, recaptchaClient,
storageCredentialGenerator, gcmSender,
verifyExperimentEnrollmentManager)) apnSender,
.build(); storageCredentialGenerator,
verifyExperimentEnrollmentManager))
.build();
@BeforeEach @BeforeEach

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -33,7 +33,8 @@ import org.assertj.core.api.InstanceOfAssertFactories;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV1; import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV1;
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV2; import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV2;
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV3; import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV3;
@ -43,7 +44,6 @@ import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV3;
import org.whispersystems.textsecuregcm.entities.AttachmentUri; import org.whispersystems.textsecuregcm.entities.AttachmentUri;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -78,8 +78,9 @@ class AttachmentControllerTest {
static { static {
try { try {
resources = ResourceExtension.builder() resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setMapper(SystemMapper.getMapper()) .setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new AttachmentControllerV1(rateLimiters, "accessKey", "accessSecret", "attachment-bucket")) .addResource(new AttachmentControllerV1(rateLimiters, "accessKey", "accessSecret", "attachment-bucket"))

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -28,8 +28,9 @@ import org.signal.zkgroup.auth.AuthCredential;
import org.signal.zkgroup.auth.AuthCredentialResponse; import org.signal.zkgroup.auth.AuthCredentialResponse;
import org.signal.zkgroup.auth.ClientZkAuthOperations; import org.signal.zkgroup.auth.ClientZkAuthOperations;
import org.signal.zkgroup.auth.ServerZkAuthOperations; import org.signal.zkgroup.auth.ServerZkAuthOperations;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.CertificateGenerator; import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.controllers.CertificateController; import org.whispersystems.textsecuregcm.controllers.CertificateController;
import org.whispersystems.textsecuregcm.crypto.Curve; import org.whispersystems.textsecuregcm.crypto.Curve;
@ -37,7 +38,6 @@ import org.whispersystems.textsecuregcm.entities.DeliveryCertificate;
import org.whispersystems.textsecuregcm.entities.GroupCredentials; import org.whispersystems.textsecuregcm.entities.GroupCredentials;
import org.whispersystems.textsecuregcm.entities.MessageProtos.SenderCertificate; import org.whispersystems.textsecuregcm.entities.MessageProtos.SenderCertificate;
import org.whispersystems.textsecuregcm.entities.MessageProtos.ServerCertificate; import org.whispersystems.textsecuregcm.entities.MessageProtos.ServerCertificate;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
@ -66,12 +66,13 @@ class CertificateControllerTest {
private static final ResourceExtension resources = ResourceExtension.builder() private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
.setMapper(SystemMapper.getMapper()) ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setMapper(SystemMapper.getMapper())
.addResource(new CertificateController(certificateGenerator, serverZkAuthOperations, true)) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.build(); .addResource(new CertificateController(certificateGenerator, serverZkAuthOperations, true))
.build();
@Test @Test
void testValidCertificate() throws Exception { void testValidCertificate() throws Exception {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.tests.controllers; package org.whispersystems.textsecuregcm.tests.controllers;
@ -36,7 +36,8 @@ import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode; import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import org.whispersystems.textsecuregcm.controllers.DeviceController; import org.whispersystems.textsecuregcm.controllers.DeviceController;
import org.whispersystems.textsecuregcm.entities.AccountAttributes; import org.whispersystems.textsecuregcm.entities.AccountAttributes;
@ -88,17 +89,18 @@ class DeviceControllerTest {
private static Map<String, Integer> deviceConfiguration = new HashMap<>(); private static Map<String, Integer> deviceConfiguration = new HashMap<>();
private static final ResourceExtension resources = ResourceExtension.builder() private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.addProvider(new DeviceLimitExceededExceptionMapper()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new DumbVerificationDeviceController(pendingDevicesManager, .addProvider(new DeviceLimitExceededExceptionMapper())
accountsManager, .addResource(new DumbVerificationDeviceController(pendingDevicesManager,
messagesManager, accountsManager,
keys, messagesManager,
rateLimiters, keys,
deviceConfiguration)) rateLimiters,
.build(); deviceConfiguration))
.build();
@BeforeEach @BeforeEach
@ -114,15 +116,14 @@ class DeviceControllerTest {
when(account.getNextDeviceId()).thenReturn(42L); when(account.getNextDeviceId()).thenReturn(42L);
when(account.getNumber()).thenReturn(AuthHelper.VALID_NUMBER); when(account.getNumber()).thenReturn(AuthHelper.VALID_NUMBER);
when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID); when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID);
// when(maxedAccount.getActiveDeviceCount()).thenReturn(6);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(masterDevice));
when(account.isEnabled()).thenReturn(false); when(account.isEnabled()).thenReturn(false);
when(account.isGroupsV2Supported()).thenReturn(true); when(account.isGroupsV2Supported()).thenReturn(true);
when(account.isGv1MigrationSupported()).thenReturn(true); when(account.isGv1MigrationSupported()).thenReturn(true);
when(account.isSenderKeySupported()).thenReturn(true); when(account.isSenderKeySupported()).thenReturn(true);
when(account.isAnnouncementGroupSupported()).thenReturn(true); when(account.isAnnouncementGroupSupported()).thenReturn(true);
when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(new StoredVerificationCode("5678901", System.currentTimeMillis(), null, null))); when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER)).thenReturn(
Optional.of(new StoredVerificationCode("5678901", System.currentTimeMillis(), null, null)));
when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.empty()); when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.empty());
when(accountsManager.get(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(account)); when(accountsManager.get(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(account));
when(accountsManager.get(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(maxedAccount)); when(accountsManager.get(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(maxedAccount));

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -20,14 +20,14 @@ import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status.Family; import javax.ws.rs.core.Response.Status.Family;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.controllers.DirectoryController; import org.whispersystems.textsecuregcm.controllers.DirectoryController;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
@ -37,11 +37,12 @@ class DirectoryControllerTest {
private static final ExternalServiceCredentials validCredentials = new ExternalServiceCredentials("username", "password"); private static final ExternalServiceCredentials validCredentials = new ExternalServiceCredentials("username", "password");
private static final ResourceExtension resources = ResourceExtension.builder() private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.addResource(new DirectoryController(directoryCredentialsGenerator)) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.build(); .addResource(new DirectoryController(directoryCredentialsGenerator))
.build();
@BeforeEach @BeforeEach
void setup() { void setup() {

View File

@ -26,14 +26,14 @@ import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.configuration.DonationConfiguration; import org.whispersystems.textsecuregcm.configuration.DonationConfiguration;
import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; import org.whispersystems.textsecuregcm.configuration.RetryConfiguration;
import org.whispersystems.textsecuregcm.controllers.DonationController; import org.whispersystems.textsecuregcm.controllers.DonationController;
import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationRequest; import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationRequest;
import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationResponse; import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationResponse;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -57,7 +57,8 @@ public class DonationControllerTest {
configuration.setSupportedCurrencies(Set.of("usd", "gbp")); configuration.setSupportedCurrencies(Set.of("usd", "gbp"));
resources = ResourceExtension.builder() resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setMapper(SystemMapper.getMapper()) .setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new DonationController(executor, configuration)) .addResource(new DonationController(executor, configuration))

View File

@ -44,7 +44,8 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher; import org.mockito.ArgumentMatcher;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.controllers.KeysController; import org.whispersystems.textsecuregcm.controllers.KeysController;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
@ -102,8 +103,8 @@ class KeysControllerTest {
private static final ResourceExtension resources = ResourceExtension.builder() private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>( .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(
ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new RateLimitChallengeExceptionMapper(rateLimitChallengeManager)) .addResource(new RateLimitChallengeExceptionMapper(rateLimitChallengeManager))
.addResource(new ServerRejectedExceptionMapper()) .addResource(new ServerRejectedExceptionMapper())

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -68,7 +68,8 @@ import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher; import org.mockito.ArgumentMatcher;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration;
@ -142,15 +143,16 @@ class MessageControllerTest {
private final ObjectMapper mapper = new ObjectMapper(); private final ObjectMapper mapper = new ObjectMapper();
private static final ResourceExtension resources = ResourceExtension.builder() private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
.addProvider(RateLimitExceededExceptionMapper.class) ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.addProvider(new RateLimitChallengeExceptionMapper(rateLimitChallengeManager)) .addProvider(RateLimitExceededExceptionMapper.class)
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addProvider(new RateLimitChallengeExceptionMapper(rateLimitChallengeManager))
.addResource(new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager, .addResource(new MessageController(rateLimiters, messageSender, receiptSender, accountsManager,
rateLimitChallengeManager, reportMessageManager, metricsCluster, receiptExecutor)) messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager,
.build(); rateLimitChallengeManager, reportMessageManager, metricsCluster, receiptExecutor))
.build();
@BeforeEach @BeforeEach
void setup() throws Exception { void setup() throws Exception {
@ -576,7 +578,7 @@ class MessageControllerTest {
.delete(); .delete();
assertThat("Good Response Code", response.getStatus(), is(equalTo(204))); assertThat("Good Response Code", response.getStatus(), is(equalTo(204)));
verify(receiptSender).sendReceipt(any(Account.class), eq("+14152222222"), eq(timestamp)); verify(receiptSender).sendReceipt(any(AuthenticatedAccount.class), eq("+14152222222"), eq(timestamp));
response = resources.getJerseyTest() response = resources.getJerseyTest()
.target(String.format("/v1/messages/%s/%d", "+14152222222", 31338)) .target(String.format("/v1/messages/%s/%d", "+14152222222", 31338))
@ -731,22 +733,54 @@ class MessageControllerTest {
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 3L), Set.of(1L, 3L),
null, null,
null,
false,
null), null),
arguments( arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 2L, 3L), Set.of(1L, 2L, 3L),
null, null,
Set.of(2L)), Set.of(2L),
false,
null),
arguments( arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L), Set.of(1L),
Set.of(3L), Set.of(3L),
null,
false,
null), null),
arguments( arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true), mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 2L), Set.of(1L, 2L),
Set.of(3L), Set.of(3L),
Set.of(2L)) Set.of(2L),
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L),
Set.of(3L),
Set.of(1L),
true,
1L
),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(2L),
Set.of(3L),
Set.of(2L),
true,
1L
),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(3L),
null,
null,
true,
1L
)
); );
} }
@ -756,10 +790,13 @@ class MessageControllerTest {
Account account, Account account,
Set<Long> deviceIds, Set<Long> deviceIds,
Collection<Long> expectedMissingDeviceIds, Collection<Long> expectedMissingDeviceIds,
Collection<Long> expectedExtraDeviceIds) throws Exception { Collection<Long> expectedExtraDeviceIds,
boolean isSyncMessage,
Long authenticatedDeviceId) throws Exception {
if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) { if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) {
final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class, final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class,
() -> MessageController.validateCompleteDeviceList(account, deviceIds, false)); () -> MessageController.validateCompleteDeviceList(account, deviceIds, isSyncMessage,
Optional.ofNullable(authenticatedDeviceId)));
if (expectedMissingDeviceIds != null) { if (expectedMissingDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getMissingDevices()) Assertions.assertThat(mismatchedDevicesException.getMissingDevices())
.hasSameElementsAs(expectedMissingDeviceIds); .hasSameElementsAs(expectedMissingDeviceIds);
@ -768,7 +805,8 @@ class MessageControllerTest {
Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds); Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds);
} }
} else { } else {
MessageController.validateCompleteDeviceList(account, deviceIds, false); MessageController.validateCompleteDeviceList(account, deviceIds, isSyncMessage,
Optional.ofNullable(authenticatedDeviceId));
} }
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -22,14 +22,14 @@ import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.controllers.PaymentsController; import org.whispersystems.textsecuregcm.controllers.PaymentsController;
import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager; import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager;
import org.whispersystems.textsecuregcm.entities.CurrencyConversionEntity; import org.whispersystems.textsecuregcm.entities.CurrencyConversionEntity;
import org.whispersystems.textsecuregcm.entities.CurrencyConversionEntityList; import org.whispersystems.textsecuregcm.entities.CurrencyConversionEntityList;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
@ -41,11 +41,12 @@ class PaymentsControllerTest {
private final ExternalServiceCredentials validCredentials = new ExternalServiceCredentials("username", "password"); private final ExternalServiceCredentials validCredentials = new ExternalServiceCredentials("username", "password");
private static final ResourceExtension resources = ResourceExtension.builder() private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.addResource(new PaymentsController(currencyManager, paymentsCredentialGenerator)) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.build(); .addResource(new PaymentsController(currencyManager, paymentsCredentialGenerator))
.build();
@BeforeEach @BeforeEach

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -41,7 +41,8 @@ import org.signal.zkgroup.profiles.ProfileKey;
import org.signal.zkgroup.profiles.ProfileKeyCommitment; import org.signal.zkgroup.profiles.ProfileKeyCommitment;
import org.signal.zkgroup.profiles.ServerZkProfileOperations; import org.signal.zkgroup.profiles.ServerZkProfileOperations;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicPaymentsConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicPaymentsConfiguration;
import org.whispersystems.textsecuregcm.controllers.ProfileController; import org.whispersystems.textsecuregcm.controllers.ProfileController;
@ -87,22 +88,23 @@ class ProfileControllerTest {
private Account profileAccount; private Account profileAccount;
private static final ResourceExtension resources = ResourceExtension.builder() private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
.setMapper(SystemMapper.getMapper()) ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setMapper(SystemMapper.getMapper())
.addResource(new ProfileController(rateLimiters, .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
accountsManager, .addResource(new ProfileController(rateLimiters,
profilesManager, accountsManager,
usernamesManager, profilesManager,
dynamicConfigurationManager, usernamesManager,
s3client, dynamicConfigurationManager,
postPolicyGenerator, s3client,
policySigner, postPolicyGenerator,
"profilesBucket", policySigner,
zkProfileOperations, "profilesBucket",
true)) zkProfileOperations,
.build(); true))
.build();
@BeforeEach @BeforeEach
void setup() throws Exception { void setup() throws Exception {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -30,17 +30,17 @@ import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.controllers.RemoteConfigController; import org.whispersystems.textsecuregcm.controllers.RemoteConfigController;
import org.whispersystems.textsecuregcm.entities.UserRemoteConfig; import org.whispersystems.textsecuregcm.entities.UserRemoteConfig;
import org.whispersystems.textsecuregcm.entities.UserRemoteConfigList; import org.whispersystems.textsecuregcm.entities.UserRemoteConfigList;
import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.RemoteConfig; import org.whispersystems.textsecuregcm.storage.RemoteConfig;
import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager; import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@ -52,12 +52,13 @@ class RemoteConfigControllerTest {
private static final List<String> remoteConfigsAuth = List.of("foo", "bar"); private static final List<String> remoteConfigsAuth = List.of("foo", "bar");
private static final ResourceExtension resources = ResourceExtension.builder() private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.addProvider(new DeviceLimitExceededExceptionMapper()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new RemoteConfigController(remoteConfigsManager, remoteConfigsAuth, Map.of("maxGroupSize", "42"))) .addProvider(new DeviceLimitExceededExceptionMapper())
.build(); .addResource(new RemoteConfigController(remoteConfigsManager, remoteConfigsAuth, Map.of("maxGroupSize", "42")))
.build();
@BeforeEach @BeforeEach

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -15,11 +15,11 @@ import javax.ws.rs.core.Response;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials; import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.controllers.SecureStorageController; import org.whispersystems.textsecuregcm.controllers.SecureStorageController;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -29,12 +29,13 @@ class SecureStorageControllerTest {
private static final ExternalServiceCredentialGenerator storageCredentialGenerator = new ExternalServiceCredentialGenerator(new byte[32], new byte[32], false); private static final ExternalServiceCredentialGenerator storageCredentialGenerator = new ExternalServiceCredentialGenerator(new byte[32], new byte[32], false);
private static final ResourceExtension resources = ResourceExtension.builder() private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
.setMapper(SystemMapper.getMapper()) ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setMapper(SystemMapper.getMapper())
.addResource(new SecureStorageController(storageCredentialGenerator)) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.build(); .addResource(new SecureStorageController(storageCredentialGenerator))
.build();
@Test @Test

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -21,13 +21,13 @@ import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.controllers.StickerController; import org.whispersystems.textsecuregcm.controllers.StickerController;
import org.whispersystems.textsecuregcm.entities.StickerPackFormUploadAttributes; import org.whispersystems.textsecuregcm.entities.StickerPackFormUploadAttributes;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -38,12 +38,13 @@ class StickerControllerTest {
private static final RateLimiters rateLimiters = mock(RateLimiters.class); private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final ResourceExtension resources = ResourceExtension.builder() private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
.setMapper(SystemMapper.getMapper()) ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setMapper(SystemMapper.getMapper())
.addResource(new StickerController(rateLimiters, "foo", "bar", "us-east-1", "mybucket")) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.build(); .addResource(new StickerController(rateLimiters, "foo", "bar", "us-east-1", "mybucket"))
.build();
@BeforeEach @BeforeEach
void setup() { void setup() {

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -18,10 +18,10 @@ import javax.ws.rs.core.Response;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.controllers.VoiceVerificationController; import org.whispersystems.textsecuregcm.controllers.VoiceVerificationController;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -29,14 +29,15 @@ import org.whispersystems.textsecuregcm.util.SystemMapper;
class VoiceVerificationControllerTest { class VoiceVerificationControllerTest {
private static final ResourceExtension resources = ResourceExtension.builder() private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class))) .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
.addProvider(new RateLimitExceededExceptionMapper()) ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setMapper(SystemMapper.getMapper()) .addProvider(new RateLimitExceededExceptionMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setMapper(SystemMapper.getMapper())
.addResource(new VoiceVerificationController("https://foo.com/bar", .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
new HashSet<>(Arrays.asList("pt-BR", "ru")))) .addResource(new VoiceVerificationController("https://foo.com/bar",
.build(); new HashSet<>(Arrays.asList("pt-BR", "ru"))))
.build();
@Test @Test
void testTwimlLocale() { void testTwimlLocale() {

View File

@ -113,10 +113,6 @@ public class AccountsHelper {
when(updatedAccount.getMasterDevice()).thenAnswer(stubbing); when(updatedAccount.getMasterDevice()).thenAnswer(stubbing);
break; break;
} }
case "getAuthenticatedDevice": {
when(updatedAccount.getAuthenticatedDevice()).thenAnswer(stubbing);
break;
}
case "isEnabled": { case "isEnabled": {
when(updatedAccount.isEnabled()).thenAnswer(stubbing); when(updatedAccount.isEnabled()).thenAnswer(stubbing);
break; break;

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -24,8 +24,9 @@ import java.util.UUID;
import org.mockito.ArgumentMatcher; import org.mockito.ArgumentMatcher;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier; import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials; import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
@ -121,11 +122,6 @@ public class AuthHelper {
when(UNDISCOVERABLE_ACCOUNT.getNumber()).thenReturn(UNDISCOVERABLE_NUMBER); when(UNDISCOVERABLE_ACCOUNT.getNumber()).thenReturn(UNDISCOVERABLE_NUMBER);
when(UNDISCOVERABLE_ACCOUNT.getUuid()).thenReturn(UNDISCOVERABLE_UUID); when(UNDISCOVERABLE_ACCOUNT.getUuid()).thenReturn(UNDISCOVERABLE_UUID);
when(VALID_ACCOUNT.getAuthenticatedDevice()).thenReturn(Optional.of(VALID_DEVICE));
when(VALID_ACCOUNT_TWO.getAuthenticatedDevice()).thenReturn(Optional.of(VALID_DEVICE_TWO));
when(DISABLED_ACCOUNT.getAuthenticatedDevice()).thenReturn(Optional.of(DISABLED_DEVICE));
when(UNDISCOVERABLE_ACCOUNT.getAuthenticatedDevice()).thenReturn(Optional.of(UNDISCOVERABLE_DEVICE));
when(VALID_ACCOUNT.getRelay()).thenReturn(Optional.empty()); when(VALID_ACCOUNT.getRelay()).thenReturn(Optional.empty());
when(VALID_ACCOUNT_TWO.getRelay()).thenReturn(Optional.empty()); when(VALID_ACCOUNT_TWO.getRelay()).thenReturn(Optional.empty());
when(UNDISCOVERABLE_ACCOUNT.getRelay()).thenReturn(Optional.empty()); when(UNDISCOVERABLE_ACCOUNT.getRelay()).thenReturn(Optional.empty());
@ -151,18 +147,30 @@ public class AuthHelper {
when(ACCOUNTS_MANAGER.get(VALID_NUMBER_TWO)).thenReturn(Optional.of(VALID_ACCOUNT_TWO)); when(ACCOUNTS_MANAGER.get(VALID_NUMBER_TWO)).thenReturn(Optional.of(VALID_ACCOUNT_TWO));
when(ACCOUNTS_MANAGER.get(VALID_UUID_TWO)).thenReturn(Optional.of(VALID_ACCOUNT_TWO)); when(ACCOUNTS_MANAGER.get(VALID_UUID_TWO)).thenReturn(Optional.of(VALID_ACCOUNT_TWO));
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(VALID_NUMBER_TWO)))).thenReturn(Optional.of(VALID_ACCOUNT_TWO)); when(ACCOUNTS_MANAGER.get(argThat(
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid() && identifier.getUuid().equals(VALID_UUID_TWO)))).thenReturn(Optional.of(VALID_ACCOUNT_TWO)); (ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber()
&& identifier.getNumber().equals(VALID_NUMBER_TWO)))).thenReturn(Optional.of(VALID_ACCOUNT_TWO));
when(ACCOUNTS_MANAGER.get(argThat(
(ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid()
&& identifier.getUuid().equals(VALID_UUID_TWO)))).thenReturn(Optional.of(VALID_ACCOUNT_TWO));
when(ACCOUNTS_MANAGER.get(DISABLED_NUMBER)).thenReturn(Optional.of(DISABLED_ACCOUNT)); when(ACCOUNTS_MANAGER.get(DISABLED_NUMBER)).thenReturn(Optional.of(DISABLED_ACCOUNT));
when(ACCOUNTS_MANAGER.get(DISABLED_UUID)).thenReturn(Optional.of(DISABLED_ACCOUNT)); when(ACCOUNTS_MANAGER.get(DISABLED_UUID)).thenReturn(Optional.of(DISABLED_ACCOUNT));
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(DISABLED_NUMBER)))).thenReturn(Optional.of(DISABLED_ACCOUNT)); when(ACCOUNTS_MANAGER.get(argThat(
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid() && identifier.getUuid().equals(DISABLED_UUID)))).thenReturn(Optional.of(DISABLED_ACCOUNT)); (ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber()
&& identifier.getNumber().equals(DISABLED_NUMBER)))).thenReturn(Optional.of(DISABLED_ACCOUNT));
when(ACCOUNTS_MANAGER.get(argThat(
(ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid()
&& identifier.getUuid().equals(DISABLED_UUID)))).thenReturn(Optional.of(DISABLED_ACCOUNT));
when(ACCOUNTS_MANAGER.get(UNDISCOVERABLE_NUMBER)).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT)); when(ACCOUNTS_MANAGER.get(UNDISCOVERABLE_NUMBER)).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT));
when(ACCOUNTS_MANAGER.get(UNDISCOVERABLE_UUID)).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT)); when(ACCOUNTS_MANAGER.get(UNDISCOVERABLE_UUID)).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT));
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(UNDISCOVERABLE_NUMBER)))).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT)); when(ACCOUNTS_MANAGER.get(argThat(
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid() && identifier.getUuid().equals(UNDISCOVERABLE_UUID)))).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT)); (ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber()
&& identifier.getNumber().equals(UNDISCOVERABLE_NUMBER)))).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT));
when(ACCOUNTS_MANAGER.get(argThat(
(ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid()
&& identifier.getUuid().equals(UNDISCOVERABLE_UUID)))).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT));
AccountsHelper.setupMockUpdateForAuthHelper(ACCOUNTS_MANAGER); AccountsHelper.setupMockUpdateForAuthHelper(ACCOUNTS_MANAGER);
@ -170,11 +178,13 @@ public class AuthHelper {
testAccount.setup(ACCOUNTS_MANAGER); testAccount.setup(ACCOUNTS_MANAGER);
} }
AuthFilter<BasicCredentials, Account> accountAuthFilter = new BasicCredentialAuthFilter.Builder<Account>().setAuthenticator(new AccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter (); AuthFilter<BasicCredentials, AuthenticatedAccount> accountAuthFilter = new BasicCredentialAuthFilter.Builder<AuthenticatedAccount>().setAuthenticator(
AuthFilter<BasicCredentials, DisabledPermittedAccount> disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder<DisabledPermittedAccount>().setAuthenticator(new DisabledPermittedAccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter(); new AccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter();
AuthFilter<BasicCredentials, DisabledPermittedAuthenticatedAccount> disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder<DisabledPermittedAuthenticatedAccount>().setAuthenticator(
new DisabledPermittedAccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter();
return new PolymorphicAuthDynamicFeature<>(ImmutableMap.of(Account.class, accountAuthFilter, return new PolymorphicAuthDynamicFeature<>(ImmutableMap.of(AuthenticatedAccount.class, accountAuthFilter,
DisabledPermittedAccount.class, disabledPermittedAccountAuthFilter)); DisabledPermittedAuthenticatedAccount.class, disabledPermittedAccountAuthFilter));
} }
public static String getAuthHeader(String number, String password) { public static String getAuthHeader(String number, String password) {
@ -223,13 +233,16 @@ public class AuthHelper {
when(account.getMasterDevice()).thenReturn(Optional.of(device)); when(account.getMasterDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn(number); when(account.getNumber()).thenReturn(number);
when(account.getUuid()).thenReturn(uuid); when(account.getUuid()).thenReturn(uuid);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getRelay()).thenReturn(Optional.empty()); when(account.getRelay()).thenReturn(Optional.empty());
when(account.isEnabled()).thenReturn(true); when(account.isEnabled()).thenReturn(true);
when(accountsManager.get(number)).thenReturn(Optional.of(account)); when(accountsManager.get(number)).thenReturn(Optional.of(account));
when(accountsManager.get(uuid)).thenReturn(Optional.of(account)); when(accountsManager.get(uuid)).thenReturn(Optional.of(account));
when(accountsManager.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(number)))).thenReturn(Optional.of(account)); when(accountsManager.get(argThat(
when(accountsManager.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid() && identifier.getUuid().equals(uuid)))).thenReturn(Optional.of(account)); (ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber()
&& identifier.getNumber().equals(number)))).thenReturn(Optional.of(account));
when(accountsManager.get(argThat(
(ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid()
&& identifier.getUuid().equals(uuid)))).thenReturn(Optional.of(account));
} }
} }

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -40,6 +40,7 @@ import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager; import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
@ -52,6 +53,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager; import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbRule; import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbRule;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage; import org.whispersystems.websocket.messages.WebSocketResponseMessage;
@ -90,13 +92,13 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest
when(account.getUuid()).thenReturn(UUID.randomUUID()); when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L); when(device.getId()).thenReturn(1L);
webSocketConnection = new WebSocketConnection( webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class), mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class), reportMessageManager), new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class), reportMessageManager),
account, new AuthenticatedAccount(() -> new Pair<>(account, device)),
device, device,
webSocketClient, webSocketClient,
retrySchedulingExecutor); retrySchedulingExecutor);
} }
@After @After

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -24,6 +24,7 @@ import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.auth.basic.BasicCredentials; import io.dropwizard.auth.basic.BasicCredentials;
import io.lettuce.core.RedisException;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
@ -39,7 +40,6 @@ import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import io.lettuce.core.RedisException;
import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.RandomStringUtils;
import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.junit.Before; import org.junit.Before;
@ -48,6 +48,7 @@ import org.mockito.ArgumentMatchers;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager; import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
@ -58,6 +59,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.WebSocketClient; import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult; import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult;
import org.whispersystems.websocket.messages.WebSocketResponseMessage; import org.whispersystems.websocket.messages.WebSocketResponseMessage;
@ -75,6 +77,7 @@ public class WebSocketConnectionTest {
private AccountsManager accountsManager; private AccountsManager accountsManager;
private Account account; private Account account;
private Device device; private Device device;
private AuthenticatedAccount auth;
private UpgradeRequest upgradeRequest; private UpgradeRequest upgradeRequest;
private ReceiptSender receiptSender; private ReceiptSender receiptSender;
private ApnFallbackManager apnFallbackManager; private ApnFallbackManager apnFallbackManager;
@ -86,6 +89,7 @@ public class WebSocketConnectionTest {
accountsManager = mock(AccountsManager.class); accountsManager = mock(AccountsManager.class);
account = mock(Account.class); account = mock(Account.class);
device = mock(Device.class); device = mock(Device.class);
auth = new AuthenticatedAccount(() -> new Pair<>(account, device));
upgradeRequest = mock(UpgradeRequest.class); upgradeRequest = mock(UpgradeRequest.class);
receiptSender = mock(ReceiptSender.class); receiptSender = mock(ReceiptSender.class);
apnFallbackManager = mock(ApnFallbackManager.class); apnFallbackManager = mock(ApnFallbackManager.class);
@ -94,35 +98,42 @@ public class WebSocketConnectionTest {
@Test @Test
public void testCredentials() throws Exception { public void testCredentials() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class); MessagesManager storedMessages = mock(MessagesManager.class);
WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator); WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages, mock(MessageSender.class), apnFallbackManager, mock(ClientPresenceManager.class), AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages,
mock(MessageSender.class), apnFallbackManager, mock(ClientPresenceManager.class),
retrySchedulingExecutor); retrySchedulingExecutor);
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(account)); .thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(account, device))));
when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD)))) when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD))))
.thenReturn(Optional.<Account>empty()); .thenReturn(Optional.empty());
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device)); when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<>() {{
put("login", new LinkedList<>() {{
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, List<String>>() {{ add(VALID_USER);
put("login", new LinkedList<String>() {{add(VALID_USER);}}); }});
put("password", new LinkedList<String>() {{add(VALID_PASSWORD);}}); put("password", new LinkedList<>() {{
add(VALID_PASSWORD);
}});
}}); }});
AuthenticationResult<Account> account = webSocketAuthenticator.authenticate(upgradeRequest); AuthenticationResult<AuthenticatedAccount> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated(Account.class)).thenReturn(account.getUser().orElse(null)); when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.getUser().orElse(null));
connectListener.onWebSocketConnect(sessionContext); connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext).addListener(any(WebSocketSessionContext.WebSocketEventListener.class)); verify(sessionContext).addListener(any(WebSocketSessionContext.WebSocketEventListener.class));
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, List<String>>() {{ when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, List<String>>() {{
put("login", new LinkedList<String>() {{add(INVALID_USER);}}); put("login", new LinkedList<String>() {{
put("password", new LinkedList<String>() {{add(INVALID_PASSWORD);}}); add(INVALID_USER);
}});
put("password", new LinkedList<String>() {{
add(INVALID_PASSWORD);
}});
}}); }});
account = webSocketAuthenticator.authenticate(upgradeRequest); account = webSocketAuthenticator.authenticate(upgradeRequest);
@ -148,7 +159,6 @@ public class WebSocketConnectionTest {
when(device.getId()).thenReturn(2L); when(device.getId()).thenReturn(2L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222"); when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid); when(account.getUuid()).thenReturn(accountUuid);
@ -184,12 +194,13 @@ public class WebSocketConnectionTest {
}); });
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages,
account, device, client, retrySchedulingExecutor); auth, device, client, retrySchedulingExecutor);
connection.start(); connection.start();
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any()); verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class),
ArgumentMatchers.<Optional<byte[]>>any());
assertTrue(futures.size() == 3); assertEquals(3, futures.size());
WebSocketResponseMessage response = mock(WebSocketResponseMessage.class); WebSocketResponseMessage response = mock(WebSocketResponseMessage.class);
when(response.getStatus()).thenReturn(200); when(response.getStatus()).thenReturn(200);
@ -199,7 +210,7 @@ public class WebSocketConnectionTest {
futures.get(2).completeExceptionally(new IOException()); futures.get(2).completeExceptionally(new IOException());
verify(storedMessages, times(1)).delete(eq(accountUuid), eq(2L), eq(outgoingMessages.get(1).getGuid())); verify(storedMessages, times(1)).delete(eq(accountUuid), eq(2L), eq(outgoingMessages.get(1).getGuid()));
verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender1"), eq(2222L)); verify(receiptSender, times(1)).sendReceipt(eq(auth), eq("sender1"), eq(2222L));
connection.stop(); connection.stop();
verify(client).close(anyInt(), anyString()); verify(client).close(anyInt(), anyString());
@ -207,9 +218,10 @@ public class WebSocketConnectionTest {
@Test(timeout = 5_000L) @Test(timeout = 5_000L)
public void testOnlineSend() throws Exception { public void testOnlineSend() throws Exception {
final MessagesManager messagesManager = mock(MessagesManager.class); final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class); final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
final UUID accountUuid = UUID.randomUUID(); final UUID accountUuid = UUID.randomUUID();
@ -219,7 +231,7 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA"); when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)) .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false))
.thenReturn(new OutgoingMessageEntityList(List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first")), false)) .thenReturn(new OutgoingMessageEntityList(List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first")), false))
.thenReturn(new OutgoingMessageEntityList(List.of(createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")), false)); .thenReturn(new OutgoingMessageEntityList(List.of(createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")), false));
@ -300,7 +312,6 @@ public class WebSocketConnectionTest {
when(device.getId()).thenReturn(2L); when(device.getId()).thenReturn(2L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222"); when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(UUID.randomUUID()); when(account.getUuid()).thenReturn(UUID.randomUUID());
@ -336,11 +347,12 @@ public class WebSocketConnectionTest {
}); });
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages,
account, device, client, retrySchedulingExecutor); auth, device, client, retrySchedulingExecutor);
connection.start(); connection.start();
verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any()); verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class),
ArgumentMatchers.<Optional<byte[]>>any());
assertEquals(futures.size(), 2); assertEquals(futures.size(), 2);
@ -349,7 +361,7 @@ public class WebSocketConnectionTest {
futures.get(1).complete(response); futures.get(1).complete(response);
futures.get(0).completeExceptionally(new IOException()); futures.get(0).completeExceptionally(new IOException());
verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender2"), eq(secondMessage.getTimestamp())); verify(receiptSender, times(1)).sendReceipt(eq(auth), eq("sender2"), eq(secondMessage.getTimestamp()));
connection.stop(); connection.stop();
verify(client).close(anyInt(), anyString()); verify(client).close(anyInt(), anyString());
@ -357,19 +369,21 @@ public class WebSocketConnectionTest {
@Test(timeout = 5000L) @Test(timeout = 5000L)
public void testProcessStoredMessageConcurrency() throws InterruptedException { public void testProcessStoredMessageConcurrency() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class); final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class); final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
when(account.getNumber()).thenReturn("+18005551234"); when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID()); when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L); when(device.getId()).thenReturn(1L);
when(client.getUserAgent()).thenReturn("Test-UA"); when(client.getUserAgent()).thenReturn("Test-UA");
final AtomicBoolean threadWaiting = new AtomicBoolean(false); final AtomicBoolean threadWaiting = new AtomicBoolean(false);
final AtomicBoolean returnMessageList = new AtomicBoolean(false); final AtomicBoolean returnMessageList = new AtomicBoolean(false);
when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)).thenAnswer((Answer<OutgoingMessageEntityList>)invocation -> { when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)).thenAnswer(
(Answer<OutgoingMessageEntityList>) invocation -> {
synchronized (threadWaiting) { synchronized (threadWaiting) {
threadWaiting.set(true); threadWaiting.set(true);
threadWaiting.notifyAll(); threadWaiting.notifyAll();
@ -418,9 +432,10 @@ public class WebSocketConnectionTest {
@Test(timeout = 5000L) @Test(timeout = 5000L)
public void testProcessStoredMessagesMultiplePages() throws InterruptedException { public void testProcessStoredMessagesMultiplePages() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class); final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class); final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
when(account.getNumber()).thenReturn("+18005551234"); when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID()); when(account.getUuid()).thenReturn(UUID.randomUUID());
@ -428,8 +443,8 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA"); when(client.getUserAgent()).thenReturn("Test-UA");
final List<OutgoingMessageEntity> firstPageMessages = final List<OutgoingMessageEntity> firstPageMessages =
List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"), List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"),
createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")); createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second"));
final List<OutgoingMessageEntity> secondPageMessages = final List<OutgoingMessageEntity> secondPageMessages =
List.of(createMessage(3L, false, "sender1", UUID.randomUUID(), 3333, false, "third")); List.of(createMessage(3L, false, "sender1", UUID.randomUUID(), 3333, false, "third"));
@ -463,7 +478,8 @@ public class WebSocketConnectionTest {
public void testProcessStoredMessagesContainsSenderUuid() throws InterruptedException { public void testProcessStoredMessagesContainsSenderUuid() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class); final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class); final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
when(account.getNumber()).thenReturn("+18005551234"); when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID()); when(account.getUuid()).thenReturn(UUID.randomUUID());
@ -471,7 +487,8 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA"); when(client.getUserAgent()).thenReturn("Test-UA");
final UUID senderUuid = UUID.randomUUID(); final UUID senderUuid = UUID.randomUUID();
final List<OutgoingMessageEntity> messages = List.of(createMessage(1L, false, "senderE164", senderUuid, 1111L, false, "message the first")); final List<OutgoingMessageEntity> messages = List.of(
createMessage(1L, false, "senderE164", senderUuid, 1111L, false, "message the first"));
final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(messages, false); final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(messages, false);
when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)).thenReturn(firstPage); when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)).thenReturn(firstPage);
@ -511,9 +528,10 @@ public class WebSocketConnectionTest {
@Test @Test
public void testProcessStoredMessagesSingleEmptyCall() { public void testProcessStoredMessagesSingleEmptyCall() {
final MessagesManager messagesManager = mock(MessagesManager.class); final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class); final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
final UUID accountUuid = UUID.randomUUID(); final UUID accountUuid = UUID.randomUUID();
@ -523,7 +541,7 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA"); when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200); when(successResponse.getStatus()).thenReturn(200);
@ -540,10 +558,11 @@ public class WebSocketConnectionTest {
@Test(timeout = 5000L) @Test(timeout = 5000L)
public void testRequeryOnStateMismatch() throws InterruptedException { public void testRequeryOnStateMismatch() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class); final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class); final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
final UUID accountUuid = UUID.randomUUID(); retrySchedulingExecutor);
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234"); when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid); when(account.getUuid()).thenReturn(accountUuid);
@ -551,8 +570,8 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA"); when(client.getUserAgent()).thenReturn("Test-UA");
final List<OutgoingMessageEntity> firstPageMessages = final List<OutgoingMessageEntity> firstPageMessages =
List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"), List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"),
createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")); createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second"));
final List<OutgoingMessageEntity> secondPageMessages = final List<OutgoingMessageEntity> secondPageMessages =
List.of(createMessage(3L, false, "sender1", UUID.randomUUID(), 3333, false, "third")); List.of(createMessage(3L, false, "sender1", UUID.randomUUID(), 3333, false, "third"));
@ -587,9 +606,10 @@ public class WebSocketConnectionTest {
@Test @Test
public void testProcessCachedMessagesOnly() { public void testProcessCachedMessagesOnly() {
final MessagesManager messagesManager = mock(MessagesManager.class); final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class); final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
final UUID accountUuid = UUID.randomUUID(); final UUID accountUuid = UUID.randomUUID();
@ -599,7 +619,7 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA"); when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200); when(successResponse.getStatus()).thenReturn(200);
@ -619,9 +639,10 @@ public class WebSocketConnectionTest {
@Test @Test
public void testProcessDatabaseMessagesAfterPersist() { public void testProcessDatabaseMessagesAfterPersist() {
final MessagesManager messagesManager = mock(MessagesManager.class); final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class); final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
final UUID accountUuid = UUID.randomUUID(); final UUID accountUuid = UUID.randomUUID();
@ -631,7 +652,7 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA"); when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean())) when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false)); .thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class); final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200); when(successResponse.getStatus()).thenReturn(200);
@ -664,7 +685,6 @@ public class WebSocketConnectionTest {
when(device.getId()).thenReturn(2L); when(device.getId()).thenReturn(2L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222"); when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid); when(account.getUuid()).thenReturn(accountUuid);
@ -689,20 +709,24 @@ public class WebSocketConnectionTest {
final WebSocketClient client = mock(WebSocketClient.class); final WebSocketClient client = mock(WebSocketClient.class);
when(client.getUserAgent()).thenReturn(userAgent); when(client.getUserAgent()).thenReturn(userAgent);
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any())) when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class),
.thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() { ArgumentMatchers.<Optional<byte[]>>any()))
@Override .thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() {
public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock) throws Throwable { @Override
CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>(); public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock)
futures.add(future); throws Throwable {
return future; CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
} futures.add(future);
}); return future;
}
});
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client, retrySchedulingExecutor); WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client,
retrySchedulingExecutor);
connection.start(); connection.start();
verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any()); verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class),
ArgumentMatchers.<Optional<byte[]>>any());
assertEquals(2, futures.size()); assertEquals(2, futures.size());
@ -737,7 +761,6 @@ public class WebSocketConnectionTest {
when(device.getId()).thenReturn(2L); when(device.getId()).thenReturn(2L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222"); when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid); when(account.getUuid()).thenReturn(accountUuid);
@ -762,20 +785,24 @@ public class WebSocketConnectionTest {
final WebSocketClient client = mock(WebSocketClient.class); final WebSocketClient client = mock(WebSocketClient.class);
when(client.getUserAgent()).thenReturn(userAgent); when(client.getUserAgent()).thenReturn(userAgent);
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any())) when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class),
.thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() { ArgumentMatchers.<Optional<byte[]>>any()))
@Override .thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() {
public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock) throws Throwable { @Override
CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>(); public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock)
futures.add(future); throws Throwable {
return future; CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
} futures.add(future);
}); return future;
}
});
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client, retrySchedulingExecutor); WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client,
retrySchedulingExecutor);
connection.start(); connection.start();
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any()); verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class),
ArgumentMatchers.<Optional<byte[]>>any());
assertEquals(3, futures.size()); assertEquals(3, futures.size());
@ -799,7 +826,6 @@ public class WebSocketConnectionTest {
when(device.getId()).thenReturn(2L); when(device.getId()).thenReturn(2L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222"); when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid); when(account.getUuid()).thenReturn(accountUuid);
@ -808,17 +834,20 @@ public class WebSocketConnectionTest {
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false)) when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false))
.thenThrow(new RedisException("OH NO")); .thenThrow(new RedisException("OH NO"));
when(retrySchedulingExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer((Answer<ScheduledFuture<?>>) invocation -> { when(retrySchedulingExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer(
invocation.getArgument(0, Runnable.class).run(); (Answer<ScheduledFuture<?>>) invocation -> {
return mock(ScheduledFuture.class); invocation.getArgument(0, Runnable.class).run();
}); return mock(ScheduledFuture.class);
});
final WebSocketClient client = mock(WebSocketClient.class); final WebSocketClient client = mock(WebSocketClient.class);
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client, retrySchedulingExecutor); WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client,
retrySchedulingExecutor);
connection.start(); connection.start();
verify(retrySchedulingExecutor, times(WebSocketConnection.MAX_CONSECUTIVE_RETRIES)).schedule(any(Runnable.class), anyLong(), any()); verify(retrySchedulingExecutor, times(WebSocketConnection.MAX_CONSECUTIVE_RETRIES)).schedule(any(Runnable.class),
anyLong(), any());
verify(client).close(eq(1011), anyString()); verify(client).close(eq(1011), anyString());
} }