Use refreshing `AuthenticatedAccount` for `@Auth`

This commit is contained in:
Chris Eager 2021-08-11 14:52:25 -05:00 committed by GitHub
parent b3e6a50dee
commit 31022aeb79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
53 changed files with 1251 additions and 969 deletions

View File

@ -64,8 +64,9 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.dispatch.DispatchManager;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator;
@ -149,7 +150,6 @@ import org.whispersystems.textsecuregcm.sms.TwilioSmsSender;
import org.whispersystems.textsecuregcm.sms.TwilioVerifyExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.AbusiveHostRules;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountCleaner;
import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawler;
import org.whispersystems.textsecuregcm.storage.AccountDatabaseCrawlerCache;
@ -544,31 +544,40 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.credentialsProvider(cdnCredentialsProvider)
.region(Region.of(config.getCdnConfiguration().getRegion()))
.build();
PostPolicyGenerator profileCdnPolicyGenerator = new PostPolicyGenerator(config.getCdnConfiguration().getRegion(), config.getCdnConfiguration().getBucket(), config.getCdnConfiguration().getAccessKey());
PolicySigner profileCdnPolicySigner = new PolicySigner(config.getCdnConfiguration().getAccessSecret(), config.getCdnConfiguration().getRegion());
PostPolicyGenerator profileCdnPolicyGenerator = new PostPolicyGenerator(config.getCdnConfiguration().getRegion(),
config.getCdnConfiguration().getBucket(), config.getCdnConfiguration().getAccessKey());
PolicySigner profileCdnPolicySigner = new PolicySigner(config.getCdnConfiguration().getAccessSecret(),
config.getCdnConfiguration().getRegion());
ServerSecretParams zkSecretParams = new ServerSecretParams(config.getZkConfig().getServerSecret());
ServerZkProfileOperations zkProfileOperations = new ServerZkProfileOperations(zkSecretParams);
ServerZkAuthOperations zkAuthOperations = new ServerZkAuthOperations(zkSecretParams);
boolean isZkEnabled = config.getZkConfig().isEnabled();
ServerSecretParams zkSecretParams = new ServerSecretParams(config.getZkConfig().getServerSecret());
ServerZkProfileOperations zkProfileOperations = new ServerZkProfileOperations(zkSecretParams);
ServerZkAuthOperations zkAuthOperations = new ServerZkAuthOperations(zkSecretParams);
boolean isZkEnabled = config.getZkConfig().isEnabled();
AuthFilter<BasicCredentials, Account> accountAuthFilter = new BasicCredentialAuthFilter.Builder<Account>().setAuthenticator(accountAuthenticator).buildAuthFilter ();
AuthFilter<BasicCredentials, DisabledPermittedAccount> disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder<DisabledPermittedAccount>().setAuthenticator(disabledPermittedAccountAuthenticator).buildAuthFilter();
AuthFilter<BasicCredentials, AuthenticatedAccount> accountAuthFilter = new BasicCredentialAuthFilter.Builder<AuthenticatedAccount>().setAuthenticator(
accountAuthenticator).buildAuthFilter();
AuthFilter<BasicCredentials, DisabledPermittedAuthenticatedAccount> disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder<DisabledPermittedAuthenticatedAccount>().setAuthenticator(
disabledPermittedAccountAuthenticator).buildAuthFilter();
environment.servlets().addFilter("RemoteDeprecationFilter", new RemoteDeprecationFilter(dynamicConfigurationManager))
environment.servlets()
.addFilter("RemoteDeprecationFilter", new RemoteDeprecationFilter(dynamicConfigurationManager))
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
environment.jersey().register(new ContentLengthFilter(TrafficSource.HTTP));
environment.jersey().register(MultiRecipientMessageProvider.class);
environment.jersey().register(new MetricsApplicationEventListener(TrafficSource.HTTP));
environment.jersey().register(new PolymorphicAuthDynamicFeature<>(ImmutableMap.of(Account.class, accountAuthFilter,
DisabledPermittedAccount.class, disabledPermittedAccountAuthFilter)));
environment.jersey().register(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)));
environment.jersey()
.register(new PolymorphicAuthDynamicFeature<>(ImmutableMap.of(AuthenticatedAccount.class, accountAuthFilter,
DisabledPermittedAuthenticatedAccount.class, disabledPermittedAccountAuthFilter)));
environment.jersey().register(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)));
environment.jersey().register(new TimestampResponseFilter());
environment.jersey().register(new VoiceVerificationController(config.getVoiceVerificationConfiguration().getUrl(), config.getVoiceVerificationConfiguration().getLocales()));
environment.jersey().register(new VoiceVerificationController(config.getVoiceVerificationConfiguration().getUrl(),
config.getVoiceVerificationConfiguration().getLocales()));
///
WebSocketEnvironment<Account> webSocketEnvironment = new WebSocketEnvironment<>(environment, config.getWebSocketConfiguration(), 90000);
WebSocketEnvironment<AuthenticatedAccount> webSocketEnvironment = new WebSocketEnvironment<>(environment,
config.getWebSocketConfiguration(), 90000);
webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator));
webSocketEnvironment.setConnectListener(
new AuthenticatedConnectListener(receiptSender, messagesManager, messageSender, apnFallbackManager,
@ -602,15 +611,18 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new RemoteConfigController(remoteConfigsManager, config.getRemoteConfigConfiguration().getAuthorizedTokens(), config.getRemoteConfigConfiguration().getGlobalConfig()),
new SecureBackupController(backupCredentialsGenerator),
new SecureStorageController(storageCredentialsGenerator),
new StickerController(rateLimiters, config.getCdnConfiguration().getAccessKey(), config.getCdnConfiguration().getAccessSecret(), config.getCdnConfiguration().getRegion(), config.getCdnConfiguration().getBucket())
new StickerController(rateLimiters, config.getCdnConfiguration().getAccessKey(),
config.getCdnConfiguration().getAccessSecret(), config.getCdnConfiguration().getRegion(),
config.getCdnConfiguration().getBucket())
);
for (Object controller : commonControllers) {
environment.jersey().register(controller);
webSocketEnvironment.jersey().register(controller);
environment.jersey().register(controller);
webSocketEnvironment.jersey().register(controller);
}
WebSocketEnvironment<Account> provisioningEnvironment = new WebSocketEnvironment<>(environment, webSocketEnvironment.getRequestLog(), 60000);
WebSocketEnvironment<AuthenticatedAccount> provisioningEnvironment = new WebSocketEnvironment<>(environment,
webSocketEnvironment.getRequestLog(), 60000);
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(pubSubManager));
provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET));
provisioningEnvironment.jersey().register(new KeepAliveController(clientPresenceManager));
@ -618,16 +630,19 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
registerCorsFilter(environment);
registerExceptionMappers(environment, webSocketEnvironment, provisioningEnvironment);
RateLimitChallengeExceptionMapper rateLimitChallengeExceptionMapper = new RateLimitChallengeExceptionMapper(rateLimitChallengeManager);
RateLimitChallengeExceptionMapper rateLimitChallengeExceptionMapper = new RateLimitChallengeExceptionMapper(
rateLimitChallengeManager);
environment.jersey().register(rateLimitChallengeExceptionMapper);
webSocketEnvironment.jersey().register(rateLimitChallengeExceptionMapper);
provisioningEnvironment.jersey().register(rateLimitChallengeExceptionMapper);
WebSocketResourceProviderFactory<Account> webSocketServlet = new WebSocketResourceProviderFactory<>(webSocketEnvironment, Account.class);
WebSocketResourceProviderFactory<Account> provisioningServlet = new WebSocketResourceProviderFactory<>(provisioningEnvironment, Account.class);
WebSocketResourceProviderFactory<AuthenticatedAccount> webSocketServlet = new WebSocketResourceProviderFactory<>(
webSocketEnvironment, AuthenticatedAccount.class);
WebSocketResourceProviderFactory<AuthenticatedAccount> provisioningServlet = new WebSocketResourceProviderFactory<>(
provisioningEnvironment, AuthenticatedAccount.class);
ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet );
ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet);
ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet);
websocket.addMapping("/v1/websocket/");
@ -649,14 +664,18 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.metrics().register(name(NetworkReceivedGauge.class, "bytes_received"), new NetworkReceivedGauge());
environment.metrics().register(name(FileDescriptorGauge.class, "fd_count"), new FileDescriptorGauge());
environment.metrics().register(name(MaxFileDescriptorGauge.class, "max_fd_count"), new MaxFileDescriptorGauge());
environment.metrics().register(name(OperatingSystemMemoryGauge.class, "buffers"), new OperatingSystemMemoryGauge("Buffers"));
environment.metrics().register(name(OperatingSystemMemoryGauge.class, "cached"), new OperatingSystemMemoryGauge("Cached"));
environment.metrics()
.register(name(OperatingSystemMemoryGauge.class, "buffers"), new OperatingSystemMemoryGauge("Buffers"));
environment.metrics()
.register(name(OperatingSystemMemoryGauge.class, "cached"), new OperatingSystemMemoryGauge("Cached"));
BufferPoolGauges.registerMetrics();
GarbageCollectionGauges.registerMetrics();
}
private void registerExceptionMappers(Environment environment, WebSocketEnvironment<Account> webSocketEnvironment, WebSocketEnvironment<Account> provisioningEnvironment) {
private void registerExceptionMappers(Environment environment,
WebSocketEnvironment<AuthenticatedAccount> webSocketEnvironment,
WebSocketEnvironment<AuthenticatedAccount> provisioningEnvironment) {
environment.jersey().register(new LoggingUnhandledExceptionMapper());
environment.jersey().register(new IOExceptionMapper());
environment.jersey().register(new RateLimitExceededExceptionMapper());

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
@ -7,17 +7,17 @@ package org.whispersystems.textsecuregcm.auth;
import io.dropwizard.auth.Authenticator;
import io.dropwizard.auth.basic.BasicCredentials;
import java.util.Optional;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
public class AccountAuthenticator extends BaseAccountAuthenticator implements Authenticator<BasicCredentials, Account> {
public class AccountAuthenticator extends BaseAccountAuthenticator implements
Authenticator<BasicCredentials, AuthenticatedAccount> {
public AccountAuthenticator(AccountsManager accountsManager) {
super(accountsManager);
}
@Override
public Optional<Account> authenticate(BasicCredentials basicCredentials) {
public Optional<AuthenticatedAccount> authenticate(BasicCredentials basicCredentials) {
return super.authenticate(basicCredentials, true);
}

View File

@ -0,0 +1,42 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import java.security.Principal;
import java.util.function.Supplier;
import javax.security.auth.Subject;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair;
public class AuthenticatedAccount implements Principal {
private final Supplier<Pair<Account, Device>> accountAndDevice;
public AuthenticatedAccount(final Supplier<Pair<Account, Device>> accountAndDevice) {
this.accountAndDevice = accountAndDevice;
}
public Account getAccount() {
return accountAndDevice.get().first();
}
public Device getAuthenticatedDevice() {
return accountAndDevice.get().second();
}
// Principal implementation
@Override
public String getName() {
return null;
}
@Override
public boolean implies(final Subject subject) {
return false;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -19,6 +19,7 @@ import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.RefreshingAccountAndDeviceSupplier;
import org.whispersystems.textsecuregcm.util.Util;
public class BaseAccountAuthenticator {
@ -45,14 +46,15 @@ public class BaseAccountAuthenticator {
this.clock = clock;
}
public Optional<Account> authenticate(BasicCredentials basicCredentials, boolean enabledRequired) {
public Optional<AuthenticatedAccount> authenticate(BasicCredentials basicCredentials, boolean enabledRequired) {
boolean succeeded = false;
String failureReason = null;
String credentialType = null;
try {
AuthorizationHeader authorizationHeader = AuthorizationHeader.fromUserAndPassword(basicCredentials.getUsername(), basicCredentials.getPassword());
Optional<Account> account = accountsManager.get(authorizationHeader.getIdentifier());
AuthorizationHeader authorizationHeader = AuthorizationHeader.fromUserAndPassword(basicCredentials.getUsername(),
basicCredentials.getPassword());
Optional<Account> account = accountsManager.get(authorizationHeader.getIdentifier());
credentialType = authorizationHeader.getIdentifier().hasNumber() ? "e164" : "uuid";
@ -83,9 +85,8 @@ public class BaseAccountAuthenticator {
if (device.get().getAuthenticationCredentials().verify(basicCredentials.getPassword())) {
succeeded = true;
final Account authenticatedAccount = updateLastSeen(account.get(), device.get());
// the device in scope might be stale after the update, so get the latest from the authenticated account
authenticatedAccount.setAuthenticatedDevice(authenticatedAccount.getDevice(device.get().getId()).orElseThrow());
return Optional.of(authenticatedAccount);
return Optional.of(new AuthenticatedAccount(
new RefreshingAccountAndDeviceSupplier(authenticatedAccount, device.get().getId(), accountsManager)));
}
return Optional.empty();

View File

@ -1,36 +0,0 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import org.whispersystems.textsecuregcm.storage.Account;
import javax.security.auth.Subject;
import java.security.Principal;
public class DisabledPermittedAccount implements Principal {
private final Account account;
public DisabledPermittedAccount(Account account) {
this.account = account;
}
public Account getAccount() {
return account;
}
// Principal implementation
@Override
public String getName() {
return null;
}
@Override
public boolean implies(Subject subject) {
return false;
}
}

View File

@ -1,27 +1,25 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import java.util.Optional;
import io.dropwizard.auth.Authenticator;
import io.dropwizard.auth.basic.BasicCredentials;
import java.util.Optional;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
public class DisabledPermittedAccountAuthenticator extends BaseAccountAuthenticator implements Authenticator<BasicCredentials, DisabledPermittedAccount> {
public class DisabledPermittedAccountAuthenticator extends BaseAccountAuthenticator implements
Authenticator<BasicCredentials, DisabledPermittedAuthenticatedAccount> {
public DisabledPermittedAccountAuthenticator(AccountsManager accountsManager) {
super(accountsManager);
}
@Override
public Optional<DisabledPermittedAccount> authenticate(BasicCredentials credentials) {
Optional<Account> account = super.authenticate(credentials, false);
return account.map(DisabledPermittedAccount::new);
public Optional<DisabledPermittedAuthenticatedAccount> authenticate(BasicCredentials credentials) {
Optional<AuthenticatedAccount> account = super.authenticate(credentials, false);
return account.map(DisabledPermittedAuthenticatedAccount::new);
}
}

View File

@ -0,0 +1,40 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import java.security.Principal;
import javax.security.auth.Subject;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
public class DisabledPermittedAuthenticatedAccount implements Principal {
private final AuthenticatedAccount authenticatedAccount;
public DisabledPermittedAuthenticatedAccount(final AuthenticatedAccount authenticatedAccount) {
this.authenticatedAccount = authenticatedAccount;
}
public Account getAccount() {
return authenticatedAccount.getAccount();
}
public Device getAuthenticatedDevice() {
return authenticatedAccount.getAuthenticatedDevice();
}
// Principal implementation
@Override
public String getName() {
return null;
}
@Override
public boolean implies(Subject subject) {
return false;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
@ -39,9 +39,10 @@ import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.auth.AuthorizationHeader;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.auth.InvalidAuthorizationHeaderException;
@ -404,8 +405,8 @@ public class AccountController {
@GET
@Path("/turn/")
@Produces(MediaType.APPLICATION_JSON)
public TurnToken getTurnToken(@Auth Account account) throws RateLimitExceededException {
rateLimiters.getTurnLimiter().validate(account.getUuid());
public TurnToken getTurnToken(@Auth AuthenticatedAccount auth) throws RateLimitExceededException {
rateLimiters.getTurnLimiter().validate(auth.getAccount().getUuid());
return turnTokenGenerator.generate();
}
@ -413,13 +414,13 @@ public class AccountController {
@PUT
@Path("/gcm/")
@Consumes(MediaType.APPLICATION_JSON)
public void setGcmRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount, @Valid GcmRegistrationId registrationId) {
Account account = disabledPermittedAccount.getAccount();
Device device = account.getAuthenticatedDevice().get();
public void setGcmRegistrationId(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth,
@Valid GcmRegistrationId registrationId) {
Account account = disabledPermittedAuth.getAccount();
Device device = disabledPermittedAuth.getAuthenticatedDevice();
if (device.getGcmId() != null &&
device.getGcmId().equals(registrationId.getGcmRegistrationId()))
{
device.getGcmId().equals(registrationId.getGcmRegistrationId())) {
return;
}
@ -434,9 +435,9 @@ public class AccountController {
@Timed
@DELETE
@Path("/gcm/")
public void deleteGcmRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount) {
Account account = disabledPermittedAccount.getAccount();
Device device = account.getAuthenticatedDevice().get();
public void deleteGcmRegistrationId(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth) {
Account account = disabledPermittedAuth.getAccount();
Device device = disabledPermittedAuth.getAuthenticatedDevice();
accounts.updateDevice(account, device.getId(), d -> {
d.setGcmId(null);
@ -449,9 +450,10 @@ public class AccountController {
@PUT
@Path("/apn/")
@Consumes(MediaType.APPLICATION_JSON)
public void setApnRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount, @Valid ApnRegistrationId registrationId) {
Account account = disabledPermittedAccount.getAccount();
Device device = account.getAuthenticatedDevice().get();
public void setApnRegistrationId(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth,
@Valid ApnRegistrationId registrationId) {
Account account = disabledPermittedAuth.getAccount();
Device device = disabledPermittedAuth.getAuthenticatedDevice();
accounts.updateDevice(account, device.getId(), d -> {
d.setApnId(registrationId.getApnRegistrationId());
@ -464,9 +466,9 @@ public class AccountController {
@Timed
@DELETE
@Path("/apn/")
public void deleteApnRegistrationId(@Auth DisabledPermittedAccount disabledPermittedAccount) {
Account account = disabledPermittedAccount.getAccount();
Device device = account.getAuthenticatedDevice().get();
public void deleteApnRegistrationId(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth) {
Account account = disabledPermittedAuth.getAccount();
Device device = disabledPermittedAuth.getAuthenticatedDevice();
accounts.updateDevice(account, device.getId(), d -> {
d.setApnId(null);
@ -483,57 +485,54 @@ public class AccountController {
@PUT
@Produces(MediaType.APPLICATION_JSON)
@Path("/registration_lock")
public void setRegistrationLock(@Auth Account account, @Valid RegistrationLock accountLock) {
public void setRegistrationLock(@Auth AuthenticatedAccount auth, @Valid RegistrationLock accountLock) {
AuthenticationCredentials credentials = new AuthenticationCredentials(accountLock.getRegistrationLock());
accounts.update(account, a -> {
a.setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt());
});
accounts.update(auth.getAccount(),
a -> a.setRegistrationLock(credentials.getHashedAuthenticationToken(), credentials.getSalt()));
}
@Timed
@DELETE
@Path("/registration_lock")
public void removeRegistrationLock(@Auth Account account) {
accounts.update(account, a -> a.setRegistrationLock(null, null));
public void removeRegistrationLock(@Auth AuthenticatedAccount auth) {
accounts.update(auth.getAccount(), a -> a.setRegistrationLock(null, null));
}
@Timed
@PUT
@Path("/name/")
public void setName(@Auth DisabledPermittedAccount disabledPermittedAccount, @Valid DeviceName deviceName) {
Account account = disabledPermittedAccount.getAccount();
Device device = account.getAuthenticatedDevice().get();
public void setName(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth, @Valid DeviceName deviceName) {
Account account = disabledPermittedAuth.getAccount();
Device device = disabledPermittedAuth.getAuthenticatedDevice();
accounts.updateDevice(account, device.getId(), d -> d.setName(deviceName.getDeviceName()));
}
@Timed
@DELETE
@Path("/signaling_key")
public void removeSignalingKey(@Auth DisabledPermittedAccount disabledPermittedAccount) {
public void removeSignalingKey(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth) {
}
@Timed
@PUT
@Path("/attributes/")
@Consumes(MediaType.APPLICATION_JSON)
public void setAccountAttributes(@Auth DisabledPermittedAccount disabledPermittedAccount,
@HeaderParam("X-Signal-Agent") String userAgent,
@Valid AccountAttributes attributes)
{
Account account = disabledPermittedAccount.getAccount();
long deviceId = account.getAuthenticatedDevice().get().getId();
accounts.update(account, a-> {
public void setAccountAttributes(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth,
@HeaderParam("X-Signal-Agent") String userAgent,
@Valid AccountAttributes attributes) {
Account account = disabledPermittedAuth.getAccount();
long deviceId = disabledPermittedAuth.getAuthenticatedDevice().getId();
accounts.update(account, a -> {
a.getDevice(deviceId).ifPresent(d -> {
d.setFetchesMessages(attributes.getFetchesMessages());
d.setName(attributes.getName());
d.setLastSeen(Util.todayInMillis());
d.setCapabilities(attributes.getCapabilities());
d.setRegistrationId(attributes.getRegistrationId());
d.setUserAgent(userAgent);
});
d.setFetchesMessages(attributes.getFetchesMessages());
d.setName(attributes.getName());
d.setLastSeen(Util.todayInMillis());
d.setCapabilities(attributes.getCapabilities());
d.setRegistrationId(attributes.getRegistrationId());
d.setUserAgent(userAgent);
});
a.setRegistrationLockFromAttributes(attributes);
@ -546,29 +545,30 @@ public class AccountController {
@GET
@Path("/me")
@Produces(MediaType.APPLICATION_JSON)
public AccountCreationResult getMe(@Auth Account account) {
return whoAmI(account);
public AccountCreationResult getMe(@Auth AuthenticatedAccount auth) {
return whoAmI(auth);
}
@GET
@Path("/whoami")
@Produces(MediaType.APPLICATION_JSON)
public AccountCreationResult whoAmI(@Auth Account account) {
return new AccountCreationResult(account.getUuid(), account.isStorageSupported());
public AccountCreationResult whoAmI(@Auth AuthenticatedAccount auth) {
return new AccountCreationResult(auth.getAccount().getUuid(), auth.getAccount().isStorageSupported());
}
@DELETE
@Path("/username")
@Produces(MediaType.APPLICATION_JSON)
public void deleteUsername(@Auth Account account) {
usernames.delete(account.getUuid());
public void deleteUsername(@Auth AuthenticatedAccount auth) {
usernames.delete(auth.getAccount().getUuid());
}
@PUT
@Path("/username/{username}")
@Produces(MediaType.APPLICATION_JSON)
public Response setUsername(@Auth Account account, @PathParam("username") String username) throws RateLimitExceededException {
rateLimiters.getUsernameSetLimiter().validate(account.getUuid());
public Response setUsername(@Auth AuthenticatedAccount auth, @PathParam("username") String username)
throws RateLimitExceededException {
rateLimiters.getUsernameSetLimiter().validate(auth.getAccount().getUuid());
if (username == null || username.isEmpty()) {
return Response.status(Response.Status.BAD_REQUEST).build();
@ -580,7 +580,7 @@ public class AccountController {
return Response.status(Response.Status.BAD_REQUEST).build();
}
if (!usernames.put(account.getUuid(), username)) {
if (!usernames.put(auth.getAccount().getUuid(), username)) {
return Response.status(Response.Status.CONFLICT).build();
}
@ -678,8 +678,8 @@ public class AccountController {
@Timed
@DELETE
@Path("/me")
public void deleteAccount(@Auth Account account) throws InterruptedException {
accounts.delete(account, AccountsManager.DeletionReason.USER_REQUEST);
public void deleteAccount(@Auth AuthenticatedAccount auth) throws InterruptedException {
accounts.delete(auth.getAccount(), AccountsManager.DeletionReason.USER_REQUEST);
}
private boolean shouldAutoBlock(String sourceHost) {

View File

@ -1,29 +1,27 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import com.amazonaws.HttpMethod;
import com.codahale.metrics.annotation.Timed;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV1;
import org.whispersystems.textsecuregcm.entities.AttachmentUri;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.s3.UrlSigner;
import org.whispersystems.textsecuregcm.storage.Account;
import io.dropwizard.auth.Auth;
import java.io.IOException;
import java.net.URL;
import java.util.stream.Stream;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import java.io.IOException;
import java.net.URL;
import java.util.stream.Stream;
import io.dropwizard.auth.Auth;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV1;
import org.whispersystems.textsecuregcm.entities.AttachmentUri;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.s3.UrlSigner;
@Path("/v1/attachments")
@ -35,25 +33,25 @@ public class AttachmentControllerV1 extends AttachmentControllerBase {
private static final String[] UNACCELERATED_REGIONS = {"+20", "+971", "+968", "+974"};
private final RateLimiters rateLimiters;
private final UrlSigner urlSigner;
private final UrlSigner urlSigner;
public AttachmentControllerV1(RateLimiters rateLimiters, String accessKey, String accessSecret, String bucket) {
this.rateLimiters = rateLimiters;
this.urlSigner = new UrlSigner(accessKey, accessSecret, bucket);
this.urlSigner = new UrlSigner(accessKey, accessSecret, bucket);
}
@Timed
@GET
@Produces(MediaType.APPLICATION_JSON)
public AttachmentDescriptorV1 allocateAttachment(@Auth Account account)
throws RateLimitExceededException
{
if (account.isRateLimited()) {
rateLimiters.getAttachmentLimiter().validate(account.getUuid());
public AttachmentDescriptorV1 allocateAttachment(@Auth AuthenticatedAccount auth)
throws RateLimitExceededException {
if (auth.getAccount().isRateLimited()) {
rateLimiters.getAttachmentLimiter().validate(auth.getAccount().getUuid());
}
long attachmentId = generateAttachmentId();
URL url = urlSigner.getPreSignedUrl(attachmentId, HttpMethod.PUT, Stream.of(UNACCELERATED_REGIONS).anyMatch(region -> account.getNumber().startsWith(region)));
URL url = urlSigner.getPreSignedUrl(attachmentId, HttpMethod.PUT,
Stream.of(UNACCELERATED_REGIONS).anyMatch(region -> auth.getAccount().getNumber().startsWith(region)));
return new AttachmentDescriptorV1(attachmentId, url.toExternalForm());
@ -63,11 +61,11 @@ public class AttachmentControllerV1 extends AttachmentControllerBase {
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/{attachmentId}")
public AttachmentUri redirectToAttachment(@Auth Account account,
@PathParam("attachmentId") long attachmentId)
throws IOException
{
return new AttachmentUri(urlSigner.getPreSignedUrl(attachmentId, HttpMethod.GET, Stream.of(UNACCELERATED_REGIONS).anyMatch(region -> account.getNumber().startsWith(region))));
public AttachmentUri redirectToAttachment(@Auth AuthenticatedAccount auth,
@PathParam("attachmentId") long attachmentId)
throws IOException {
return new AttachmentUri(urlSigner.getPreSignedUrl(attachmentId, HttpMethod.GET,
Stream.of(UNACCELERATED_REGIONS).anyMatch(region -> auth.getAccount().getNumber().startsWith(region))));
}
}

View File

@ -1,28 +1,26 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed;
import io.dropwizard.auth.Auth;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV2;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.s3.PolicySigner;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Pair;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import io.dropwizard.auth.Auth;
@Path("/v2/attachments")
public class AttachmentControllerV2 extends AttachmentControllerBase {
@ -40,19 +38,20 @@ public class AttachmentControllerV2 extends AttachmentControllerBase {
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/form/upload")
public AttachmentDescriptorV2 getAttachmentUploadForm(@Auth Account account) throws RateLimitExceededException {
rateLimiter.validate(account.getUuid());
public AttachmentDescriptorV2 getAttachmentUploadForm(@Auth AuthenticatedAccount auth)
throws RateLimitExceededException {
rateLimiter.validate(auth.getAccount().getUuid());
ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC);
long attachmentId = generateAttachmentId();
String objectName = String.valueOf(attachmentId);
Pair<String, String> policy = policyGenerator.createFor(now, String.valueOf(objectName), 100 * 1024 * 1024);
String signature = policySigner.getSignature(now, policy.second());
ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC);
long attachmentId = generateAttachmentId();
String objectName = String.valueOf(attachmentId);
Pair<String, String> policy = policyGenerator.createFor(now, String.valueOf(objectName), 100 * 1024 * 1024);
String signature = policySigner.getSignature(now, policy.second());
return new AttachmentDescriptorV2(attachmentId, objectName, policy.first(),
"private", "AWS4-HMAC-SHA256",
now.format(PostPolicyGenerator.AWS_DATE_TIME),
policy.second(), signature);
"private", "AWS4-HMAC-SHA256",
now.format(PostPolicyGenerator.AWS_DATE_TIME),
policy.second(), signature);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -7,19 +7,6 @@ package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed;
import io.dropwizard.auth.Auth;
import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV3;
import org.whispersystems.textsecuregcm.gcp.CanonicalRequest;
import org.whispersystems.textsecuregcm.gcp.CanonicalRequestGenerator;
import org.whispersystems.textsecuregcm.gcp.CanonicalRequestSigner;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import javax.annotation.Nonnull;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import java.io.IOException;
import java.security.InvalidKeyException;
import java.security.SecureRandom;
@ -29,6 +16,18 @@ import java.time.ZonedDateTime;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV3;
import org.whispersystems.textsecuregcm.gcp.CanonicalRequest;
import org.whispersystems.textsecuregcm.gcp.CanonicalRequestGenerator;
import org.whispersystems.textsecuregcm.gcp.CanonicalRequestSigner;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
@Path("/v3/attachments")
public class AttachmentControllerV3 extends AttachmentControllerBase {
@ -45,26 +44,29 @@ public class AttachmentControllerV3 extends AttachmentControllerBase {
@Nonnull
private final SecureRandom secureRandom;
public AttachmentControllerV3(@Nonnull RateLimiters rateLimiters, @Nonnull String domain, @Nonnull String email, int maxSizeInBytes, @Nonnull String pathPrefix, @Nonnull String rsaSigningKey)
public AttachmentControllerV3(@Nonnull RateLimiters rateLimiters, @Nonnull String domain, @Nonnull String email,
int maxSizeInBytes, @Nonnull String pathPrefix, @Nonnull String rsaSigningKey)
throws IOException, InvalidKeyException, InvalidKeySpecException {
this.rateLimiter = rateLimiters.getAttachmentLimiter();
this.rateLimiter = rateLimiters.getAttachmentLimiter();
this.canonicalRequestGenerator = new CanonicalRequestGenerator(domain, email, maxSizeInBytes, pathPrefix);
this.canonicalRequestSigner = new CanonicalRequestSigner(rsaSigningKey);
this.secureRandom = new SecureRandom();
this.canonicalRequestSigner = new CanonicalRequestSigner(rsaSigningKey);
this.secureRandom = new SecureRandom();
}
@Timed
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/form/upload")
public AttachmentDescriptorV3 getAttachmentUploadForm(@Auth Account account) throws RateLimitExceededException {
rateLimiter.validate(account.getUuid());
public AttachmentDescriptorV3 getAttachmentUploadForm(@Auth AuthenticatedAccount auth)
throws RateLimitExceededException {
rateLimiter.validate(auth.getAccount().getUuid());
final ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC);
final String key = generateAttachmentKey();
final ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC);
final String key = generateAttachmentKey();
final CanonicalRequest canonicalRequest = canonicalRequestGenerator.createFor(key, now);
return new AttachmentDescriptorV3(2, key, getHeaderMap(canonicalRequest), getSignedUploadLocation(canonicalRequest));
return new AttachmentDescriptorV3(2, key, getHeaderMap(canonicalRequest),
getSignedUploadLocation(canonicalRequest));
}
private String getSignedUploadLocation(@Nonnull CanonicalRequest canonicalRequest) {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -10,7 +10,6 @@ import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.annotation.Timed;
import io.dropwizard.auth.Auth;
import io.micrometer.core.instrument.Metrics;
import java.io.IOException;
import java.security.InvalidKeyException;
import java.util.LinkedList;
import java.util.List;
@ -24,10 +23,10 @@ import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.signal.zkgroup.auth.ServerZkAuthOperations;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
import org.whispersystems.textsecuregcm.entities.DeliveryCertificate;
import org.whispersystems.textsecuregcm.entities.GroupCredentials;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Util;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@ -51,43 +50,49 @@ public class CertificateController {
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/delivery")
public DeliveryCertificate getDeliveryCertificate(@Auth Account account,
@QueryParam("includeE164") Optional<Boolean> maybeIncludeE164)
throws InvalidKeyException
{
if (account.getAuthenticatedDevice().isEmpty()) {
throw new AssertionError();
}
if (Util.isEmpty(account.getIdentityKey())) {
public DeliveryCertificate getDeliveryCertificate(@Auth AuthenticatedAccount auth,
@QueryParam("includeE164") Optional<Boolean> maybeIncludeE164)
throws InvalidKeyException {
if (Util.isEmpty(auth.getAccount().getIdentityKey())) {
throw new WebApplicationException(Response.Status.BAD_REQUEST);
}
final boolean includeE164 = maybeIncludeE164.orElse(true);
Metrics.counter(GENERATE_DELIVERY_CERTIFICATE_COUNTER_NAME, INCLUDE_E164_TAG_NAME, String.valueOf(includeE164)).increment();
Metrics.counter(GENERATE_DELIVERY_CERTIFICATE_COUNTER_NAME, INCLUDE_E164_TAG_NAME, String.valueOf(includeE164))
.increment();
return new DeliveryCertificate(certificateGenerator.createFor(account, account.getAuthenticatedDevice().get(), includeE164));
return new DeliveryCertificate(
certificateGenerator.createFor(auth.getAccount(), auth.getAuthenticatedDevice(), includeE164));
}
@Timed
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/group/{startRedemptionTime}/{endRedemptionTime}")
public GroupCredentials getAuthenticationCredentials(@Auth Account account,
@PathParam("startRedemptionTime") int startRedemptionTime,
@PathParam("endRedemptionTime") int endRedemptionTime)
{
if (!isZkEnabled) throw new WebApplicationException(Response.Status.NOT_FOUND);
if (startRedemptionTime > endRedemptionTime) throw new WebApplicationException(Response.Status.BAD_REQUEST);
if (endRedemptionTime > Util.currentDaysSinceEpoch() + 7) throw new WebApplicationException(Response.Status.BAD_REQUEST);
if (startRedemptionTime < Util.currentDaysSinceEpoch()) throw new WebApplicationException(Response.Status.BAD_REQUEST);
public GroupCredentials getAuthenticationCredentials(@Auth AuthenticatedAccount auth,
@PathParam("startRedemptionTime") int startRedemptionTime,
@PathParam("endRedemptionTime") int endRedemptionTime) {
if (!isZkEnabled) {
throw new WebApplicationException(Response.Status.NOT_FOUND);
}
if (startRedemptionTime > endRedemptionTime) {
throw new WebApplicationException(Response.Status.BAD_REQUEST);
}
if (endRedemptionTime > Util.currentDaysSinceEpoch() + 7) {
throw new WebApplicationException(Response.Status.BAD_REQUEST);
}
if (startRedemptionTime < Util.currentDaysSinceEpoch()) {
throw new WebApplicationException(Response.Status.BAD_REQUEST);
}
List<GroupCredentials.GroupCredential> credentials = new LinkedList<>();
for (int i=startRedemptionTime;i<=endRedemptionTime;i++) {
credentials.add(new GroupCredentials.GroupCredential(serverZkAuthOperations.issueAuthCredential(account.getUuid(), i)
.serialize(),
i));
for (int i = startRedemptionTime; i <= endRedemptionTime; i++) {
credentials.add(new GroupCredentials.GroupCredential(
serverZkAuthOperations.issueAuthCredential(auth.getAccount().getUuid(), i)
.serialize(),
i));
}
return new GroupCredentials(credentials);

View File

@ -17,12 +17,12 @@ import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.AnswerChallengeRequest;
import org.whispersystems.textsecuregcm.entities.AnswerPushChallengeRequest;
import org.whispersystems.textsecuregcm.entities.AnswerRecaptchaChallengeRequest;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.ForwardedIpUtil;
@Path("/v1/challenge")
@ -38,7 +38,7 @@ public class ChallengeController {
@PUT
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
public Response handleChallengeResponse(@Auth final Account account,
public Response handleChallengeResponse(@Auth final AuthenticatedAccount auth,
@Valid final AnswerChallengeRequest answerRequest,
@HeaderParam("X-Forwarded-For") String forwardedFor) throws RetryLaterException {
@ -46,14 +46,15 @@ public class ChallengeController {
if (answerRequest instanceof AnswerPushChallengeRequest) {
final AnswerPushChallengeRequest pushChallengeRequest = (AnswerPushChallengeRequest) answerRequest;
rateLimitChallengeManager.answerPushChallenge(account, pushChallengeRequest.getChallenge());
rateLimitChallengeManager.answerPushChallenge(auth.getAccount(), pushChallengeRequest.getChallenge());
} else if (answerRequest instanceof AnswerRecaptchaChallengeRequest) {
try {
final AnswerRecaptchaChallengeRequest recaptchaChallengeRequest = (AnswerRecaptchaChallengeRequest) answerRequest;
final String mostRecentProxy = ForwardedIpUtil.getMostRecentProxy(forwardedFor).orElseThrow();
rateLimitChallengeManager.answerRecaptchaChallenge(account, recaptchaChallengeRequest.getCaptcha(), mostRecentProxy);
rateLimitChallengeManager.answerRecaptchaChallenge(auth.getAccount(), recaptchaChallengeRequest.getCaptcha(),
mostRecentProxy);
} catch (final NoSuchElementException e) {
return Response.status(400).build();
@ -69,9 +70,9 @@ public class ChallengeController {
@Timed
@POST
@Path("/push")
public Response requestPushChallenge(@Auth final Account account) {
public Response requestPushChallenge(@Auth final AuthenticatedAccount auth) {
try {
rateLimitChallengeManager.sendPushChallenge(account);
rateLimitChallengeManager.sendPushChallenge(auth.getAccount());
return Response.status(200).build();
} catch (final NotPushRegisteredException e) {
return Response.status(404).build();

View File

@ -26,6 +26,7 @@ import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.auth.AuthorizationHeader;
import org.whispersystems.textsecuregcm.auth.InvalidAuthorizationHeaderException;
@ -79,12 +80,12 @@ public class DeviceController {
@Timed
@GET
@Produces(MediaType.APPLICATION_JSON)
public DeviceInfoList getDevices(@Auth Account account) {
public DeviceInfoList getDevices(@Auth AuthenticatedAccount auth) {
List<DeviceInfo> devices = new LinkedList<>();
for (Device device : account.getDevices()) {
for (Device device : auth.getAccount().getDevices()) {
devices.add(new DeviceInfo(device.getId(), device.getName(),
device.getLastSeen(), device.getCreated()));
device.getLastSeen(), device.getCreated()));
}
return new DeviceInfoList(devices);
@ -93,8 +94,9 @@ public class DeviceController {
@Timed
@DELETE
@Path("/{device_id}")
public void removeDevice(@Auth Account account, @PathParam("device_id") long deviceId) {
if (account.getAuthenticatedDevice().get().getId() != Device.MASTER_ID) {
public void removeDevice(@Auth AuthenticatedAccount auth, @PathParam("device_id") long deviceId) {
Account account = auth.getAccount();
if (auth.getAuthenticatedDevice().getId() != Device.MASTER_ID) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}
@ -109,9 +111,11 @@ public class DeviceController {
@GET
@Path("/provisioning/code")
@Produces(MediaType.APPLICATION_JSON)
public VerificationCode createDeviceToken(@Auth Account account)
throws RateLimitExceededException, DeviceLimitExceededException
{
public VerificationCode createDeviceToken(@Auth AuthenticatedAccount auth)
throws RateLimitExceededException, DeviceLimitExceededException {
final Account account = auth.getAccount();
rateLimiters.getAllocateDeviceLimiter().validate(account.getUuid());
int maxDeviceLimit = MAX_DEVICES;
@ -124,7 +128,7 @@ public class DeviceController {
throw new DeviceLimitExceededException(account.getDevices().size(), MAX_DEVICES);
}
if (account.getAuthenticatedDevice().get().getId() != Device.MASTER_ID) {
if (auth.getAuthenticatedDevice().getId() != Device.MASTER_ID) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}
@ -213,18 +217,18 @@ public class DeviceController {
@Timed
@PUT
@Path("/unauthenticated_delivery")
public void setUnauthenticatedDelivery(@Auth Account account) {
assert(account.getAuthenticatedDevice().isPresent());
public void setUnauthenticatedDelivery(@Auth AuthenticatedAccount auth) {
assert (auth.getAuthenticatedDevice() != null);
// Deprecated
}
@Timed
@PUT
@Path("/capabilities")
public void setCapabiltities(@Auth Account account, @Valid DeviceCapabilities capabilities) {
assert(account.getAuthenticatedDevice().isPresent());
final long deviceId = account.getAuthenticatedDevice().get().getId();
accounts.updateDevice(account, deviceId, d -> d.setCapabilities(capabilities));
public void setCapabiltities(@Auth AuthenticatedAccount auth, @Valid DeviceCapabilities capabilities) {
assert (auth.getAuthenticatedDevice() != null);
final long deviceId = auth.getAuthenticatedDevice().getId();
accounts.updateDevice(auth.getAccount(), deviceId, d -> d.setCapabilities(capabilities));
}
@VisibleForTesting protected VerificationCode generateVerificationCode() {

View File

@ -1,14 +1,11 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed;
import io.dropwizard.auth.Auth;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.storage.Account;
import javax.ws.rs.Consumes;
import javax.ws.rs.GET;
import javax.ws.rs.PUT;
@ -16,6 +13,8 @@ import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
@Path("/v1/directory")
public class DirectoryController {
@ -30,15 +29,15 @@ public class DirectoryController {
@GET
@Path("/auth")
@Produces(MediaType.APPLICATION_JSON)
public Response getAuthToken(@Auth Account account) {
return Response.ok().entity(directoryServiceTokenGenerator.generateFor(account.getNumber())).build();
public Response getAuthToken(@Auth AuthenticatedAccount auth) {
return Response.ok().entity(directoryServiceTokenGenerator.generateFor(auth.getAccount().getNumber())).build();
}
@PUT
@Path("/feedback-v3/{status}")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public Response setFeedback(@Auth Account account) {
public Response setFeedback(@Auth AuthenticatedAccount auth) {
return Response.ok().build();
}
@ -47,7 +46,7 @@ public class DirectoryController {
@GET
@Path("/{token}")
@Produces(MediaType.APPLICATION_JSON)
public Response getTokenPresence(@Auth Account account) {
public Response getTokenPresence(@Auth AuthenticatedAccount auth) {
return Response.status(429).build();
}
@ -56,7 +55,7 @@ public class DirectoryController {
@Path("/tokens")
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
public Response getContactIntersection(@Auth Account account) {
public Response getContactIntersection(@Auth AuthenticatedAccount auth) {
return Response.status(429).build();
}
}

View File

@ -34,12 +34,12 @@ import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.configuration.DonationConfiguration;
import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationRequest;
import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationResponse;
import org.whispersystems.textsecuregcm.http.FaultTolerantHttpClient;
import org.whispersystems.textsecuregcm.http.FormDataBodyPublisher;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@Path("/v1/donation")
@ -75,7 +75,7 @@ public class DonationController {
@Path("/authorize-apple-pay")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public CompletableFuture<Response> getApplePayAuthorization(@Auth Account account, @Valid ApplePayAuthorizationRequest request) {
public CompletableFuture<Response> getApplePayAuthorization(@Auth AuthenticatedAccount auth, @Valid ApplePayAuthorizationRequest request) {
if (!supportedCurrencies.contains(request.getCurrency())) {
return CompletableFuture.completedFuture(Response.status(422).build());
}

View File

@ -1,28 +1,27 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.annotation.Timed;
import io.dropwizard.auth.Auth;
import io.micrometer.core.instrument.Metrics;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.core.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.websocket.session.WebSocketSession;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.core.Response;
import static com.codahale.metrics.MetricRegistry.name;
@Path("/v1/keepalive")
public class KeepAliveController {
@ -40,15 +39,14 @@ public class KeepAliveController {
@Timed
@GET
public Response getKeepAlive(@Auth Account account,
@WebSocketSession WebSocketSessionContext context)
{
if (account != null) {
if (!clientPresenceManager.isLocallyPresent(account.getUuid(), account.getAuthenticatedDevice().get().getId())) {
public Response getKeepAlive(@Auth AuthenticatedAccount auth,
@WebSocketSession WebSocketSessionContext context) {
if (auth != null) {
if (!clientPresenceManager.isLocallyPresent(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId())) {
logger.warn("***** No local subscription found for {}::{}; age = {}ms, User-Agent = {}",
account.getUuid(), account.getAuthenticatedDevice().get().getId(),
System.currentTimeMillis() - context.getClient().getCreatedTimestamp(),
context.getClient().getUserAgent());
auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId(),
System.currentTimeMillis() - context.getClient().getCreatedTimestamp(),
context.getClient().getUserAgent());
context.getClient().close(1000, "OK");

View File

@ -28,7 +28,8 @@ import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.PreKeyCount;
@ -76,8 +77,8 @@ public class KeysController {
@GET
@Produces(MediaType.APPLICATION_JSON)
public PreKeyCount getStatus(@Auth Account account) {
int count = keysDynamoDb.getCount(account, account.getAuthenticatedDevice().get().getId());
public PreKeyCount getStatus(@Auth AuthenticatedAccount auth) {
int count = keysDynamoDb.getCount(auth.getAccount(), auth.getAuthenticatedDevice().getId());
if (count > 0) {
count = count - 1;
@ -89,10 +90,10 @@ public class KeysController {
@Timed
@PUT
@Consumes(MediaType.APPLICATION_JSON)
public void setKeys(@Auth DisabledPermittedAccount disabledPermittedAccount, @Valid PreKeyState preKeys) {
Account account = disabledPermittedAccount.getAccount();
Device device = account.getAuthenticatedDevice().get();
boolean updateAccount = false;
public void setKeys(@Auth DisabledPermittedAuthenticatedAccount disabledPermittedAuth, @Valid PreKeyState preKeys) {
Account account = disabledPermittedAuth.getAccount();
Device device = disabledPermittedAuth.getAuthenticatedDevice();
boolean updateAccount = false;
if (!preKeys.getSignedPreKey().equals(device.getSignedPreKey())) {
updateAccount = true;
@ -116,7 +117,7 @@ public class KeysController {
@GET
@Path("/{identifier}/{device_id}")
@Produces(MediaType.APPLICATION_JSON)
public Response getDeviceKeys(@Auth Optional<Account> account,
public Response getDeviceKeys(@Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@PathParam("identifier") AmbiguousIdentifier targetName,
@PathParam("device_id") String deviceId,
@ -125,14 +126,16 @@ public class KeysController {
targetName.incrementRequestCounter("getDeviceKeys", userAgent);
if (!account.isPresent() && !accessKey.isPresent()) {
if (auth.isEmpty() && accessKey.isEmpty()) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}
final Optional<Account> account = auth.map(AuthenticatedAccount::getAccount);
Optional<Account> target = accounts.get(targetName);
OptionalAccess.verify(account, accessKey, target, deviceId);
assert(target.isPresent());
assert (target.isPresent());
{
final String sourceCountryCode = account.map(a -> Util.getCountryCode(a.getNumber())).orElse("0");
@ -146,7 +149,9 @@ public class KeysController {
}
if (account.isPresent()) {
rateLimiters.getPreKeysLimiter().validate(account.get().getUuid() + "." + account.get().getAuthenticatedDevice().get().getId() + "__" + target.get().getUuid() + "." + deviceId);
rateLimiters.getPreKeysLimiter().validate(
account.get().getUuid() + "." + auth.get().getAuthenticatedDevice().getId() + "__" + target.get().getUuid()
+ "." + deviceId);
try {
preKeyRateLimiter.validate(account.get());
@ -188,22 +193,25 @@ public class KeysController {
@PUT
@Path("/signed")
@Consumes(MediaType.APPLICATION_JSON)
public void setSignedKey(@Auth Account account, @Valid SignedPreKey signedPreKey) {
Device device = account.getAuthenticatedDevice().get();
public void setSignedKey(@Auth AuthenticatedAccount auth, @Valid SignedPreKey signedPreKey) {
Device device = auth.getAuthenticatedDevice();
accounts.updateDevice(account, device.getId(), d -> d.setSignedPreKey(signedPreKey));
accounts.updateDevice(auth.getAccount(), device.getId(), d -> d.setSignedPreKey(signedPreKey));
}
@Timed
@GET
@Path("/signed")
@Produces(MediaType.APPLICATION_JSON)
public Optional<SignedPreKey> getSignedKey(@Auth Account account) {
Device device = account.getAuthenticatedDevice().get();
public Optional<SignedPreKey> getSignedKey(@Auth AuthenticatedAccount auth) {
Device device = auth.getAuthenticatedDevice();
SignedPreKey signedPreKey = device.getSignedPreKey();
if (signedPreKey != null) return Optional.of(signedPreKey);
else return Optional.empty();
if (signedPreKey != null) {
return Optional.of(signedPreKey);
} else {
return Optional.empty();
}
}
private Map<Long, PreKey> getLocalKeys(Account destination, String deviceIdSelector) {

View File

@ -62,6 +62,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKeys;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration;
@ -189,12 +190,12 @@ public class MessageController {
@PUT
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public Response sendMessage(@Auth Optional<Account> source,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@HeaderParam("User-Agent") String userAgent,
@HeaderParam("X-Forwarded-For") String forwardedFor,
@PathParam("destination") AmbiguousIdentifier destinationName,
@Valid IncomingMessageList messages)
public Response sendMessage(@Auth Optional<AuthenticatedAccount> source,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@HeaderParam("User-Agent") String userAgent,
@HeaderParam("X-Forwarded-For") String forwardedFor,
@PathParam("destination") AmbiguousIdentifier destinationName,
@Valid IncomingMessageList messages)
throws RateLimitExceededException, RateLimitChallengeException {
destinationName.incrementRequestCounter("sendMessage", userAgent);
@ -203,20 +204,22 @@ public class MessageController {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}
if (source.isPresent() && !source.get().isFor(destinationName)) {
assert source.get().getMasterDevice().isPresent();
if (source.isPresent() && !source.get().getAccount().isFor(destinationName)) {
assert source.get().getAccount().getMasterDevice().isPresent();
final Device masterDevice = source.get().getMasterDevice().get();
final String senderCountryCode = Util.getCountryCode(source.get().getNumber());
final Device masterDevice = source.get().getAccount().getMasterDevice().get();
final String senderCountryCode = Util.getCountryCode(source.get().getAccount().getNumber());
if (StringUtils.isAllBlank(masterDevice.getApnId(), masterDevice.getVoipApnId(), masterDevice.getGcmId()) || masterDevice.getUninstalledFeedbackTimestamp() > 0) {
Metrics.counter(UNSEALED_SENDER_WITHOUT_PUSH_TOKEN_COUNTER_NAME, SENDER_COUNTRY_TAG_NAME, senderCountryCode).increment();
if (StringUtils.isAllBlank(masterDevice.getApnId(), masterDevice.getVoipApnId(), masterDevice.getGcmId())
|| masterDevice.getUninstalledFeedbackTimestamp() > 0) {
Metrics.counter(UNSEALED_SENDER_WITHOUT_PUSH_TOKEN_COUNTER_NAME, SENDER_COUNTRY_TAG_NAME, senderCountryCode)
.increment();
}
}
final String senderType;
if (source.isPresent() && !source.get().isFor(destinationName)) {
if (source.isPresent() && !source.get().getAccount().isFor(destinationName)) {
identifiedMeter.mark();
senderType = "identified";
} else if (source.isEmpty()) {
@ -246,23 +249,26 @@ public class MessageController {
}
try {
boolean isSyncMessage = source.isPresent() && source.get().isFor(destinationName);
boolean isSyncMessage = source.isPresent() && source.get().getAccount().isFor(destinationName);
Optional<Account> destination;
if (!isSyncMessage) destination = accountsManager.get(destinationName);
else destination = source;
if (!isSyncMessage) {
destination = accountsManager.get(destinationName);
} else {
destination = source.map(AuthenticatedAccount::getAccount);
}
OptionalAccess.verify(source, accessKey, destination);
assert(destination.isPresent());
OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination);
assert (destination.isPresent());
if (source.isPresent() && !source.get().isFor(destinationName)) {
rateLimiters.getMessagesLimiter().validate(source.get().getUuid(), destination.get().getUuid());
if (source.isPresent() && !source.get().getAccount().isFor(destinationName)) {
rateLimiters.getMessagesLimiter().validate(source.get().getAccount().getUuid(), destination.get().getUuid());
final String senderCountryCode = Util.getCountryCode(source.get().getNumber());
final String senderCountryCode = Util.getCountryCode(source.get().getAccount().getNumber());
try {
unsealedSenderRateLimiter.validate(source.get(), destination.get());
unsealedSenderRateLimiter.validate(source.get().getAccount(), destination.get());
} catch (final RateLimitExceededException e) {
final boolean legacyClient = rateLimitChallengeManager.isClientBelowMinimumVersion(userAgent);
@ -276,11 +282,11 @@ public class MessageController {
throw e;
}
throw new RateLimitChallengeException(source.get(), e.getRetryDuration());
throw new RateLimitChallengeException(source.get().getAccount(), e.getRetryDuration());
}
final String destinationCountryCode = Util.getCountryCode(destination.get().getNumber());
final Device masterDevice = source.get().getMasterDevice().get();
final Device masterDevice = source.get().getAccount().getMasterDevice().get();
if (!senderCountryCode.equals(destinationCountryCode)) {
recordInternationalUnsealedSenderMetrics(forwardedFor, senderCountryCode, destination.get().getNumber());
@ -293,31 +299,34 @@ public class MessageController {
.orElse(false);
if (isRateLimitedHost) {
return declineDelivery(messages, source.get(), destination.get());
return declineDelivery(messages, source.get().getAccount(), destination.get());
}
}
}
}
}
validateCompleteDeviceList(destination.get(), messages.getMessages(), isSyncMessage);
validateCompleteDeviceList(destination.get(), messages.getMessages(), isSyncMessage,
source.map(AuthenticatedAccount::getAuthenticatedDevice).map(Device::getId));
validateRegistrationIds(destination.get(), messages.getMessages());
final List<Tag> tags = List.of(UserAgentTagUtil.getPlatformTag(userAgent),
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.isOnline())),
Tag.of(SENDER_TYPE_TAG_NAME, senderType),
Tag.of(DESTINATION_TYPE_TAG_NAME, destinationName.hasNumber() ? "e164" : "uuid"));
Tag.of(EPHEMERAL_TAG_NAME, String.valueOf(messages.isOnline())),
Tag.of(SENDER_TYPE_TAG_NAME, senderType),
Tag.of(DESTINATION_TYPE_TAG_NAME, destinationName.hasNumber() ? "e164" : "uuid"));
for (IncomingMessage incomingMessage : messages.getMessages()) {
Optional<Device> destinationDevice = destination.get().getDevice(incomingMessage.getDestinationDeviceId());
if (destinationDevice.isPresent()) {
Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment();
sendMessage(source, destination.get(), destinationDevice.get(), messages.getTimestamp(), messages.isOnline(), incomingMessage);
sendMessage(source, destination.get(), destinationDevice.get(), messages.getTimestamp(), messages.isOnline(),
incomingMessage);
}
}
return Response.ok(new SendMessageResponse(!isSyncMessage && source.isPresent() && source.get().getEnabledDeviceCount() > 1)).build();
return Response.ok(new SendMessageResponse(
!isSyncMessage && source.isPresent() && source.get().getAccount().getEnabledDeviceCount() > 1)).build();
} catch (NoSuchUserException e) {
throw new WebApplicationException(Response.status(404).build());
} catch (MismatchedDevicesException e) {
@ -380,7 +389,7 @@ public class MessageController {
final Set<Pair<Long, Integer>> deviceIdAndRegistrationIdSet = accountToDeviceIdAndRegistrationIdMap.get(account);
final Set<Long> deviceIds = deviceIdAndRegistrationIdSet.stream().map(Pair::first).collect(Collectors.toSet());
try {
validateCompleteDeviceList(account, deviceIds, false);
validateCompleteDeviceList(account, deviceIds, false, Optional.empty());
validateRegistrationIds(account, deviceIdAndRegistrationIdSet.stream());
} catch (MismatchedDevicesException e) {
accountMismatchedDevices.add(new AccountMismatchedDevices(account.getUuid(),
@ -476,7 +485,9 @@ public class MessageController {
if (random.nextDouble() <= messageRateConfiguration.getReceiptProbability()) {
receiptExecutorService.schedule(() -> {
try {
receiptSender.sendReceipt(destination, source.getNumber(), timestamp);
receiptSender.sendReceipt(
new AuthenticatedAccount(() -> new Pair<>(destination, destination.getMasterDevice().get())),
source.getNumber(), timestamp);
} catch (final NoSuchUserException ignored) {
}
}, receiptDelay.toMillis(), TimeUnit.MILLISECONDS);
@ -503,16 +514,17 @@ public class MessageController {
@Timed
@GET
@Produces(MediaType.APPLICATION_JSON)
public OutgoingMessageEntityList getPendingMessages(@Auth Account account, @HeaderParam("User-Agent") String userAgent) {
assert account.getAuthenticatedDevice().isPresent();
public OutgoingMessageEntityList getPendingMessages(@Auth AuthenticatedAccount auth,
@HeaderParam("User-Agent") String userAgent) {
assert auth.getAuthenticatedDevice() != null;
if (!Util.isEmpty(account.getAuthenticatedDevice().get().getApnId())) {
RedisOperation.unchecked(() -> apnFallbackManager.cancel(account, account.getAuthenticatedDevice().get()));
if (!Util.isEmpty(auth.getAuthenticatedDevice().getApnId())) {
RedisOperation.unchecked(() -> apnFallbackManager.cancel(auth.getAccount(), auth.getAuthenticatedDevice()));
}
final OutgoingMessageEntityList outgoingMessages = messagesManager.getMessagesForDevice(
account.getUuid(),
account.getAuthenticatedDevice().get().getId(),
auth.getAccount().getUuid(),
auth.getAuthenticatedDevice().getId(),
userAgent,
false);
@ -549,21 +561,20 @@ public class MessageController {
@Timed
@DELETE
@Path("/{source}/{timestamp}")
public void removePendingMessage(@Auth Account account,
@PathParam("source") String source,
@PathParam("timestamp") long timestamp)
{
public void removePendingMessage(@Auth AuthenticatedAccount auth,
@PathParam("source") String source,
@PathParam("timestamp") long timestamp) {
try {
WebSocketConnection.recordMessageDeliveryDuration(timestamp, account.getAuthenticatedDevice().get());
WebSocketConnection.recordMessageDeliveryDuration(timestamp, auth.getAuthenticatedDevice());
Optional<OutgoingMessageEntity> message = messagesManager.delete(
account.getUuid(),
account.getAuthenticatedDevice().get().getId(),
source, timestamp);
auth.getAccount().getUuid(),
auth.getAuthenticatedDevice().getId(),
source, timestamp);
if (message.isPresent() && message.get().getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE) {
receiptSender.sendReceipt(account,
message.get().getSource(),
message.get().getTimestamp());
receiptSender.sendReceipt(auth,
message.get().getSource(),
message.get().getTimestamp());
}
} catch (NoSuchUserException e) {
logger.warn("Sending delivery receipt", e);
@ -573,17 +584,18 @@ public class MessageController {
@Timed
@DELETE
@Path("/uuid/{uuid}")
public void removePendingMessage(@Auth Account account, @PathParam("uuid") UUID uuid) {
public void removePendingMessage(@Auth AuthenticatedAccount auth, @PathParam("uuid") UUID uuid) {
try {
Optional<OutgoingMessageEntity> message = messagesManager.delete(
account.getUuid(),
account.getAuthenticatedDevice().get().getId(),
uuid);
auth.getAccount().getUuid(),
auth.getAuthenticatedDevice().getId(),
uuid);
if (message.isPresent()) {
WebSocketConnection.recordMessageDeliveryDuration(message.get().getTimestamp(), account.getAuthenticatedDevice().get());
if (!Util.isEmpty(message.get().getSource()) && message.get().getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE) {
receiptSender.sendReceipt(account, message.get().getSource(), message.get().getTimestamp());
WebSocketConnection.recordMessageDeliveryDuration(message.get().getTimestamp(), auth.getAuthenticatedDevice());
if (!Util.isEmpty(message.get().getSource())
&& message.get().getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT_VALUE) {
receiptSender.sendReceipt(auth, message.get().getSource(), message.get().getTimestamp());
}
}
@ -595,7 +607,8 @@ public class MessageController {
@Timed
@POST
@Path("/report/{sourceNumber}/{messageGuid}")
public Response reportMessage(@Auth Account account, @PathParam("sourceNumber") String sourceNumber, @PathParam("messageGuid") UUID messageGuid) {
public Response reportMessage(@Auth AuthenticatedAccount auth, @PathParam("sourceNumber") String sourceNumber,
@PathParam("messageGuid") UUID messageGuid) {
reportMessageManager.report(sourceNumber, messageGuid);
@ -603,27 +616,26 @@ public class MessageController {
.build();
}
private void sendMessage(Optional<Account> source,
Account destinationAccount,
Device destinationDevice,
long timestamp,
boolean online,
IncomingMessage incomingMessage)
throws NoSuchUserException
{
private void sendMessage(Optional<AuthenticatedAccount> source,
Account destinationAccount,
Device destinationDevice,
long timestamp,
boolean online,
IncomingMessage incomingMessage)
throws NoSuchUserException {
try (final Timer.Context ignored = sendMessageInternalTimer.time()) {
Optional<byte[]> messageBody = getMessageBody(incomingMessage);
Optional<byte[]> messageBody = getMessageBody(incomingMessage);
Optional<byte[]> messageContent = getMessageContent(incomingMessage);
Envelope.Builder messageBuilder = Envelope.newBuilder();
messageBuilder.setType(Envelope.Type.forNumber(incomingMessage.getType()))
.setTimestamp(timestamp == 0 ? System.currentTimeMillis() : timestamp)
.setServerTimestamp(System.currentTimeMillis());
.setTimestamp(timestamp == 0 ? System.currentTimeMillis() : timestamp)
.setServerTimestamp(System.currentTimeMillis());
if (source.isPresent()) {
messageBuilder.setSource(source.get().getNumber())
.setSourceUuid(source.get().getUuid().toString())
.setSourceDevice((int)source.get().getAuthenticatedDevice().get().getId());
messageBuilder.setSource(source.get().getAccount().getNumber())
.setSourceUuid(source.get().getAccount().getUuid().toString())
.setSourceDevice((int) source.get().getAuthenticatedDevice().getId());
}
if (messageBody.isPresent()) {
@ -697,24 +709,26 @@ public class MessageController {
}
@VisibleForTesting
public static void validateCompleteDeviceList(Account account, List<IncomingMessage> messages, boolean isSyncMessage)
public static void validateCompleteDeviceList(Account account, List<IncomingMessage> messages, boolean isSyncMessage,
Optional<Long> authenticatedDeviceId)
throws MismatchedDevicesException {
Set<Long> messageDeviceIds = messages.stream().map(IncomingMessage::getDestinationDeviceId).collect(Collectors.toSet());
validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage);
Set<Long> messageDeviceIds = messages.stream().map(IncomingMessage::getDestinationDeviceId)
.collect(Collectors.toSet());
validateCompleteDeviceList(account, messageDeviceIds, isSyncMessage, authenticatedDeviceId);
}
@VisibleForTesting
public static void validateCompleteDeviceList(Account account, Set<Long> messageDeviceIds, boolean isSyncMessage)
public static void validateCompleteDeviceList(Account account, Set<Long> messageDeviceIds, boolean isSyncMessage,
Optional<Long> authenticatedDeviceId)
throws MismatchedDevicesException {
Set<Long> accountDeviceIds = new HashSet<>();
List<Long> missingDeviceIds = new LinkedList<>();
List<Long> extraDeviceIds = new LinkedList<>();
List<Long> extraDeviceIds = new LinkedList<>();
for (Device device : account.getDevices()) {
for (Device device : account.getDevices()) {
if (device.isEnabled() &&
!(isSyncMessage && device.getId() == account.getAuthenticatedDevice().get().getId()))
{
!(isSyncMessage && device.getId() == authenticatedDeviceId.get())) {
accountDeviceIds.add(device.getId());
if (!messageDeviceIds.contains(device.getId())) {

View File

@ -1,23 +1,21 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.entities.CurrencyConversionEntityList;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager;
import io.dropwizard.auth.Auth;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import io.dropwizard.auth.Auth;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager;
import org.whispersystems.textsecuregcm.entities.CurrencyConversionEntityList;
@Path("/v1/payments")
public class PaymentsController {
@ -34,15 +32,15 @@ public class PaymentsController {
@GET
@Path("/auth")
@Produces(MediaType.APPLICATION_JSON)
public ExternalServiceCredentials getAuth(@Auth Account account) {
return paymentsServiceCredentialGenerator.generateFor(account.getUuid().toString());
public ExternalServiceCredentials getAuth(@Auth AuthenticatedAccount auth) {
return paymentsServiceCredentialGenerator.generateFor(auth.getAccount().getUuid().toString());
}
@Timed
@GET
@Path("/conversions")
@Produces(MediaType.APPLICATION_JSON)
public CurrencyConversionEntityList getConversions(@Auth Account account) {
public CurrencyConversionEntityList getConversions(@Auth AuthenticatedAccount auth) {
return currencyManager.getCurrencyConversions().orElseThrow();
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -41,6 +41,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessChecksum;
import org.whispersystems.textsecuregcm.entities.CreateProfileRequest;
@ -107,60 +108,64 @@ public class ProfileController {
this.isZkEnabled = isZkEnabled;
}
@Timed
@PUT
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
public Response setProfile(@Auth Account account, @Valid CreateProfileRequest request) {
if (!isZkEnabled) throw new WebApplicationException(Response.Status.NOT_FOUND);
@Timed
@PUT
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
public Response setProfile(@Auth AuthenticatedAccount auth, @Valid CreateProfileRequest request) {
if (!isZkEnabled) {
throw new WebApplicationException(Response.Status.NOT_FOUND);
}
final Set<String> allowedPaymentsCountryCodes =
dynamicConfigurationManager.getConfiguration().getPaymentsConfiguration().getAllowedCountryCodes();
final Set<String> allowedPaymentsCountryCodes =
dynamicConfigurationManager.getConfiguration().getPaymentsConfiguration().getAllowedCountryCodes();
if (StringUtils.isNotBlank(request.getPaymentAddress()) &&
!allowedPaymentsCountryCodes.contains(Util.getCountryCode(account.getNumber()))) {
if (StringUtils.isNotBlank(request.getPaymentAddress()) &&
!allowedPaymentsCountryCodes.contains(Util.getCountryCode(auth.getAccount().getNumber()))) {
return Response.status(Status.FORBIDDEN).build();
}
return Response.status(Status.FORBIDDEN).build();
}
Optional<VersionedProfile> currentProfile = profilesManager.get(account.getUuid(), request.getVersion());
String avatar = request.isAvatar() ? generateAvatarObjectName() : null;
Optional<ProfileAvatarUploadAttributes> response = Optional.empty();
Optional<VersionedProfile> currentProfile = profilesManager.get(auth.getAccount().getUuid(), request.getVersion());
String avatar = request.isAvatar() ? generateAvatarObjectName() : null;
Optional<ProfileAvatarUploadAttributes> response = Optional.empty();
profilesManager.set(account.getUuid(),
new VersionedProfile(
request.getVersion(),
request.getName(),
avatar,
request.getAboutEmoji(),
request.getAbout(),
request.getPaymentAddress(),
request.getCommitment().serialize()));
profilesManager.set(auth.getAccount().getUuid(),
new VersionedProfile(
request.getVersion(),
request.getName(),
avatar,
request.getAboutEmoji(),
request.getAbout(),
request.getPaymentAddress(),
request.getCommitment().serialize()));
if (request.isAvatar()) {
Optional<String> currentAvatar = Optional.empty();
Optional<String> currentAvatar = Optional.empty();
if (currentProfile.isPresent() && currentProfile.get().getAvatar() != null && currentProfile.get().getAvatar().startsWith("profiles/")) {
currentAvatar = Optional.of(currentProfile.get().getAvatar());
}
if (currentProfile.isPresent() && currentProfile.get().getAvatar() != null && currentProfile.get().getAvatar()
.startsWith("profiles/")) {
currentAvatar = Optional.of(currentProfile.get().getAvatar());
}
if (currentAvatar.isEmpty() && account.getAvatar() != null && account.getAvatar().startsWith("profiles/")) {
currentAvatar = Optional.of(account.getAvatar());
}
if (currentAvatar.isEmpty() && auth.getAccount().getAvatar() != null && auth.getAccount().getAvatar()
.startsWith("profiles/")) {
currentAvatar = Optional.of(auth.getAccount().getAvatar());
}
currentAvatar.ifPresent(s -> s3client.deleteObject(DeleteObjectRequest.builder()
.bucket(bucket)
.key(s)
.build()));
currentAvatar.ifPresent(s -> s3client.deleteObject(DeleteObjectRequest.builder()
.bucket(bucket)
.key(s)
.build()));
response = Optional.of(generateAvatarUploadForm(avatar));
response = Optional.of(generateAvatarUploadForm(avatar));
}
accountsManager.update(account, a -> {
a.setProfileName(request.getName());
a.setAvatar(avatar);
a.setCurrentProfileVersion(request.getVersion());
});
accountsManager.update(auth.getAccount(), a -> {
a.setProfileName(request.getName());
a.setAvatar(avatar);
a.setCurrentProfileVersion(request.getVersion());
});
if (response.isPresent()) return Response.ok(response).build();
else return Response.ok().build();
@ -170,29 +175,32 @@ public class ProfileController {
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/{uuid}/{version}")
public Optional<Profile> getProfile(@Auth Optional<Account> requestAccount,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@PathParam("uuid") UUID uuid,
@PathParam("version") String version)
throws RateLimitExceededException
{
if (!isZkEnabled) throw new WebApplicationException(Response.Status.NOT_FOUND);
return getVersionedProfile(requestAccount, accessKey, uuid, version, Optional.empty());
public Optional<Profile> getProfile(@Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@PathParam("uuid") UUID uuid,
@PathParam("version") String version)
throws RateLimitExceededException {
if (!isZkEnabled) {
throw new WebApplicationException(Response.Status.NOT_FOUND);
}
return getVersionedProfile(auth.map(AuthenticatedAccount::getAccount), accessKey, uuid, version, Optional.empty());
}
@Timed
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/{uuid}/{version}/{credentialRequest}")
public Optional<Profile> getProfile(@Auth Optional<Account> requestAccount,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@PathParam("uuid") UUID uuid,
@PathParam("version") String version,
@PathParam("credentialRequest") String credentialRequest)
throws RateLimitExceededException
{
if (!isZkEnabled) throw new WebApplicationException(Response.Status.NOT_FOUND);
return getVersionedProfile(requestAccount, accessKey, uuid, version, Optional.of(credentialRequest));
public Optional<Profile> getProfile(@Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@PathParam("uuid") UUID uuid,
@PathParam("version") String version,
@PathParam("credentialRequest") String credentialRequest)
throws RateLimitExceededException {
if (!isZkEnabled) {
throw new WebApplicationException(Response.Status.NOT_FOUND);
}
return getVersionedProfile(auth.map(AuthenticatedAccount::getAccount), accessKey, uuid, version,
Optional.of(credentialRequest));
}
private Optional<Profile> getVersionedProfile(Optional<Account> requestAccount,
@ -255,22 +263,23 @@ public class ProfileController {
}
@Timed
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/username/{username}")
public Profile getProfileByUsername(@Auth Account account, @PathParam("username") String username) throws RateLimitExceededException {
rateLimiters.getUsernameLookupLimiter().validate(account.getUuid());
@Timed
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/username/{username}")
public Profile getProfileByUsername(@Auth AuthenticatedAccount auth, @PathParam("username") String username)
throws RateLimitExceededException {
rateLimiters.getUsernameLookupLimiter().validate(auth.getAccount().getUuid());
username = username.toLowerCase();
username = username.toLowerCase();
Optional<UUID> uuid = usernamesManager.get(username);
Optional<UUID> uuid = usernamesManager.get(username);
if (uuid.isEmpty()) {
throw new WebApplicationException(Response.status(Response.Status.NOT_FOUND).build());
}
if (uuid.isEmpty()) {
throw new WebApplicationException(Response.status(Response.Status.NOT_FOUND).build());
}
Optional<Account> accountProfile = accountsManager.get(uuid.get());
Optional<Account> accountProfile = accountsManager.get(uuid.get());
if (accountProfile.isEmpty()) {
throw new WebApplicationException(Response.status(Response.Status.NOT_FOUND).build());
@ -312,40 +321,40 @@ public class ProfileController {
// Old profile endpoints. Replaced by versioned profile endpoints (above)
@Deprecated
@Timed
@PUT
@Produces(MediaType.APPLICATION_JSON)
@Path("/name/{name}")
public void setProfile(@Auth Account account, @PathParam("name") @ExactlySize(value = {72, 108}, payload = {Unwrapping.Unwrap.class}) Optional<String> name) {
accountsManager.update(account, a -> a.setProfileName(name.orElse(null)));
}
@Deprecated
@Timed
@PUT
@Produces(MediaType.APPLICATION_JSON)
@Path("/name/{name}")
public void setProfile(@Auth AuthenticatedAccount auth,
@PathParam("name") @ExactlySize(value = {72, 108}, payload = {Unwrapping.Unwrap.class}) Optional<String> name) {
accountsManager.update(auth.getAccount(), a -> a.setProfileName(name.orElse(null)));
}
@Deprecated
@Timed
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/{identifier}")
public Profile getProfile(@Auth Optional<Account> requestAccount,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@HeaderParam("User-Agent") String userAgent,
@PathParam("identifier") AmbiguousIdentifier identifier,
@QueryParam("ca") boolean useCaCertificate)
throws RateLimitExceededException
{
public Profile getProfile(@Auth Optional<AuthenticatedAccount> auth,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@HeaderParam("User-Agent") String userAgent,
@PathParam("identifier") AmbiguousIdentifier identifier,
@QueryParam("ca") boolean useCaCertificate)
throws RateLimitExceededException {
identifier.incrementRequestCounter("getProfile", userAgent);
if (requestAccount.isEmpty() && accessKey.isEmpty()) {
if (auth.isEmpty() && accessKey.isEmpty()) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}
if (requestAccount.isPresent()) {
rateLimiters.getProfileLimiter().validate(requestAccount.get().getUuid());
if (auth.isPresent()) {
rateLimiters.getProfileLimiter().validate(auth.get().getAccount().getUuid());
}
Optional<Account> accountProfile = accountsManager.get(identifier);
OptionalAccess.verify(requestAccount, accessKey, accountProfile);
OptionalAccess.verify(auth.map(AuthenticatedAccount::getAccount), accessKey, accountProfile);
Optional<String> username = Optional.empty();
@ -369,24 +378,24 @@ public class ProfileController {
}
@Deprecated
@Timed
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/form/avatar")
public ProfileAvatarUploadAttributes getAvatarUploadForm(@Auth Account account) {
String previousAvatar = account.getAvatar();
String objectName = generateAvatarObjectName();
ProfileAvatarUploadAttributes profileAvatarUploadAttributes = generateAvatarUploadForm(objectName);
@Deprecated
@Timed
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/form/avatar")
public ProfileAvatarUploadAttributes getAvatarUploadForm(@Auth AuthenticatedAccount auth) {
String previousAvatar = auth.getAccount().getAvatar();
String objectName = generateAvatarObjectName();
ProfileAvatarUploadAttributes profileAvatarUploadAttributes = generateAvatarUploadForm(objectName);
if (previousAvatar != null && previousAvatar.startsWith("profiles/")) {
s3client.deleteObject(DeleteObjectRequest.builder()
.bucket(bucket)
.key(previousAvatar)
.build());
}
if (previousAvatar != null && previousAvatar.startsWith("profiles/")) {
s3client.deleteObject(DeleteObjectRequest.builder()
.bucket(bucket)
.key(previousAvatar)
.build());
}
accountsManager.update(account, a -> a.setAvatar(objectName));
accountsManager.update(auth.getAccount(), a -> a.setAvatar(objectName));
return profileAvatarUploadAttributes;
}

View File

@ -17,10 +17,10 @@ import javax.ws.rs.Produces;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.ProvisioningMessage;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.push.ProvisioningManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.websocket.ProvisioningAddress;
@Path("/v1/provisioning")
@ -39,16 +39,15 @@ public class ProvisioningController {
@PUT
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public void sendProvisioningMessage(@Auth Account source,
@PathParam("destination") String destinationName,
@Valid ProvisioningMessage message)
public void sendProvisioningMessage(@Auth AuthenticatedAccount auth,
@PathParam("destination") String destinationName,
@Valid ProvisioningMessage message)
throws RateLimitExceededException {
rateLimiters.getMessagesLimiter().validate(source.getUuid());
rateLimiters.getMessagesLimiter().validate(auth.getAccount().getUuid());
if (!provisioningManager.sendProvisioningMessage(new ProvisioningAddress(destinationName, 0),
Base64.getDecoder().decode(message.getBody())))
{
Base64.getDecoder().decode(message.getBody()))) {
throw new WebApplicationException(Response.Status.NOT_FOUND);
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -8,13 +8,16 @@ package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed;
import com.google.common.annotations.VisibleForTesting;
import io.dropwizard.auth.Auth;
import org.whispersystems.textsecuregcm.entities.UserRemoteConfig;
import org.whispersystems.textsecuregcm.entities.UserRemoteConfigList;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.RemoteConfig;
import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager;
import org.whispersystems.textsecuregcm.util.Conversions;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.validation.Valid;
import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE;
@ -27,16 +30,12 @@ import javax.ws.rs.Produces;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.UserRemoteConfig;
import org.whispersystems.textsecuregcm.entities.UserRemoteConfigList;
import org.whispersystems.textsecuregcm.storage.RemoteConfig;
import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager;
import org.whispersystems.textsecuregcm.util.Conversions;
@Path("/v1/config")
public class RemoteConfigController {
@ -57,15 +56,19 @@ public class RemoteConfigController {
@GET
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public UserRemoteConfigList getAll(@Auth Account account) {
public UserRemoteConfigList getAll(@Auth AuthenticatedAccount auth) {
try {
MessageDigest digest = MessageDigest.getInstance("SHA1");
final Stream<UserRemoteConfig> globalConfigStream = globalConfig.entrySet().stream().map(entry -> new UserRemoteConfig(GLOBAL_CONFIG_PREFIX + entry.getKey(), true, entry.getValue()));
final Stream<UserRemoteConfig> globalConfigStream = globalConfig.entrySet().stream()
.map(entry -> new UserRemoteConfig(GLOBAL_CONFIG_PREFIX + entry.getKey(), true, entry.getValue()));
return new UserRemoteConfigList(Stream.concat(remoteConfigsManager.getAll().stream().map(config -> {
final byte[] hashKey = config.getHashKey() != null ? config.getHashKey().getBytes(StandardCharsets.UTF_8) : config.getName().getBytes(StandardCharsets.UTF_8);
boolean inBucket = isInBucket(digest, account.getUuid(), hashKey, config.getPercentage(), config.getUuids());
return new UserRemoteConfig(config.getName(), inBucket, inBucket ? config.getValue() : config.getDefaultValue());
final byte[] hashKey = config.getHashKey() != null ? config.getHashKey().getBytes(StandardCharsets.UTF_8)
: config.getName().getBytes(StandardCharsets.UTF_8);
boolean inBucket = isInBucket(digest, auth.getAccount().getUuid(), hashKey, config.getPercentage(),
config.getUuids());
return new UserRemoteConfig(config.getName(), inBucket,
inBucket ? config.getValue() : config.getDefaultValue());
}), globalConfigStream).collect(Collectors.toList()));
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);

View File

@ -1,21 +1,19 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.storage.Account;
import io.dropwizard.auth.Auth;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import io.dropwizard.auth.Auth;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
@Path("/v1/backup")
public class SecureBackupController {
@ -30,7 +28,7 @@ public class SecureBackupController {
@GET
@Path("/auth")
@Produces(MediaType.APPLICATION_JSON)
public ExternalServiceCredentials getAuth(@Auth Account account) {
return backupServiceCredentialGenerator.generateFor(account.getUuid().toString());
public ExternalServiceCredentials getAuth(@Auth AuthenticatedAccount auth) {
return backupServiceCredentialGenerator.generateFor(auth.getAccount().getUuid().toString());
}
}

View File

@ -1,21 +1,19 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.storage.Account;
import io.dropwizard.auth.Auth;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import io.dropwizard.auth.Auth;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
@Path("/v1/storage")
public class SecureStorageController {
@ -30,7 +28,7 @@ public class SecureStorageController {
@GET
@Path("/auth")
@Produces(MediaType.APPLICATION_JSON)
public ExternalServiceCredentials getAuth(@Auth Account account) {
return storageServiceCredentialGenerator.generateFor(account.getUuid().toString());
public ExternalServiceCredentials getAuth(@Auth AuthenticatedAccount auth) {
return storageServiceCredentialGenerator.generateFor(auth.getAccount().getUuid().toString());
}
}

View File

@ -1,21 +1,16 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import io.dropwizard.auth.Auth;
import org.whispersystems.textsecuregcm.entities.StickerPackFormUploadAttributes;
import org.whispersystems.textsecuregcm.entities.StickerPackFormUploadAttributes.StickerPackFormUploadItem;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.s3.PolicySigner;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Hex;
import org.whispersystems.textsecuregcm.util.Pair;
import java.security.SecureRandom;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import java.util.LinkedList;
import java.util.List;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.ws.rs.GET;
@ -23,11 +18,15 @@ import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import java.security.SecureRandom;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import java.util.LinkedList;
import java.util.List;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.StickerPackFormUploadAttributes;
import org.whispersystems.textsecuregcm.entities.StickerPackFormUploadAttributes.StickerPackFormUploadItem;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.s3.PolicySigner;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Hex;
import org.whispersystems.textsecuregcm.util.Pair;
@Path("/v1/sticker")
public class StickerController {
@ -45,30 +44,31 @@ public class StickerController {
@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/pack/form/{count}")
public StickerPackFormUploadAttributes getStickersForm(@Auth Account account,
@PathParam("count") @Min(1) @Max(201) int stickerCount)
throws RateLimitExceededException
{
rateLimiters.getStickerPackLimiter().validate(account.getUuid());
ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC);
String packId = generatePackId();
String packLocation = "stickers/" + packId;
String manifestKey = packLocation + "/manifest.proto";
Pair<String, String> manifestPolicy = policyGenerator.createFor(now, manifestKey, Constants.MAXIMUM_STICKER_MANIFEST_SIZE_BYTES);
String manifestSignature = policySigner.getSignature(now, manifestPolicy.second());
StickerPackFormUploadItem manifest = new StickerPackFormUploadItem(-1, manifestKey, manifestPolicy.first(), "private", "AWS4-HMAC-SHA256",
now.format(PostPolicyGenerator.AWS_DATE_TIME), manifestPolicy.second(), manifestSignature);
public StickerPackFormUploadAttributes getStickersForm(@Auth AuthenticatedAccount auth,
@PathParam("count") @Min(1) @Max(201) int stickerCount)
throws RateLimitExceededException {
rateLimiters.getStickerPackLimiter().validate(auth.getAccount().getUuid());
ZonedDateTime now = ZonedDateTime.now(ZoneOffset.UTC);
String packId = generatePackId();
String packLocation = "stickers/" + packId;
String manifestKey = packLocation + "/manifest.proto";
Pair<String, String> manifestPolicy = policyGenerator.createFor(now, manifestKey,
Constants.MAXIMUM_STICKER_MANIFEST_SIZE_BYTES);
String manifestSignature = policySigner.getSignature(now, manifestPolicy.second());
StickerPackFormUploadItem manifest = new StickerPackFormUploadItem(-1, manifestKey, manifestPolicy.first(),
"private", "AWS4-HMAC-SHA256",
now.format(PostPolicyGenerator.AWS_DATE_TIME), manifestPolicy.second(), manifestSignature);
List<StickerPackFormUploadItem> stickers = new LinkedList<>();
for (int i=0;i<stickerCount;i++) {
String stickerKey = packLocation + "/full/" + i;
Pair<String, String> stickerPolicy = policyGenerator.createFor(now, stickerKey, Constants.MAXIMUM_STICKER_SIZE_BYTES);
String stickerSignature = policySigner.getSignature(now, stickerPolicy.second());
for (int i = 0; i < stickerCount; i++) {
String stickerKey = packLocation + "/full/" + i;
Pair<String, String> stickerPolicy = policyGenerator.createFor(now, stickerKey,
Constants.MAXIMUM_STICKER_SIZE_BYTES);
String stickerSignature = policySigner.getSignature(now, stickerPolicy.second());
stickers.add(new StickerPackFormUploadItem(i, stickerKey, stickerPolicy.first(), "private", "AWS4-HMAC-SHA256",
now.format(PostPolicyGenerator.AWS_DATE_TIME), stickerPolicy.second(), stickerSignature));
now.format(PostPolicyGenerator.AWS_DATE_TIME), stickerPolicy.second(), stickerSignature));
}
return new StickerPackFormUploadAttributes(packId, manifest, stickers);

View File

@ -1,20 +1,20 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.push;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.controllers.NoSuchUserException;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import java.util.Optional;
public class ReceiptSender {
private final MessageSender messageSender;
@ -23,30 +23,29 @@ public class ReceiptSender {
private static final Logger logger = LoggerFactory.getLogger(ReceiptSender.class);
public ReceiptSender(AccountsManager accountManager,
MessageSender messageSender)
{
MessageSender messageSender) {
this.accountManager = accountManager;
this.messageSender = messageSender;
this.messageSender = messageSender;
}
public void sendReceipt(Account source, String destination, long messageId)
throws NoSuchUserException
{
if (source.getNumber().equals(destination)) {
public void sendReceipt(AuthenticatedAccount source, String destination, long messageId)
throws NoSuchUserException {
final Account sourceAccount = source.getAccount();
if (sourceAccount.getNumber().equals(destination)) {
return;
}
Account destinationAccount = getDestinationAccount(destination);
Envelope.Builder message = Envelope.newBuilder()
.setServerTimestamp(System.currentTimeMillis())
.setSource(source.getNumber())
.setSourceUuid(source.getUuid().toString())
.setSourceDevice((int) source.getAuthenticatedDevice().get().getId())
.setTimestamp(messageId)
.setType(Envelope.Type.SERVER_DELIVERY_RECEIPT);
Account destinationAccount = getDestinationAccount(destination);
Envelope.Builder message = Envelope.newBuilder()
.setServerTimestamp(System.currentTimeMillis())
.setSource(sourceAccount.getNumber())
.setSourceUuid(sourceAccount.getUuid().toString())
.setSourceDevice((int) source.getAuthenticatedDevice().getId())
.setTimestamp(messageId)
.setType(Envelope.Type.SERVER_DELIVERY_RECEIPT);
if (source.getRelay().isPresent()) {
message.setRelay(source.getRelay().get());
if (sourceAccount.getRelay().isPresent()) {
message.setRelay(sourceAccount.getRelay().get());
}
for (final Device destinationDevice : destinationAccount.getDevices()) {
@ -63,7 +62,7 @@ public class ReceiptSender {
{
Optional<Account> account = accountManager.get(destination);
if (!account.isPresent()) {
if (account.isEmpty()) {
throw new NoSuchUserException(destination);
}

View File

@ -8,12 +8,10 @@ package org.whispersystems.textsecuregcm.storage;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import java.security.Principal;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import javax.security.auth.Subject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
@ -22,7 +20,7 @@ import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.util.Util;
public class Account implements Principal {
public class Account {
@JsonIgnore
private static final Logger logger = LoggerFactory.getLogger(Account.class);
@ -63,9 +61,6 @@ public class Account implements Principal {
@JsonProperty("inCds")
private boolean discoverableByPhoneNumber = true;
@JsonIgnore
private Device authenticatedDevice;
@JsonProperty
private int version;
@ -82,18 +77,6 @@ public class Account implements Principal {
this.unidentifiedAccessKey = unidentifiedAccessKey;
}
public Optional<Device> getAuthenticatedDevice() {
requireNotStale();
return Optional.ofNullable(authenticatedDevice);
}
public void setAuthenticatedDevice(Device device) {
requireNotStale();
this.authenticatedDevice = device;
}
public UUID getUuid() {
// this is the one method that may be called on a stale account
return uuid;
@ -390,6 +373,10 @@ public class Account implements Principal {
this.version = version;
}
boolean isStale() {
return stale;
}
public void markStale() {
stale = true;
}
@ -403,17 +390,4 @@ public class Account implements Principal {
}
}
// Principal implementation
@Override
@JsonIgnore
public String getName() {
return null;
}
@Override
@JsonIgnore
public boolean implies(Subject subject) {
return false;
}
}

View File

@ -0,0 +1,14 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
public class RefreshingAccountAndDeviceNotFoundException extends RuntimeException {
public RefreshingAccountAndDeviceNotFoundException(final String message) {
super(message);
}
}

View File

@ -0,0 +1,35 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import java.util.function.Supplier;
import org.whispersystems.textsecuregcm.util.Pair;
public class RefreshingAccountAndDeviceSupplier implements Supplier<Pair<Account, Device>> {
private Account account;
private Device device;
private final AccountsManager accountsManager;
public RefreshingAccountAndDeviceSupplier(Account account, long deviceId, AccountsManager accountsManager) {
this.account = account;
this.device = account.getDevice(deviceId)
.orElseThrow(() -> new RefreshingAccountAndDeviceNotFoundException("Could not find device"));
this.accountsManager = accountsManager;
}
@Override
public Pair<Account, Device> get() {
if (account.isStale()) {
account = accountsManager.get(account.getUuid())
.orElseThrow(() -> new RuntimeException("Could not find account"));
device = account.getDevice(device.getId())
.orElseThrow(() -> new RefreshingAccountAndDeviceNotFoundException("Could not find device"));
}
return new Pair<>(account, device);
}
}

View File

@ -1,32 +1,31 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.websocket;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.Counter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import java.util.concurrent.ScheduledExecutorService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener;
import java.util.concurrent.ScheduledExecutorService;
import static com.codahale.metrics.MetricRegistry.name;
public class AuthenticatedConnectListener implements WebSocketConnectListener {
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
@ -60,16 +59,16 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
@Override
public void onWebSocketConnect(WebSocketSessionContext context) {
if (context.getAuthenticated() != null) {
final Account account = context.getAuthenticated(Account.class);
final Device device = account.getAuthenticatedDevice().get();
final Timer.Context timer = durationTimer.time();
final WebSocketConnection connection = new WebSocketConnection(receiptSender,
messagesManager, account, device,
context.getClient(),
retrySchedulingExecutor);
final AuthenticatedAccount auth = context.getAuthenticated(AuthenticatedAccount.class);
final Device device = auth.getAuthenticatedDevice();
final Timer.Context timer = durationTimer.time();
final WebSocketConnection connection = new WebSocketConnection(receiptSender,
messagesManager, auth, device,
context.getClient(),
retrySchedulingExecutor);
openWebsocketCounter.inc();
RedisOperation.unchecked(() -> apnFallbackManager.cancel(account, device));
RedisOperation.unchecked(() -> apnFallbackManager.cancel(auth.getAccount(), device));
context.addListener(new WebSocketSessionContext.WebSocketEventListener() {
@Override
@ -79,20 +78,21 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
connection.stop();
RedisOperation.unchecked(() -> clientPresenceManager.clearPresence(account.getUuid(), device.getId()));
RedisOperation.unchecked(
() -> clientPresenceManager.clearPresence(auth.getAccount().getUuid(), device.getId()));
RedisOperation.unchecked(() -> {
messagesManager.removeMessageAvailabilityListener(connection);
if (messagesManager.hasCachedMessages(account.getUuid(), device.getId())) {
messageSender.sendNewMessageNotification(account, device);
if (messagesManager.hasCachedMessages(auth.getAccount().getUuid(), device.getId())) {
messageSender.sendNewMessageNotification(auth.getAccount(), device);
}
});
}
});
try {
clientPresenceManager.setPresent(account.getUuid(), device.getId(), connection);
messagesManager.addMessageAvailabilityListener(account.getUuid(), device.getId(), connection);
clientPresenceManager.setPresent(auth.getAccount().getUuid(), device.getId(), connection);
messagesManager.addMessageAvailabilityListener(auth.getAccount().getUuid(), device.getId(), connection);
connection.start();
} catch (final Exception e) {
log.warn("Failed to initialize websocket", e);

View File

@ -1,23 +1,21 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.websocket;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
import io.dropwizard.auth.basic.BasicCredentials;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import io.dropwizard.auth.basic.BasicCredentials;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.websocket.auth.WebSocketAuthenticator;
public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Account> {
public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<AuthenticatedAccount> {
private final AccountAuthenticator accountAuthenticator;
@ -26,19 +24,18 @@ public class WebSocketAccountAuthenticator implements WebSocketAuthenticator<Acc
}
@Override
public AuthenticationResult<Account> authenticate(UpgradeRequest request) {
public AuthenticationResult<AuthenticatedAccount> authenticate(UpgradeRequest request) {
Map<String, List<String>> parameters = request.getParameterMap();
List<String> usernames = parameters.get("login");
List<String> passwords = parameters.get("password");
List<String> usernames = parameters.get("login");
List<String> passwords = parameters.get("password");
if (usernames == null || usernames.size() == 0 ||
passwords == null || passwords.size() == 0)
{
passwords == null || passwords.size() == 0) {
return new AuthenticationResult<>(Optional.empty(), false);
}
BasicCredentials credentials = new BasicCredentials(usernames.get(0).replace(" ", "+"),
passwords.get(0).replace(" ", "+"));
passwords.get(0).replace(" ", "+"));
return new AuthenticationResult<>(accountAuthenticator.authenticate(credentials), true);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -36,6 +36,7 @@ import javax.ws.rs.WebApplicationException;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.NoSuchUserException;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
@ -43,7 +44,6 @@ import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.DisplacedPresenceListener;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessageAvailabilityListener;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
@ -90,21 +90,22 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class);
private final ReceiptSender receiptSender;
private final MessagesManager messagesManager;
private final ReceiptSender receiptSender;
private final MessagesManager messagesManager;
private final Account account;
private final Device device;
private final WebSocketClient client;
private final AuthenticatedAccount auth;
private final Device device;
private final WebSocketClient client;
private final ScheduledExecutorService retrySchedulingExecutor;
private final boolean isDesktopClient;
private final boolean isDesktopClient;
private final Semaphore processStoredMessagesSemaphore = new Semaphore(1);
private final AtomicReference<StoredMessageState> storedMessageState = new AtomicReference<>(StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE);
private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false);
private final LongAdder sentMessageCounter = new LongAdder();
private final AtomicLong queueDrainStartTime = new AtomicLong();
private final Semaphore processStoredMessagesSemaphore = new Semaphore(1);
private final AtomicReference<StoredMessageState> storedMessageState = new AtomicReference<>(
StoredMessageState.PERSISTED_NEW_MESSAGES_AVAILABLE);
private final AtomicBoolean sentInitialQueueEmptyMessage = new AtomicBoolean(false);
private final LongAdder sentMessageCounter = new LongAdder();
private final AtomicLong queueDrainStartTime = new AtomicLong();
private final AtomicInteger consecutiveRetries = new AtomicInteger();
private final AtomicReference<ScheduledFuture<?>> retryFuture = new AtomicReference<>();
@ -118,16 +119,15 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
public WebSocketConnection(ReceiptSender receiptSender,
MessagesManager messagesManager,
Account account,
AuthenticatedAccount auth,
Device device,
WebSocketClient client,
ScheduledExecutorService retrySchedulingExecutor)
{
this.receiptSender = receiptSender;
ScheduledExecutorService retrySchedulingExecutor) {
this.receiptSender = receiptSender;
this.messagesManager = messagesManager;
this.account = account;
this.device = device;
this.client = client;
this.auth = auth;
this.device = device;
this.client = client;
this.retrySchedulingExecutor = retrySchedulingExecutor;
Optional<ClientPlatform> maybePlatform;
@ -168,7 +168,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
if (throwable == null) {
if (isSuccessResponse(response)) {
if (storedMessageInfo.isPresent()) {
messagesManager.delete(account.getUuid(), device.getId(), storedMessageInfo.get().getGuid());
messagesManager.delete(auth.getAccount().getUuid(), device.getId(), storedMessageInfo.get().getGuid());
}
if (message.getType() != Envelope.Type.SERVER_DELIVERY_RECEIPT) {
@ -204,7 +204,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
if (!message.hasSource()) return;
try {
receiptSender.sendReceipt(account, message.getSource(), message.getTimestamp());
receiptSender.sendReceipt(auth, message.getSource(), message.getTimestamp());
} catch (NoSuchUserException e) {
logger.info("No longer registered " + e.getMessage());
} catch (WebApplicationException e) {
@ -267,7 +267,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
private void sendNextMessagePage(final boolean cachedMessagesOnly, final CompletableFuture<Void> queueClearedFuture) {
try {
final OutgoingMessageEntityList messages = messagesManager
.getMessagesForDevice(account.getUuid(), device.getId(), client.getUserAgent(), cachedMessagesOnly);
.getMessagesForDevice(auth.getAccount().getUuid(), device.getId(), client.getUserAgent(), cachedMessagesOnly);
final CompletableFuture<?>[] sendFutures = new CompletableFuture[messages.getMessages().size()];
@ -303,7 +303,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
final Envelope envelope = builder.build();
if (envelope.getSerializedSize() > MAX_DESKTOP_MESSAGE_SIZE && isDesktopClient) {
messagesManager.delete(account.getUuid(), device.getId(), message.getGuid());
messagesManager.delete(auth.getAccount().getUuid(), device.getId(), message.getGuid());
discardedMessagesMeter.mark();
sendFutures[i] = CompletableFuture.completedFuture(null);
@ -340,7 +340,7 @@ public class WebSocketConnection implements MessageAvailabilityListener, Displac
public void handleNewEphemeralMessageAvailable() {
ephemeralMessageAvailableMeter.mark();
messagesManager.takeEphemeralMessage(account.getUuid(), device.getId())
messagesManager.takeEphemeralMessage(auth.getAccount().getUuid(), device.getId())
.ifPresent(message -> sendMessage(message, Optional.empty()));
}

View File

@ -24,11 +24,11 @@ import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.mappers.RetryLaterExceptionMapper;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -41,7 +41,8 @@ class ChallengeControllerTest {
private static final ResourceExtension EXTENSION = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(Set.of(Account.class, DisabledPermittedAccount.class)))
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
Set.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new RetryLaterExceptionMapper())

View File

@ -0,0 +1,71 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.util.Optional;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.util.Pair;
class RefreshingAccountAndDeviceSupplierTest {
@Test
void test() {
final AccountsManager accountsManager = mock(AccountsManager.class);
final UUID uuid = UUID.randomUUID();
final long deviceId = 2L;
final Account initialAccount = mock(Account.class);
final Device initialDevice = mock(Device.class);
when(initialAccount.getUuid()).thenReturn(uuid);
when(initialDevice.getId()).thenReturn(deviceId);
when(initialAccount.getDevice(deviceId)).thenReturn(Optional.of(initialDevice));
when(accountsManager.get(any(UUID.class))).thenAnswer(answer -> {
final Account account = mock(Account.class);
final Device device = mock(Device.class);
when(account.getUuid()).thenReturn(answer.getArgument(0, UUID.class));
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(device.getId()).thenReturn(deviceId);
return Optional.of(account);
});
final RefreshingAccountAndDeviceSupplier refreshingAccountAndDeviceSupplier = new RefreshingAccountAndDeviceSupplier(
initialAccount, deviceId, accountsManager);
Pair<Account, Device> accountAndDevice = refreshingAccountAndDeviceSupplier.get();
assertSame(initialAccount, accountAndDevice.first());
assertSame(initialDevice, accountAndDevice.second());
accountAndDevice = refreshingAccountAndDeviceSupplier.get();
assertSame(initialAccount, accountAndDevice.first());
assertSame(initialDevice, accountAndDevice.second());
when(initialAccount.isStale()).thenReturn(true);
accountAndDevice = refreshingAccountAndDeviceSupplier.get();
assertNotSame(initialAccount, accountAndDevice.first());
assertNotSame(initialDevice, accountAndDevice.second());
assertEquals(uuid, accountAndDevice.first().getUuid());
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -52,8 +52,9 @@ import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.StoredRegistrationLock;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
@ -145,27 +146,29 @@ class AccountControllerTest {
private static ExternalServiceCredentialGenerator storageCredentialGenerator = new ExternalServiceCredentialGenerator(new byte[32], new byte[32], false);
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.addProvider(new RateLimitExceededExceptionMapper())
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new AccountController(pendingAccountsManager,
accountsManager,
usernamesManager,
abusiveHostRules,
rateLimiters,
smsSender,
dynamicConfigurationManager,
turnTokenGenerator,
new HashMap<>(),
recaptchaClient,
gcmSender,
apnSender,
storageCredentialGenerator,
verifyExperimentEnrollmentManager))
.build();
.addProvider(AuthHelper.getAuthFilter())
.addProvider(
new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class,
DisabledPermittedAuthenticatedAccount.class)))
.addProvider(new RateLimitExceededExceptionMapper())
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new AccountController(pendingAccountsManager,
accountsManager,
usernamesManager,
abusiveHostRules,
rateLimiters,
smsSender,
dynamicConfigurationManager,
turnTokenGenerator,
new HashMap<>(),
recaptchaClient,
gcmSender,
apnSender,
storageCredentialGenerator,
verifyExperimentEnrollmentManager))
.build();
@BeforeEach

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -33,7 +33,8 @@ import org.assertj.core.api.InstanceOfAssertFactories;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV1;
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV2;
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV3;
@ -43,7 +44,6 @@ import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV3;
import org.whispersystems.textsecuregcm.entities.AttachmentUri;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -78,8 +78,9 @@ class AttachmentControllerTest {
static {
try {
resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new AttachmentControllerV1(rateLimiters, "accessKey", "accessSecret", "attachment-bucket"))

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -28,8 +28,9 @@ import org.signal.zkgroup.auth.AuthCredential;
import org.signal.zkgroup.auth.AuthCredentialResponse;
import org.signal.zkgroup.auth.ClientZkAuthOperations;
import org.signal.zkgroup.auth.ServerZkAuthOperations;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.controllers.CertificateController;
import org.whispersystems.textsecuregcm.crypto.Curve;
@ -37,7 +38,6 @@ import org.whispersystems.textsecuregcm.entities.DeliveryCertificate;
import org.whispersystems.textsecuregcm.entities.GroupCredentials;
import org.whispersystems.textsecuregcm.entities.MessageProtos.SenderCertificate;
import org.whispersystems.textsecuregcm.entities.MessageProtos.ServerCertificate;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util;
@ -66,12 +66,13 @@ class CertificateControllerTest {
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new CertificateController(certificateGenerator, serverZkAuthOperations, true))
.build();
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new CertificateController(certificateGenerator, serverZkAuthOperations, true))
.build();
@Test
void testValidCertificate() throws Exception {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.controllers;
@ -36,7 +36,8 @@ import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.StoredVerificationCode;
import org.whispersystems.textsecuregcm.controllers.DeviceController;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
@ -88,17 +89,18 @@ class DeviceControllerTest {
private static Map<String, Integer> deviceConfiguration = new HashMap<>();
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addProvider(new DeviceLimitExceededExceptionMapper())
.addResource(new DumbVerificationDeviceController(pendingDevicesManager,
accountsManager,
messagesManager,
keys,
rateLimiters,
deviceConfiguration))
.build();
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addProvider(new DeviceLimitExceededExceptionMapper())
.addResource(new DumbVerificationDeviceController(pendingDevicesManager,
accountsManager,
messagesManager,
keys,
rateLimiters,
deviceConfiguration))
.build();
@BeforeEach
@ -114,15 +116,14 @@ class DeviceControllerTest {
when(account.getNextDeviceId()).thenReturn(42L);
when(account.getNumber()).thenReturn(AuthHelper.VALID_NUMBER);
when(account.getUuid()).thenReturn(AuthHelper.VALID_UUID);
// when(maxedAccount.getActiveDeviceCount()).thenReturn(6);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(masterDevice));
when(account.isEnabled()).thenReturn(false);
when(account.isGroupsV2Supported()).thenReturn(true);
when(account.isGv1MigrationSupported()).thenReturn(true);
when(account.isSenderKeySupported()).thenReturn(true);
when(account.isAnnouncementGroupSupported()).thenReturn(true);
when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(new StoredVerificationCode("5678901", System.currentTimeMillis(), null, null)));
when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER)).thenReturn(
Optional.of(new StoredVerificationCode("5678901", System.currentTimeMillis(), null, null)));
when(pendingDevicesManager.getCodeForNumber(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.empty());
when(accountsManager.get(AuthHelper.VALID_NUMBER)).thenReturn(Optional.of(account));
when(accountsManager.get(AuthHelper.VALID_NUMBER_TWO)).thenReturn(Optional.of(maxedAccount));

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -20,14 +20,14 @@ import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status.Family;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.controllers.DirectoryController;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@ExtendWith(DropwizardExtensionsSupport.class)
@ -37,11 +37,12 @@ class DirectoryControllerTest {
private static final ExternalServiceCredentials validCredentials = new ExternalServiceCredentials("username", "password");
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new DirectoryController(directoryCredentialsGenerator))
.build();
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new DirectoryController(directoryCredentialsGenerator))
.build();
@BeforeEach
void setup() {

View File

@ -26,14 +26,14 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.configuration.DonationConfiguration;
import org.whispersystems.textsecuregcm.configuration.RetryConfiguration;
import org.whispersystems.textsecuregcm.controllers.DonationController;
import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationRequest;
import org.whispersystems.textsecuregcm.entities.ApplePayAuthorizationResponse;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -57,7 +57,8 @@ public class DonationControllerTest {
configuration.setSupportedCurrencies(Set.of("usd", "gbp"));
resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new DonationController(executor, configuration))

View File

@ -44,7 +44,8 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.controllers.KeysController;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
@ -102,8 +103,8 @@ class KeysControllerTest {
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(
AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new RateLimitChallengeExceptionMapper(rateLimitChallengeManager))
.addResource(new ServerRejectedExceptionMapper())

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -68,7 +68,8 @@ import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration;
@ -142,15 +143,16 @@ class MessageControllerTest {
private final ObjectMapper mapper = new ObjectMapper();
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.addProvider(RateLimitExceededExceptionMapper.class)
.addProvider(new RateLimitChallengeExceptionMapper(rateLimitChallengeManager))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new MessageController(rateLimiters, messageSender, receiptSender, accountsManager,
messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager,
rateLimitChallengeManager, reportMessageManager, metricsCluster, receiptExecutor))
.build();
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.addProvider(RateLimitExceededExceptionMapper.class)
.addProvider(new RateLimitChallengeExceptionMapper(rateLimitChallengeManager))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new MessageController(rateLimiters, messageSender, receiptSender, accountsManager,
messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager,
rateLimitChallengeManager, reportMessageManager, metricsCluster, receiptExecutor))
.build();
@BeforeEach
void setup() throws Exception {
@ -576,7 +578,7 @@ class MessageControllerTest {
.delete();
assertThat("Good Response Code", response.getStatus(), is(equalTo(204)));
verify(receiptSender).sendReceipt(any(Account.class), eq("+14152222222"), eq(timestamp));
verify(receiptSender).sendReceipt(any(AuthenticatedAccount.class), eq("+14152222222"), eq(timestamp));
response = resources.getJerseyTest()
.target(String.format("/v1/messages/%s/%d", "+14152222222", 31338))
@ -731,22 +733,54 @@ class MessageControllerTest {
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 3L),
null,
null,
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 2L, 3L),
null,
Set.of(2L)),
Set.of(2L),
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L),
Set.of(3L),
null,
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L, 2L),
Set.of(3L),
Set.of(2L))
Set.of(2L),
false,
null),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(1L),
Set.of(3L),
Set.of(1L),
true,
1L
),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(2L),
Set.of(3L),
Set.of(2L),
true,
1L
),
arguments(
mockAccountWithDeviceAndEnabled(1L, true, 2L, false, 3L, true),
Set.of(3L),
null,
null,
true,
1L
)
);
}
@ -756,10 +790,13 @@ class MessageControllerTest {
Account account,
Set<Long> deviceIds,
Collection<Long> expectedMissingDeviceIds,
Collection<Long> expectedExtraDeviceIds) throws Exception {
Collection<Long> expectedExtraDeviceIds,
boolean isSyncMessage,
Long authenticatedDeviceId) throws Exception {
if (expectedMissingDeviceIds != null || expectedExtraDeviceIds != null) {
final MismatchedDevicesException mismatchedDevicesException = assertThrows(MismatchedDevicesException.class,
() -> MessageController.validateCompleteDeviceList(account, deviceIds, false));
() -> MessageController.validateCompleteDeviceList(account, deviceIds, isSyncMessage,
Optional.ofNullable(authenticatedDeviceId)));
if (expectedMissingDeviceIds != null) {
Assertions.assertThat(mismatchedDevicesException.getMissingDevices())
.hasSameElementsAs(expectedMissingDeviceIds);
@ -768,7 +805,8 @@ class MessageControllerTest {
Assertions.assertThat(mismatchedDevicesException.getExtraDevices()).hasSameElementsAs(expectedExtraDeviceIds);
}
} else {
MessageController.validateCompleteDeviceList(account, deviceIds, false);
MessageController.validateCompleteDeviceList(account, deviceIds, isSyncMessage,
Optional.ofNullable(authenticatedDeviceId));
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -22,14 +22,14 @@ import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.controllers.PaymentsController;
import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager;
import org.whispersystems.textsecuregcm.entities.CurrencyConversionEntity;
import org.whispersystems.textsecuregcm.entities.CurrencyConversionEntityList;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@ExtendWith(DropwizardExtensionsSupport.class)
@ -41,11 +41,12 @@ class PaymentsControllerTest {
private final ExternalServiceCredentials validCredentials = new ExternalServiceCredentials("username", "password");
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new PaymentsController(currencyManager, paymentsCredentialGenerator))
.build();
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new PaymentsController(currencyManager, paymentsCredentialGenerator))
.build();
@BeforeEach

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -41,7 +41,8 @@ import org.signal.zkgroup.profiles.ProfileKey;
import org.signal.zkgroup.profiles.ProfileKeyCommitment;
import org.signal.zkgroup.profiles.ServerZkProfileOperations;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicPaymentsConfiguration;
import org.whispersystems.textsecuregcm.controllers.ProfileController;
@ -87,22 +88,23 @@ class ProfileControllerTest {
private Account profileAccount;
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new ProfileController(rateLimiters,
accountsManager,
profilesManager,
usernamesManager,
dynamicConfigurationManager,
s3client,
postPolicyGenerator,
policySigner,
"profilesBucket",
zkProfileOperations,
true))
.build();
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new ProfileController(rateLimiters,
accountsManager,
profilesManager,
usernamesManager,
dynamicConfigurationManager,
s3client,
postPolicyGenerator,
policySigner,
"profilesBucket",
zkProfileOperations,
true))
.build();
@BeforeEach
void setup() throws Exception {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -30,17 +30,17 @@ import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.controllers.RemoteConfigController;
import org.whispersystems.textsecuregcm.entities.UserRemoteConfig;
import org.whispersystems.textsecuregcm.entities.UserRemoteConfigList;
import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.RemoteConfig;
import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
@ -52,12 +52,13 @@ class RemoteConfigControllerTest {
private static final List<String> remoteConfigsAuth = List.of("foo", "bar");
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addProvider(new DeviceLimitExceededExceptionMapper())
.addResource(new RemoteConfigController(remoteConfigsManager, remoteConfigsAuth, Map.of("maxGroupSize", "42")))
.build();
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addProvider(new DeviceLimitExceededExceptionMapper())
.addResource(new RemoteConfigController(remoteConfigsManager, remoteConfigsAuth, Map.of("maxGroupSize", "42")))
.build();
@BeforeEach

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -15,11 +15,11 @@ import javax.ws.rs.core.Response;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentialGenerator;
import org.whispersystems.textsecuregcm.auth.ExternalServiceCredentials;
import org.whispersystems.textsecuregcm.controllers.SecureStorageController;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -29,12 +29,13 @@ class SecureStorageControllerTest {
private static final ExternalServiceCredentialGenerator storageCredentialGenerator = new ExternalServiceCredentialGenerator(new byte[32], new byte[32], false);
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new SecureStorageController(storageCredentialGenerator))
.build();
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new SecureStorageController(storageCredentialGenerator))
.build();
@Test

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -21,13 +21,13 @@ import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.controllers.StickerController;
import org.whispersystems.textsecuregcm.entities.StickerPackFormUploadAttributes;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -38,12 +38,13 @@ class StickerControllerTest {
private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new StickerController(rateLimiters, "foo", "bar", "us-east-1", "mybucket"))
.build();
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new StickerController(rateLimiters, "foo", "bar", "us-east-1", "mybucket"))
.build();
@BeforeEach
void setup() {

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -18,10 +18,10 @@ import javax.ws.rs.core.Response;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.controllers.VoiceVerificationController;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ -29,14 +29,15 @@ import org.whispersystems.textsecuregcm.util.SystemMapper;
class VoiceVerificationControllerTest {
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.addProvider(new RateLimitExceededExceptionMapper())
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new VoiceVerificationController("https://foo.com/bar",
new HashSet<>(Arrays.asList("pt-BR", "ru"))))
.build();
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.addProvider(new RateLimitExceededExceptionMapper())
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new VoiceVerificationController("https://foo.com/bar",
new HashSet<>(Arrays.asList("pt-BR", "ru"))))
.build();
@Test
void testTwimlLocale() {

View File

@ -113,10 +113,6 @@ public class AccountsHelper {
when(updatedAccount.getMasterDevice()).thenAnswer(stubbing);
break;
}
case "getAuthenticatedDevice": {
when(updatedAccount.getAuthenticatedDevice()).thenAnswer(stubbing);
break;
}
case "isEnabled": {
when(updatedAccount.isEnabled()).thenAnswer(stubbing);
break;

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -24,8 +24,9 @@ import java.util.UUID;
import org.mockito.ArgumentMatcher;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.AuthenticationCredentials;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccountAuthenticator;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
@ -121,11 +122,6 @@ public class AuthHelper {
when(UNDISCOVERABLE_ACCOUNT.getNumber()).thenReturn(UNDISCOVERABLE_NUMBER);
when(UNDISCOVERABLE_ACCOUNT.getUuid()).thenReturn(UNDISCOVERABLE_UUID);
when(VALID_ACCOUNT.getAuthenticatedDevice()).thenReturn(Optional.of(VALID_DEVICE));
when(VALID_ACCOUNT_TWO.getAuthenticatedDevice()).thenReturn(Optional.of(VALID_DEVICE_TWO));
when(DISABLED_ACCOUNT.getAuthenticatedDevice()).thenReturn(Optional.of(DISABLED_DEVICE));
when(UNDISCOVERABLE_ACCOUNT.getAuthenticatedDevice()).thenReturn(Optional.of(UNDISCOVERABLE_DEVICE));
when(VALID_ACCOUNT.getRelay()).thenReturn(Optional.empty());
when(VALID_ACCOUNT_TWO.getRelay()).thenReturn(Optional.empty());
when(UNDISCOVERABLE_ACCOUNT.getRelay()).thenReturn(Optional.empty());
@ -151,18 +147,30 @@ public class AuthHelper {
when(ACCOUNTS_MANAGER.get(VALID_NUMBER_TWO)).thenReturn(Optional.of(VALID_ACCOUNT_TWO));
when(ACCOUNTS_MANAGER.get(VALID_UUID_TWO)).thenReturn(Optional.of(VALID_ACCOUNT_TWO));
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(VALID_NUMBER_TWO)))).thenReturn(Optional.of(VALID_ACCOUNT_TWO));
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid() && identifier.getUuid().equals(VALID_UUID_TWO)))).thenReturn(Optional.of(VALID_ACCOUNT_TWO));
when(ACCOUNTS_MANAGER.get(argThat(
(ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber()
&& identifier.getNumber().equals(VALID_NUMBER_TWO)))).thenReturn(Optional.of(VALID_ACCOUNT_TWO));
when(ACCOUNTS_MANAGER.get(argThat(
(ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid()
&& identifier.getUuid().equals(VALID_UUID_TWO)))).thenReturn(Optional.of(VALID_ACCOUNT_TWO));
when(ACCOUNTS_MANAGER.get(DISABLED_NUMBER)).thenReturn(Optional.of(DISABLED_ACCOUNT));
when(ACCOUNTS_MANAGER.get(DISABLED_UUID)).thenReturn(Optional.of(DISABLED_ACCOUNT));
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(DISABLED_NUMBER)))).thenReturn(Optional.of(DISABLED_ACCOUNT));
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid() && identifier.getUuid().equals(DISABLED_UUID)))).thenReturn(Optional.of(DISABLED_ACCOUNT));
when(ACCOUNTS_MANAGER.get(argThat(
(ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber()
&& identifier.getNumber().equals(DISABLED_NUMBER)))).thenReturn(Optional.of(DISABLED_ACCOUNT));
when(ACCOUNTS_MANAGER.get(argThat(
(ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid()
&& identifier.getUuid().equals(DISABLED_UUID)))).thenReturn(Optional.of(DISABLED_ACCOUNT));
when(ACCOUNTS_MANAGER.get(UNDISCOVERABLE_NUMBER)).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT));
when(ACCOUNTS_MANAGER.get(UNDISCOVERABLE_UUID)).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT));
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(UNDISCOVERABLE_NUMBER)))).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT));
when(ACCOUNTS_MANAGER.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid() && identifier.getUuid().equals(UNDISCOVERABLE_UUID)))).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT));
when(ACCOUNTS_MANAGER.get(argThat(
(ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber()
&& identifier.getNumber().equals(UNDISCOVERABLE_NUMBER)))).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT));
when(ACCOUNTS_MANAGER.get(argThat(
(ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid()
&& identifier.getUuid().equals(UNDISCOVERABLE_UUID)))).thenReturn(Optional.of(UNDISCOVERABLE_ACCOUNT));
AccountsHelper.setupMockUpdateForAuthHelper(ACCOUNTS_MANAGER);
@ -170,11 +178,13 @@ public class AuthHelper {
testAccount.setup(ACCOUNTS_MANAGER);
}
AuthFilter<BasicCredentials, Account> accountAuthFilter = new BasicCredentialAuthFilter.Builder<Account>().setAuthenticator(new AccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter ();
AuthFilter<BasicCredentials, DisabledPermittedAccount> disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder<DisabledPermittedAccount>().setAuthenticator(new DisabledPermittedAccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter();
AuthFilter<BasicCredentials, AuthenticatedAccount> accountAuthFilter = new BasicCredentialAuthFilter.Builder<AuthenticatedAccount>().setAuthenticator(
new AccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter();
AuthFilter<BasicCredentials, DisabledPermittedAuthenticatedAccount> disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder<DisabledPermittedAuthenticatedAccount>().setAuthenticator(
new DisabledPermittedAccountAuthenticator(ACCOUNTS_MANAGER)).buildAuthFilter();
return new PolymorphicAuthDynamicFeature<>(ImmutableMap.of(Account.class, accountAuthFilter,
DisabledPermittedAccount.class, disabledPermittedAccountAuthFilter));
return new PolymorphicAuthDynamicFeature<>(ImmutableMap.of(AuthenticatedAccount.class, accountAuthFilter,
DisabledPermittedAuthenticatedAccount.class, disabledPermittedAccountAuthFilter));
}
public static String getAuthHeader(String number, String password) {
@ -223,13 +233,16 @@ public class AuthHelper {
when(account.getMasterDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn(number);
when(account.getUuid()).thenReturn(uuid);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getRelay()).thenReturn(Optional.empty());
when(account.isEnabled()).thenReturn(true);
when(accountsManager.get(number)).thenReturn(Optional.of(account));
when(accountsManager.get(uuid)).thenReturn(Optional.of(account));
when(accountsManager.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(number)))).thenReturn(Optional.of(account));
when(accountsManager.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid() && identifier.getUuid().equals(uuid)))).thenReturn(Optional.of(account));
when(accountsManager.get(argThat(
(ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber()
&& identifier.getNumber().equals(number)))).thenReturn(Optional.of(account));
when(accountsManager.get(argThat(
(ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasUuid()
&& identifier.getUuid().equals(uuid)))).thenReturn(Optional.of(account));
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -40,6 +40,7 @@ import org.junit.Rule;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
@ -52,6 +53,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.storage.ReportMessageManager;
import org.whispersystems.textsecuregcm.tests.util.MessagesDynamoDbRule;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
@ -90,13 +92,13 @@ public class WebSocketConnectionIntegrationTest extends AbstractRedisClusterTest
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class), reportMessageManager),
account,
device,
webSocketClient,
retrySchedulingExecutor);
webSocketConnection = new WebSocketConnection(
mock(ReceiptSender.class),
new MessagesManager(messagesDynamoDb, messagesCache, mock(PushLatencyManager.class), reportMessageManager),
new AuthenticatedAccount(() -> new Pair<>(account, device)),
device,
webSocketClient,
retrySchedulingExecutor);
}
@After

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -24,6 +24,7 @@ import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.auth.basic.BasicCredentials;
import io.lettuce.core.RedisException;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
@ -39,7 +40,6 @@ import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import io.lettuce.core.RedisException;
import org.apache.commons.lang3.RandomStringUtils;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.junit.Before;
@ -48,6 +48,7 @@ import org.mockito.ArgumentMatchers;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
@ -58,6 +59,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.websocket.WebSocketClient;
import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult;
import org.whispersystems.websocket.messages.WebSocketResponseMessage;
@ -75,6 +77,7 @@ public class WebSocketConnectionTest {
private AccountsManager accountsManager;
private Account account;
private Device device;
private AuthenticatedAccount auth;
private UpgradeRequest upgradeRequest;
private ReceiptSender receiptSender;
private ApnFallbackManager apnFallbackManager;
@ -86,6 +89,7 @@ public class WebSocketConnectionTest {
accountsManager = mock(AccountsManager.class);
account = mock(Account.class);
device = mock(Device.class);
auth = new AuthenticatedAccount(() -> new Pair<>(account, device));
upgradeRequest = mock(UpgradeRequest.class);
receiptSender = mock(ReceiptSender.class);
apnFallbackManager = mock(ApnFallbackManager.class);
@ -94,35 +98,42 @@ public class WebSocketConnectionTest {
@Test
public void testCredentials() throws Exception {
MessagesManager storedMessages = mock(MessagesManager.class);
MessagesManager storedMessages = mock(MessagesManager.class);
WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator(accountAuthenticator);
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages, mock(MessageSender.class), apnFallbackManager, mock(ClientPresenceManager.class),
AuthenticatedConnectListener connectListener = new AuthenticatedConnectListener(receiptSender, storedMessages,
mock(MessageSender.class), apnFallbackManager, mock(ClientPresenceManager.class),
retrySchedulingExecutor);
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class);
when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD))))
.thenReturn(Optional.of(account));
.thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(account, device))));
when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD))))
.thenReturn(Optional.<Account>empty());
.thenReturn(Optional.empty());
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, List<String>>() {{
put("login", new LinkedList<String>() {{add(VALID_USER);}});
put("password", new LinkedList<String>() {{add(VALID_PASSWORD);}});
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<>() {{
put("login", new LinkedList<>() {{
add(VALID_USER);
}});
put("password", new LinkedList<>() {{
add(VALID_PASSWORD);
}});
}});
AuthenticationResult<Account> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated(Account.class)).thenReturn(account.getUser().orElse(null));
AuthenticationResult<AuthenticatedAccount> account = webSocketAuthenticator.authenticate(upgradeRequest);
when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.getUser().orElse(null));
connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext).addListener(any(WebSocketSessionContext.WebSocketEventListener.class));
when(upgradeRequest.getParameterMap()).thenReturn(new HashMap<String, List<String>>() {{
put("login", new LinkedList<String>() {{add(INVALID_USER);}});
put("password", new LinkedList<String>() {{add(INVALID_PASSWORD);}});
put("login", new LinkedList<String>() {{
add(INVALID_USER);
}});
put("password", new LinkedList<String>() {{
add(INVALID_PASSWORD);
}});
}});
account = webSocketAuthenticator.authenticate(upgradeRequest);
@ -148,7 +159,6 @@ public class WebSocketConnectionTest {
when(device.getId()).thenReturn(2L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
@ -184,12 +194,13 @@ public class WebSocketConnectionTest {
});
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages,
account, device, client, retrySchedulingExecutor);
auth, device, client, retrySchedulingExecutor);
connection.start();
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any());
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class),
ArgumentMatchers.<Optional<byte[]>>any());
assertTrue(futures.size() == 3);
assertEquals(3, futures.size());
WebSocketResponseMessage response = mock(WebSocketResponseMessage.class);
when(response.getStatus()).thenReturn(200);
@ -199,7 +210,7 @@ public class WebSocketConnectionTest {
futures.get(2).completeExceptionally(new IOException());
verify(storedMessages, times(1)).delete(eq(accountUuid), eq(2L), eq(outgoingMessages.get(1).getGuid()));
verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender1"), eq(2222L));
verify(receiptSender, times(1)).sendReceipt(eq(auth), eq("sender1"), eq(2222L));
connection.stop();
verify(client).close(anyInt(), anyString());
@ -207,9 +218,10 @@ public class WebSocketConnectionTest {
@Test(timeout = 5_000L)
public void testOnlineSend() throws Exception {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor);
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
final UUID accountUuid = UUID.randomUUID();
@ -219,7 +231,7 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false))
.thenReturn(new OutgoingMessageEntityList(List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first")), false))
.thenReturn(new OutgoingMessageEntityList(List.of(createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second")), false));
@ -300,7 +312,6 @@ public class WebSocketConnectionTest {
when(device.getId()).thenReturn(2L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(UUID.randomUUID());
@ -336,11 +347,12 @@ public class WebSocketConnectionTest {
});
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages,
account, device, client, retrySchedulingExecutor);
auth, device, client, retrySchedulingExecutor);
connection.start();
verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any());
verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), any(List.class),
ArgumentMatchers.<Optional<byte[]>>any());
assertEquals(futures.size(), 2);
@ -349,7 +361,7 @@ public class WebSocketConnectionTest {
futures.get(1).complete(response);
futures.get(0).completeExceptionally(new IOException());
verify(receiptSender, times(1)).sendReceipt(eq(account), eq("sender2"), eq(secondMessage.getTimestamp()));
verify(receiptSender, times(1)).sendReceipt(eq(auth), eq("sender2"), eq(secondMessage.getTimestamp()));
connection.stop();
verify(client).close(anyInt(), anyString());
@ -357,19 +369,21 @@ public class WebSocketConnectionTest {
@Test(timeout = 5000L)
public void testProcessStoredMessageConcurrency() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor);
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
when(device.getId()).thenReturn(1L);
when(client.getUserAgent()).thenReturn("Test-UA");
final AtomicBoolean threadWaiting = new AtomicBoolean(false);
final AtomicBoolean threadWaiting = new AtomicBoolean(false);
final AtomicBoolean returnMessageList = new AtomicBoolean(false);
when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)).thenAnswer((Answer<OutgoingMessageEntityList>)invocation -> {
when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)).thenAnswer(
(Answer<OutgoingMessageEntityList>) invocation -> {
synchronized (threadWaiting) {
threadWaiting.set(true);
threadWaiting.notifyAll();
@ -418,9 +432,10 @@ public class WebSocketConnectionTest {
@Test(timeout = 5000L)
public void testProcessStoredMessagesMultiplePages() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor);
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
@ -428,8 +443,8 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA");
final List<OutgoingMessageEntity> firstPageMessages =
List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"),
createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second"));
List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"),
createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second"));
final List<OutgoingMessageEntity> secondPageMessages =
List.of(createMessage(3L, false, "sender1", UUID.randomUUID(), 3333, false, "third"));
@ -463,7 +478,8 @@ public class WebSocketConnectionTest {
public void testProcessStoredMessagesContainsSenderUuid() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(UUID.randomUUID());
@ -471,7 +487,8 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA");
final UUID senderUuid = UUID.randomUUID();
final List<OutgoingMessageEntity> messages = List.of(createMessage(1L, false, "senderE164", senderUuid, 1111L, false, "message the first"));
final List<OutgoingMessageEntity> messages = List.of(
createMessage(1L, false, "senderE164", senderUuid, 1111L, false, "message the first"));
final OutgoingMessageEntityList firstPage = new OutgoingMessageEntityList(messages, false);
when(messagesManager.getMessagesForDevice(account.getUuid(), 1L, client.getUserAgent(), false)).thenReturn(firstPage);
@ -511,9 +528,10 @@ public class WebSocketConnectionTest {
@Test
public void testProcessStoredMessagesSingleEmptyCall() {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor);
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
final UUID accountUuid = UUID.randomUUID();
@ -523,7 +541,7 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
@ -540,10 +558,11 @@ public class WebSocketConnectionTest {
@Test(timeout = 5000L)
public void testRequeryOnStateMismatch() throws InterruptedException {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor);
final UUID accountUuid = UUID.randomUUID();
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
final UUID accountUuid = UUID.randomUUID();
when(account.getNumber()).thenReturn("+18005551234");
when(account.getUuid()).thenReturn(accountUuid);
@ -551,8 +570,8 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA");
final List<OutgoingMessageEntity> firstPageMessages =
List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"),
createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second"));
List.of(createMessage(1L, false, "sender1", UUID.randomUUID(), 1111, false, "first"),
createMessage(2L, false, "sender1", UUID.randomUUID(), 2222, false, "second"));
final List<OutgoingMessageEntity> secondPageMessages =
List.of(createMessage(3L, false, "sender1", UUID.randomUUID(), 3333, false, "third"));
@ -587,9 +606,10 @@ public class WebSocketConnectionTest {
@Test
public void testProcessCachedMessagesOnly() {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor);
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
final UUID accountUuid = UUID.randomUUID();
@ -599,7 +619,7 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
@ -619,9 +639,10 @@ public class WebSocketConnectionTest {
@Test
public void testProcessDatabaseMessagesAfterPersist() {
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, account, device, client, retrySchedulingExecutor);
final MessagesManager messagesManager = mock(MessagesManager.class);
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, client,
retrySchedulingExecutor);
final UUID accountUuid = UUID.randomUUID();
@ -631,7 +652,7 @@ public class WebSocketConnectionTest {
when(client.getUserAgent()).thenReturn("Test-UA");
when(messagesManager.getMessagesForDevice(eq(accountUuid), eq(1L), eq("Test-UA"), anyBoolean()))
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
.thenReturn(new OutgoingMessageEntityList(Collections.emptyList(), false));
final WebSocketResponseMessage successResponse = mock(WebSocketResponseMessage.class);
when(successResponse.getStatus()).thenReturn(200);
@ -664,7 +685,6 @@ public class WebSocketConnectionTest {
when(device.getId()).thenReturn(2L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
@ -689,20 +709,24 @@ public class WebSocketConnectionTest {
final WebSocketClient client = mock(WebSocketClient.class);
when(client.getUserAgent()).thenReturn(userAgent);
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any()))
.thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() {
@Override
public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock) throws Throwable {
CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
futures.add(future);
return future;
}
});
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class),
ArgumentMatchers.<Optional<byte[]>>any()))
.thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() {
@Override
public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock)
throws Throwable {
CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
futures.add(future);
return future;
}
});
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client, retrySchedulingExecutor);
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client,
retrySchedulingExecutor);
connection.start();
verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any());
verify(client, times(2)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class),
ArgumentMatchers.<Optional<byte[]>>any());
assertEquals(2, futures.size());
@ -737,7 +761,6 @@ public class WebSocketConnectionTest {
when(device.getId()).thenReturn(2L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
@ -762,20 +785,24 @@ public class WebSocketConnectionTest {
final WebSocketClient client = mock(WebSocketClient.class);
when(client.getUserAgent()).thenReturn(userAgent);
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any()))
.thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() {
@Override
public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock) throws Throwable {
CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
futures.add(future);
return future;
}
});
when(client.sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class),
ArgumentMatchers.<Optional<byte[]>>any()))
.thenAnswer(new Answer<CompletableFuture<WebSocketResponseMessage>>() {
@Override
public CompletableFuture<WebSocketResponseMessage> answer(InvocationOnMock invocationOnMock)
throws Throwable {
CompletableFuture<WebSocketResponseMessage> future = new CompletableFuture<>();
futures.add(future);
return future;
}
});
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client, retrySchedulingExecutor);
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client,
retrySchedulingExecutor);
connection.start();
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class), ArgumentMatchers.<Optional<byte[]>>any());
verify(client, times(3)).sendRequest(eq("PUT"), eq("/api/v1/message"), ArgumentMatchers.nullable(List.class),
ArgumentMatchers.<Optional<byte[]>>any());
assertEquals(3, futures.size());
@ -799,7 +826,6 @@ public class WebSocketConnectionTest {
when(device.getId()).thenReturn(2L);
when(account.getAuthenticatedDevice()).thenReturn(Optional.of(device));
when(account.getNumber()).thenReturn("+14152222222");
when(account.getUuid()).thenReturn(accountUuid);
@ -808,17 +834,20 @@ public class WebSocketConnectionTest {
when(storedMessages.getMessagesForDevice(account.getUuid(), device.getId(), userAgent, false))
.thenThrow(new RedisException("OH NO"));
when(retrySchedulingExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer((Answer<ScheduledFuture<?>>) invocation -> {
invocation.getArgument(0, Runnable.class).run();
return mock(ScheduledFuture.class);
});
when(retrySchedulingExecutor.schedule(any(Runnable.class), anyLong(), any())).thenAnswer(
(Answer<ScheduledFuture<?>>) invocation -> {
invocation.getArgument(0, Runnable.class).run();
return mock(ScheduledFuture.class);
});
final WebSocketClient client = mock(WebSocketClient.class);
final WebSocketClient client = mock(WebSocketClient.class);
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, account, device, client, retrySchedulingExecutor);
WebSocketConnection connection = new WebSocketConnection(receiptSender, storedMessages, auth, device, client,
retrySchedulingExecutor);
connection.start();
verify(retrySchedulingExecutor, times(WebSocketConnection.MAX_CONSECUTIVE_RETRIES)).schedule(any(Runnable.class), anyLong(), any());
verify(retrySchedulingExecutor, times(WebSocketConnection.MAX_CONSECUTIVE_RETRIES)).schedule(any(Runnable.class),
anyLong(), any());
verify(client).close(eq(1011), anyString());
}