Add client challenges for prekey and message rate limiters

This commit is contained in:
Jon Chambers 2021-05-11 17:21:32 -04:00 committed by GitHub
parent 5752853bba
commit 46110d4d65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
46 changed files with 2289 additions and 255 deletions

View File

@ -152,6 +152,11 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty
private DynamoDbConfiguration migrationRetryAccountsDynamoDb;
@Valid
@NotNull
@JsonProperty
private DynamoDbConfiguration pushChallengeDynamoDb;
@Valid
@NotNull
@JsonProperty
@ -433,6 +438,10 @@ public class WhisperServerConfiguration extends Configuration {
return appConfig;
}
public DynamoDbConfiguration getPushChallengeDynamoDbConfiguration() {
return pushChallengeDynamoDb;
}
public TorExitNodeConfiguration getTorExitNodeConfiguration() {
return tor;
}

View File

@ -78,6 +78,7 @@ import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV1;
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV2;
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV3;
import org.whispersystems.textsecuregcm.controllers.CertificateController;
import org.whispersystems.textsecuregcm.controllers.ChallengeController;
import org.whispersystems.textsecuregcm.controllers.DeviceController;
import org.whispersystems.textsecuregcm.controllers.DirectoryController;
import org.whispersystems.textsecuregcm.controllers.DonationController;
@ -98,11 +99,17 @@ import org.whispersystems.textsecuregcm.currency.FtxClient;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter;
import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter;
import org.whispersystems.textsecuregcm.limits.PreKeyRateLimiter;
import org.whispersystems.textsecuregcm.limits.PushChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimitResetMetricsManager;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.limits.UnsealedSenderRateLimiter;
import org.whispersystems.textsecuregcm.liquibase.NameableMigrationsBundle;
import org.whispersystems.textsecuregcm.mappers.DeviceLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.IOExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.InvalidWebsocketAddressExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitChallengeExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RetryLaterExceptionMapper;
import org.whispersystems.textsecuregcm.metrics.BufferPoolGauges;
@ -169,6 +176,7 @@ import org.whispersystems.textsecuregcm.storage.PendingDevicesManager;
import org.whispersystems.textsecuregcm.storage.Profiles;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PushChallengeDynamoDb;
import org.whispersystems.textsecuregcm.storage.PushFeedbackProcessor;
import org.whispersystems.textsecuregcm.storage.RegistrationLockVersionCounter;
import org.whispersystems.textsecuregcm.storage.RemoteConfigs;
@ -314,6 +322,13 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.withRequestTimeout((int) config.getMigrationRetryAccountsDynamoDbConfiguration().getClientRequestTimeout().toMillis()))
.withCredentials(InstanceProfileCredentialsProvider.getInstance());
AmazonDynamoDBClientBuilder pushChallengeDynamoDbClientBuilder = AmazonDynamoDBClientBuilder
.standard()
.withRegion(config.getPushChallengeDynamoDbConfiguration().getRegion())
.withClientConfiguration(new ClientConfiguration().withClientExecutionTimeout(((int) config.getPushChallengeDynamoDbConfiguration().getClientExecutionTimeout().toMillis()))
.withRequestTimeout((int) config.getPushChallengeDynamoDbConfiguration().getClientRequestTimeout().toMillis()))
.withCredentials(InstanceProfileCredentialsProvider.getInstance());
DynamoDB messageDynamoDb = new DynamoDB(messageDynamoDbClientBuilder.build());
DynamoDB preKeyDynamoDb = new DynamoDB(keysDynamoDbClientBuilder.build());
@ -337,6 +352,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(messageDynamoDb, config.getMessageDynamoDbConfiguration().getTableName(), config.getMessageDynamoDbConfiguration().getTimeToLive());
AbusiveHostRules abusiveHostRules = new AbusiveHostRules(abuseDatabase);
RemoteConfigs remoteConfigs = new RemoteConfigs(accountDatabase);
PushChallengeDynamoDb pushChallengeDynamoDb = new PushChallengeDynamoDb(new DynamoDB(pushChallengeDynamoDbClientBuilder.build()), config.getPushChallengeDynamoDbConfiguration().getTableName());
RedisClientFactory pubSubClientFactory = new RedisClientFactory("pubsub_cache", config.getPubsubCacheConfiguration().getUrl(), config.getPubsubCacheConfiguration().getReplicaUrls(), config.getPubsubCacheConfiguration().getCircuitBreakerConfiguration());
ReplicatedJedisPool pubsubClient = pubSubClientFactory.getRedisClientPool();
@ -415,6 +431,11 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager);
DisabledPermittedAccountAuthenticator disabledPermittedAccountAuthenticator = new DisabledPermittedAccountAuthenticator(accountsManager);
RateLimitResetMetricsManager rateLimitResetMetricsManager = new RateLimitResetMetricsManager(metricsCluster, Metrics.globalRegistry);
UnsealedSenderRateLimiter unsealedSenderRateLimiter = new UnsealedSenderRateLimiter(rateLimiters, rateLimitersCluster, dynamicConfigurationManager, rateLimitResetMetricsManager);
PreKeyRateLimiter preKeyRateLimiter = new PreKeyRateLimiter(rateLimiters, dynamicConfigurationManager, rateLimitResetMetricsManager);
ApnFallbackManager apnFallbackManager = new ApnFallbackManager(pushSchedulerCluster, apnSender, accountsManager);
TwilioSmsSender twilioSmsSender = new TwilioSmsSender(config.getTwilioConfiguration(), dynamicConfigurationManager);
SmsSender smsSender = new SmsSender(twilioSmsSender);
@ -422,6 +443,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
ReceiptSender receiptSender = new ReceiptSender(accountsManager, messageSender);
TurnTokenGenerator turnTokenGenerator = new TurnTokenGenerator(config.getTurnConfiguration());
RecaptchaClient recaptchaClient = new RecaptchaClient(config.getRecaptchaConfiguration().getSecret());
PushChallengeManager pushChallengeManager = new PushChallengeManager(apnSender, gcmSender, pushChallengeDynamoDb);
RateLimitChallengeManager rateLimitChallengeManager = new RateLimitChallengeManager(pushChallengeManager, recaptchaClient, preKeyRateLimiter, unsealedSenderRateLimiter, rateLimiters, dynamicConfigurationManager);
MessagePersister messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, dynamicConfigurationManager, Duration.ofMinutes(config.getMessageCacheConfiguration().getPersistDelayMinutes()));
@ -472,11 +495,12 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
AttachmentControllerV2 attachmentControllerV2 = new AttachmentControllerV2(rateLimiters, config.getAwsAttachmentsConfiguration().getAccessKey(), config.getAwsAttachmentsConfiguration().getAccessSecret(), config.getAwsAttachmentsConfiguration().getRegion(), config.getAwsAttachmentsConfiguration().getBucket());
AttachmentControllerV3 attachmentControllerV3 = new AttachmentControllerV3(rateLimiters, config.getGcpAttachmentsConfiguration().getDomain(), config.getGcpAttachmentsConfiguration().getEmail(), config.getGcpAttachmentsConfiguration().getMaxSizeInBytes(), config.getGcpAttachmentsConfiguration().getPathPrefix(), config.getGcpAttachmentsConfiguration().getRsaSigningKey());
DonationController donationController = new DonationController(donationExecutor, config.getDonationConfiguration());
KeysController keysController = new KeysController(rateLimiters, keysDynamoDb, accountsManager, directoryQueue);
MessageController messageController = new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, messagesManager, apnFallbackManager, dynamicConfigurationManager, metricsCluster, declinedMessageReceiptExecutor);
KeysController keysController = new KeysController(rateLimiters, keysDynamoDb, accountsManager, directoryQueue, preKeyRateLimiter, dynamicConfigurationManager, rateLimitChallengeManager);
MessageController messageController = new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager, rateLimitChallengeManager, metricsCluster, declinedMessageReceiptExecutor);
ProfileController profileController = new ProfileController(rateLimiters, accountsManager, profilesManager, usernamesManager, dynamicConfigurationManager, cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, config.getCdnConfiguration().getBucket(), zkProfileOperations, isZkEnabled);
StickerController stickerController = new StickerController(rateLimiters, config.getCdnConfiguration().getAccessKey(), config.getCdnConfiguration().getAccessSecret(), config.getCdnConfiguration().getRegion(), config.getCdnConfiguration().getBucket());
RemoteConfigController remoteConfigController = new RemoteConfigController(remoteConfigsManager, config.getRemoteConfigConfiguration().getAuthorizedTokens(), config.getRemoteConfigConfiguration().getGlobalConfig());
ChallengeController challengeController = new ChallengeController(rateLimitChallengeManager);
AuthFilter<BasicCredentials, Account> accountAuthFilter = new BasicCredentialAuthFilter.Builder<Account>().setAuthenticator(accountAuthenticator).buildAuthFilter ();
AuthFilter<BasicCredentials, DisabledPermittedAccount> disabledPermittedAccountAuthFilter = new BasicCredentialAuthFilter.Builder<DisabledPermittedAccount>().setAuthenticator(disabledPermittedAccountAuthenticator).buildAuthFilter();
@ -508,6 +532,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.jersey().register(profileController);
environment.jersey().register(stickerController);
environment.jersey().register(remoteConfigController);
environment.jersey().register(challengeController);
///
WebSocketEnvironment<Account> webSocketEnvironment = new WebSocketEnvironment<>(environment, config.getWebSocketConfiguration(), 90000);
@ -531,6 +556,12 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
registerCorsFilter(environment);
registerExceptionMappers(environment, webSocketEnvironment, provisioningEnvironment);
RateLimitChallengeExceptionMapper rateLimitChallengeExceptionMapper = new RateLimitChallengeExceptionMapper(rateLimitChallengeManager);
environment.jersey().register(rateLimitChallengeExceptionMapper);
webSocketEnvironment.jersey().register(rateLimitChallengeExceptionMapper);
provisioningEnvironment.jersey().register(rateLimitChallengeExceptionMapper);
WebSocketResourceProviderFactory<Account> webSocketServlet = new WebSocketResourceProviderFactory<>(webSocketEnvironment, Account.class);
WebSocketResourceProviderFactory<Account> provisioningServlet = new WebSocketResourceProviderFactory<>(provisioningEnvironment, Account.class);

View File

@ -165,16 +165,12 @@ public class RateLimitsConfiguration {
@JsonProperty
private Duration ttl;
@JsonProperty
private Duration ttlJitter;
public CardinalityRateLimitConfiguration() {
}
public CardinalityRateLimitConfiguration(int maxCardinality, Duration ttl, Duration ttlJitter) {
public CardinalityRateLimitConfiguration(int maxCardinality, Duration ttl) {
this.maxCardinality = maxCardinality;
this.ttl = ttl;
this.ttlJitter = ttlJitter;
}
public int getMaxCardinality() {
@ -184,9 +180,5 @@ public class RateLimitsConfiguration {
public Duration getTtl() {
return ttl;
}
public Duration getTtlJitter() {
return ttlJitter;
}
}
}

View File

@ -47,6 +47,10 @@ public class DynamicConfiguration {
@JsonProperty
private DynamicAccountsDynamoDbMigrationConfiguration accountsDynamoDbMigration = new DynamicAccountsDynamoDbMigrationConfiguration();
@JsonProperty
@Valid
private DynamicRateLimitChallengeConfiguration rateLimitChallenge = new DynamicRateLimitChallengeConfiguration();
public Optional<DynamicExperimentEnrollmentConfiguration> getExperimentEnrollmentConfiguration(
final String experimentName) {
return Optional.ofNullable(experiments.get(experimentName));
@ -93,4 +97,8 @@ public class DynamicConfiguration {
public DynamicAccountsDynamoDbMigrationConfiguration getAccountsDynamoDbMigrationConfiguration() {
return accountsDynamoDbMigration;
}
public DynamicRateLimitChallengeConfiguration getRateLimitChallengeConfiguration() {
return rateLimitChallenge;
}
}

View File

@ -36,6 +36,7 @@ public class DynamicMessageRateConfiguration {
@JsonProperty
private double receiptProbability = 0.82;
public boolean isEnforceUnsealedSenderRateLimit() {
return enforceUnsealedSenderRateLimit;
}

View File

@ -0,0 +1,40 @@
package org.whispersystems.textsecuregcm.configuration.dynamic;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import com.vdurmont.semver4j.Semver;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import javax.validation.constraints.NotNull;
public class DynamicRateLimitChallengeConfiguration {
@JsonProperty
private boolean preKeyLimitEnforced = false;
@JsonProperty
boolean unsealedSenderLimitEnforced = false;
@JsonProperty
@NotNull
private Map<ClientPlatform, Semver> clientSupportedVersions = Collections.emptyMap();
@VisibleForTesting
Map<ClientPlatform, Semver> getClientSupportedVersions() {
return clientSupportedVersions;
}
public Optional<Semver> getMinimumSupportedVersion(final ClientPlatform platform) {
return Optional.ofNullable(clientSupportedVersions.get(platform));
}
public boolean isPreKeyLimitEnforced() {
return preKeyLimitEnforced;
}
public boolean isUnsealedSenderLimitEnforced() {
return unsealedSenderLimitEnforced;
}
}

View File

@ -8,11 +8,35 @@ import java.time.Duration;
public class DynamicRateLimitsConfiguration {
@JsonProperty
private CardinalityRateLimitConfiguration unsealedSenderNumber = new CardinalityRateLimitConfiguration(100, Duration.ofDays(1), Duration.ofDays(1));
private CardinalityRateLimitConfiguration unsealedSenderNumber = new CardinalityRateLimitConfiguration(100, Duration.ofDays(1));
@JsonProperty
private int unsealedSenderDefaultCardinalityLimit = 100;
@JsonProperty
private int unsealedSenderPermitIncrement = 50;
@JsonProperty
private RateLimitConfiguration unsealedSenderIp = new RateLimitConfiguration(120, 2.0 / 60);
@JsonProperty
private RateLimitConfiguration rateLimitReset = new RateLimitConfiguration(2, 2.0 / (60 * 24));
@JsonProperty
private RateLimitConfiguration recaptchaChallengeAttempt = new RateLimitConfiguration(10, 10.0 / (60 * 24));
@JsonProperty
private RateLimitConfiguration recaptchaChallengeSuccess = new RateLimitConfiguration(2, 2.0 / (60 * 24));
@JsonProperty
private RateLimitConfiguration pushChallengeAttempt = new RateLimitConfiguration(10, 10.0 / (60 * 24));
@JsonProperty
private RateLimitConfiguration pushChallengeSuccess = new RateLimitConfiguration(2, 2.0 / (60 * 24));
@JsonProperty
private RateLimitConfiguration dailyPreKeys = new RateLimitConfiguration(50, 50.0 / (24.0 * 60));
public RateLimitConfiguration getUnsealedSenderIp() {
return unsealedSenderIp;
}
@ -20,4 +44,36 @@ public class DynamicRateLimitsConfiguration {
public CardinalityRateLimitConfiguration getUnsealedSenderNumber() {
return unsealedSenderNumber;
}
public RateLimitConfiguration getRateLimitReset() {
return rateLimitReset;
}
public RateLimitConfiguration getRecaptchaChallengeAttempt() {
return recaptchaChallengeAttempt;
}
public RateLimitConfiguration getRecaptchaChallengeSuccess() {
return recaptchaChallengeSuccess;
}
public RateLimitConfiguration getPushChallengeAttempt() {
return pushChallengeAttempt;
}
public RateLimitConfiguration getPushChallengeSuccess() {
return pushChallengeSuccess;
}
public int getUnsealedSenderDefaultCardinalityLimit() {
return unsealedSenderDefaultCardinalityLimit;
}
public int getUnsealedSenderPermitIncrement() {
return unsealedSenderPermitIncrement;
}
public RateLimitConfiguration getDailyPreKeys() {
return dailyPreKeys;
}
}

View File

@ -190,7 +190,7 @@ public class AccountController {
if ("fcm".equals(pushType)) {
gcmSender.sendMessage(new GcmMessage(pushToken, number, 0, GcmMessage.Type.CHALLENGE, Optional.of(storedVerificationCode.getPushCode())));
} else if ("apn".equals(pushType)) {
apnSender.sendMessage(new ApnMessage(pushToken, number, 0, true, Optional.of(storedVerificationCode.getPushCode())));
apnSender.sendMessage(new ApnMessage(pushToken, number, 0, true, ApnMessage.Type.CHALLENGE, Optional.of(storedVerificationCode.getPushCode())));
} else {
throw new AssertionError();
}

View File

@ -0,0 +1,80 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import com.codahale.metrics.annotation.Timed;
import io.dropwizard.auth.Auth;
import java.util.NoSuchElementException;
import javax.validation.Valid;
import javax.ws.rs.Consumes;
import javax.ws.rs.HeaderParam;
import javax.ws.rs.POST;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.whispersystems.textsecuregcm.entities.AnswerChallengeRequest;
import org.whispersystems.textsecuregcm.entities.AnswerPushChallengeRequest;
import org.whispersystems.textsecuregcm.entities.AnswerRecaptchaChallengeRequest;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.ForwardedIpUtil;
@Path("/v1/challenge")
public class ChallengeController {
private final RateLimitChallengeManager rateLimitChallengeManager;
public ChallengeController(final RateLimitChallengeManager rateLimitChallengeManager) {
this.rateLimitChallengeManager = rateLimitChallengeManager;
}
@Timed
@PUT
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
public Response handleChallengeResponse(@Auth final Account account,
@Valid final AnswerChallengeRequest answerRequest,
@HeaderParam("X-Forwarded-For") String forwardedFor) throws RetryLaterException {
try {
if (answerRequest instanceof AnswerPushChallengeRequest) {
final AnswerPushChallengeRequest pushChallengeRequest = (AnswerPushChallengeRequest) answerRequest;
rateLimitChallengeManager.answerPushChallenge(account, pushChallengeRequest.getChallenge());
} else if (answerRequest instanceof AnswerRecaptchaChallengeRequest) {
try {
final AnswerRecaptchaChallengeRequest recaptchaChallengeRequest = (AnswerRecaptchaChallengeRequest) answerRequest;
final String mostRecentProxy = ForwardedIpUtil.getMostRecentProxy(forwardedFor).orElseThrow();
rateLimitChallengeManager.answerRecaptchaChallenge(account, recaptchaChallengeRequest.getCaptcha(), mostRecentProxy);
} catch (final NoSuchElementException e) {
return Response.status(400).build();
}
}
} catch (final RateLimitExceededException e) {
throw new RetryLaterException(e);
}
return Response.status(200).build();
}
@Timed
@POST
@Path("/push")
public Response requestPushChallenge(@Auth final Account account) {
try {
rateLimitChallengeManager.sendPushChallenge(account);
return Response.status(200).build();
} catch (final NotPushRegisteredException e) {
return Response.status(404).build();
}
}
}

View File

@ -36,11 +36,15 @@ import org.whispersystems.textsecuregcm.entities.PreKeyResponse;
import org.whispersystems.textsecuregcm.entities.PreKeyResponseItem;
import org.whispersystems.textsecuregcm.entities.PreKeyState;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.limits.PreKeyRateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeException;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.KeysDynamoDb;
import org.whispersystems.textsecuregcm.util.Util;
@ -52,18 +56,30 @@ public class KeysController {
private final KeysDynamoDb keysDynamoDb;
private final AccountsManager accounts;
private final DirectoryQueue directoryQueue;
private final PreKeyRateLimiter preKeyRateLimiter;
private final DynamicConfigurationManager dynamicConfigurationManager;
private final RateLimitChallengeManager rateLimitChallengeManager;
private static final String PREKEY_REQUEST_COUNTER_NAME = name(KeysController.class, "preKeyGet");
private static final String RATE_LIMITED_GET_PREKEYS_COUNTER_NAME = name(KeysController.class, "rateLimitedGetPreKeys");
private static final String SOURCE_COUNTRY_TAG_NAME = "sourceCountry";
private static final String INTERNATIONAL_TAG_NAME = "international";
private static final String PREKEY_TARGET_IDENTIFIER_TAG_NAME = "identifierType";
public KeysController(RateLimiters rateLimiters, KeysDynamoDb keysDynamoDb, AccountsManager accounts, DirectoryQueue directoryQueue) {
public KeysController(RateLimiters rateLimiters, KeysDynamoDb keysDynamoDb, AccountsManager accounts,
DirectoryQueue directoryQueue, PreKeyRateLimiter preKeyRateLimiter,
DynamicConfigurationManager dynamicConfigurationManager,
RateLimitChallengeManager rateLimitChallengeManager) {
this.rateLimiters = rateLimiters;
this.keysDynamoDb = keysDynamoDb;
this.accounts = accounts;
this.directoryQueue = directoryQueue;
this.preKeyRateLimiter = preKeyRateLimiter;
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.rateLimitChallengeManager = rateLimitChallengeManager;
}
@GET
@ -112,12 +128,12 @@ public class KeysController {
@GET
@Path("/{identifier}/{device_id}")
@Produces(MediaType.APPLICATION_JSON)
public Optional<PreKeyResponse> getDeviceKeys(@Auth Optional<Account> account,
public Response getDeviceKeys(@Auth Optional<Account> account,
@HeaderParam(OptionalAccess.UNIDENTIFIED) Optional<Anonymous> accessKey,
@PathParam("identifier") AmbiguousIdentifier targetName,
@PathParam("device_id") String deviceId)
throws RateLimitExceededException
{
@PathParam("device_id") String deviceId,
@HeaderParam("User-Agent") String userAgent)
throws RateLimitExceededException, RateLimitChallengeException {
if (!account.isPresent() && !accessKey.isPresent()) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}
@ -127,10 +143,6 @@ public class KeysController {
assert(target.isPresent());
if (account.isPresent()) {
rateLimiters.getPreKeysLimiter().validate(account.get().getNumber() + "." + account.get().getAuthenticatedDevice().get().getId() + "__" + target.get().getNumber() + "." + deviceId);
}
{
final String sourceCountryCode = account.map(a -> Util.getCountryCode(a.getNumber())).orElse("0");
final String targetCountryCode = target.map(a -> Util.getCountryCode(a.getNumber())).orElseThrow();
@ -142,6 +154,26 @@ public class KeysController {
)).increment();
}
if (account.isPresent()) {
rateLimiters.getPreKeysLimiter().validate(account.get().getNumber() + "." + account.get().getAuthenticatedDevice().get().getId() + "__" + target.get().getNumber() + "." + deviceId);
try {
preKeyRateLimiter.validate(account.get());
} catch (RateLimitExceededException e) {
final boolean enforceLimit = rateLimitChallengeManager.shouldIssueRateLimitChallenge(userAgent);
Metrics.counter(RATE_LIMITED_GET_PREKEYS_COUNTER_NAME,
SOURCE_COUNTRY_TAG_NAME, Util.getCountryCode(account.get().getNumber()),
"enforced", String.valueOf(enforceLimit))
.increment();
if (enforceLimit) {
throw new RateLimitChallengeException(account.get(), e.getRetryDuration());
}
}
}
Map<Long, PreKey> preKeysByDeviceId = getLocalKeys(target.get(), deviceId);
List<PreKeyResponseItem> responseItems = new LinkedList<>();
@ -156,8 +188,8 @@ public class KeysController {
}
}
if (responseItems.isEmpty()) return Optional.empty();
else return Optional.of(new PreKeyResponse(target.get().getIdentityKey(), responseItems));
if (responseItems.isEmpty()) return Response.status(404).build();
else return Response.ok().entity(new PreKeyResponse(target.get().getIdentityKey(), responseItems)).build();
}
@Timed

View File

@ -71,7 +71,10 @@ import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMessageResponse;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeException;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.limits.UnsealedSenderRateLimiter;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.providers.MultiRecipientMessageProvider;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
@ -111,8 +114,10 @@ public class MessageController {
private final ReceiptSender receiptSender;
private final AccountsManager accountsManager;
private final MessagesManager messagesManager;
private final UnsealedSenderRateLimiter unsealedSenderRateLimiter;
private final ApnFallbackManager apnFallbackManager;
private final DynamicConfigurationManager dynamicConfigurationManager;
private final RateLimitChallengeManager rateLimitChallengeManager;
private final ScheduledExecutorService receiptExecutorService;
private final Random random = new Random();
@ -138,8 +143,10 @@ public class MessageController {
ReceiptSender receiptSender,
AccountsManager accountsManager,
MessagesManager messagesManager,
UnsealedSenderRateLimiter unsealedSenderRateLimiter,
ApnFallbackManager apnFallbackManager,
DynamicConfigurationManager dynamicConfigurationManager,
RateLimitChallengeManager rateLimitChallengeManager,
FaultTolerantRedisCluster metricsCluster,
ScheduledExecutorService receiptExecutorService)
{
@ -148,8 +155,10 @@ public class MessageController {
this.receiptSender = receiptSender;
this.accountsManager = accountsManager;
this.messagesManager = messagesManager;
this.unsealedSenderRateLimiter = unsealedSenderRateLimiter;
this.apnFallbackManager = apnFallbackManager;
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.rateLimitChallengeManager = rateLimitChallengeManager;
this.receiptExecutorService = receiptExecutorService;
try {
@ -171,8 +180,7 @@ public class MessageController {
@HeaderParam("X-Forwarded-For") String forwardedFor,
@PathParam("destination") AmbiguousIdentifier destinationName,
@Valid IncomingMessageList messages)
throws RateLimitExceededException
{
throws RateLimitExceededException, RateLimitChallengeException {
if (source.isEmpty() && accessKey.isEmpty()) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}
@ -186,19 +194,6 @@ public class MessageController {
if (StringUtils.isAllBlank(masterDevice.getApnId(), masterDevice.getVoipApnId(), masterDevice.getGcmId()) || masterDevice.getUninstalledFeedbackTimestamp() > 0) {
Metrics.counter(UNSEALED_SENDER_WITHOUT_PUSH_TOKEN_COUNTER_NAME, SENDER_COUNTRY_TAG_NAME, senderCountryCode).increment();
}
try {
rateLimiters.getUnsealedSenderLimiter().validate(source.get().getNumber(), destinationName.toString());
} catch (RateLimitExceededException e) {
if (dynamicConfigurationManager.getConfiguration().getMessageRateConfiguration().isEnforceUnsealedSenderRateLimit()) {
Metrics.counter(REJECT_UNSEALED_SENDER_COUNTER_NAME, SENDER_COUNTRY_TAG_NAME, senderCountryCode).increment();
logger.debug("Rejected unsealed sender limit from: {}", source.get().getNumber());
throw e;
} else {
logger.debug("Would reject unsealed sender limit from: {}", source.get().getNumber());
}
}
}
final String senderType;
@ -247,6 +242,27 @@ public class MessageController {
rateLimiters.getMessagesLimiter().validate(source.get().getNumber() + "__" + destination.get().getUuid());
final String senderCountryCode = Util.getCountryCode(source.get().getNumber());
try {
unsealedSenderRateLimiter.validate(source.get(), destination.get());
} catch (final RateLimitExceededException e) {
final boolean enforceLimit = rateLimitChallengeManager.shouldIssueRateLimitChallenge(userAgent);
Metrics.counter(REJECT_UNSEALED_SENDER_COUNTER_NAME,
SENDER_COUNTRY_TAG_NAME, senderCountryCode,
"enforced", String.valueOf(enforceLimit))
.increment();
if (enforceLimit) {
logger.debug("Rejected unsealed sender limit from: {}", source.get().getNumber());
throw new RateLimitChallengeException(source.get(), e.getRetryDuration());
} else {
throw e;
}
}
final String destinationCountryCode = Util.getCountryCode(destination.get().getNumber());
final Device masterDevice = source.get().getMasterDevice().get();

View File

@ -0,0 +1,18 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import javax.validation.constraints.NotBlank;
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type")
@JsonSubTypes({
@JsonSubTypes.Type(value = AnswerPushChallengeRequest.class, name = "rateLimitPushChallenge"),
@JsonSubTypes.Type(value = AnswerRecaptchaChallengeRequest.class, name = "recaptcha")
})
public abstract class AnswerChallengeRequest {
}

View File

@ -0,0 +1,18 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import javax.validation.constraints.NotBlank;
public class AnswerPushChallengeRequest extends AnswerChallengeRequest {
@NotBlank
private String challenge;
public String getChallenge() {
return challenge;
}
}

View File

@ -0,0 +1,25 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import javax.validation.constraints.NotBlank;
public class AnswerRecaptchaChallengeRequest extends AnswerChallengeRequest {
@NotBlank
private String token;
@NotBlank
private String captcha;
public String getToken() {
return token;
}
public String getCaptcha() {
return captcha;
}
}

View File

@ -0,0 +1,32 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
import javax.validation.constraints.NotNull;
public class RateLimitChallenge {
@JsonProperty
@NotNull
private final String token;
@JsonProperty
@NotNull
private final List<String> options;
@JsonCreator
public RateLimitChallenge(@JsonProperty("token") final String token, @JsonProperty("options") final List<String> options) {
this.token = token;
this.options = options;
}
public String getToken() {
return token;
}
public List<String> getOptions() {
return options;
}
}

View File

@ -6,7 +6,6 @@
package org.whispersystems.textsecuregcm.limits;
import java.time.Duration;
import java.util.Random;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.CardinalityRateLimitConfiguration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
@ -24,25 +23,22 @@ public class CardinalityRateLimiter {
private final String name;
private final Duration ttl;
private final Duration ttlJitter;
private final int maxCardinality;
private final int defaultMaxCardinality;
private final Random random = new Random();
public CardinalityRateLimiter(final FaultTolerantRedisCluster cacheCluster, final String name, final Duration ttl, final Duration ttlJitter, final int maxCardinality) {
public CardinalityRateLimiter(final FaultTolerantRedisCluster cacheCluster, final String name, final Duration ttl, final int defaultMaxCardinality) {
this.cacheCluster = cacheCluster;
this.name = name;
this.ttl = ttl;
this.ttlJitter = ttlJitter;
this.maxCardinality = maxCardinality;
this.defaultMaxCardinality = defaultMaxCardinality;
}
public void validate(final String key, final String target) throws RateLimitExceededException {
final String hllKey = getHllKey(key);
public void validate(final String key, final String target, final int maxCardinality) throws RateLimitExceededException {
final boolean rateLimitExceeded = cacheCluster.withCluster(connection -> {
final String hllKey = getHllKey(key);
final boolean changed = connection.sync().pfadd(hllKey, target) == 1;
final long cardinality = connection.sync().pfcount(hllKey);
@ -51,16 +47,14 @@ public class CardinalityRateLimiter {
// If the set already existed, we can assume it already had an expiration time and can save a round trip by
// skipping the ttl check.
if (mayNeedExpiration && connection.sync().ttl(hllKey) == -1) {
final long expireSeconds = ttl.plusSeconds(random.nextInt((int) ttlJitter.toSeconds())).toSeconds();
connection.sync().expire(hllKey, expireSeconds);
connection.sync().expire(hllKey, ttl.toSeconds());
}
return changed && cardinality > maxCardinality;
});
if (rateLimitExceeded) {
// Using the TTL as the "retry after" time isn't EXACTLY right, but it's a reasonable approximation
throw new RateLimitExceededException(ttl);
throw new RateLimitExceededException(Duration.ofSeconds(getRemainingTtl(key)));
}
}
@ -68,21 +62,20 @@ public class CardinalityRateLimiter {
return "hll_rate_limit::" + name + "::" + key;
}
public Duration getTtl() {
public Duration getInitialTtl() {
return ttl;
}
public Duration getTtlJitter() {
return ttlJitter;
public long getRemainingTtl(final String key) {
return cacheCluster.withCluster(connection -> connection.sync().ttl(getHllKey(key)));
}
public int getMaxCardinality() {
return maxCardinality;
public int getDefaultMaxCardinality() {
return defaultMaxCardinality;
}
public boolean hasConfiguration(final CardinalityRateLimitConfiguration configuration) {
return maxCardinality == configuration.getMaxCardinality() &&
ttl.equals(configuration.getTtl()) &&
ttlJitter.equals(configuration.getTtlJitter());
return defaultMaxCardinality == configuration.getMaxCardinality() && ttl.equals(configuration.getTtl());
}
}

View File

@ -0,0 +1,78 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static com.codahale.metrics.MetricRegistry.name;
import io.dropwizard.util.Duration;
import io.micrometer.core.instrument.Metrics;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.Util;
public class PreKeyRateLimiter {
private static final String RATE_LIMIT_RESET_COUNTER_NAME = name(PreKeyRateLimiter.class, "reset");
private static final String RATE_LIMITED_PREKEYS_COUNTER_NAME = name(PreKeyRateLimiter.class, "rateLimited");
private static final String RATE_LIMITED_PREKEYS_TOTAL_ACCOUNTS_COUNTER_NAME = name(PreKeyRateLimiter.class, "rateLimited");
private static final String RATE_LIMITED_PREKEYS_ACCOUNTS_ENFORCED_COUNTER_NAME = name(PreKeyRateLimiter.class, "rateLimitedAccountsEnforced");
private static final String RATE_LIMITED_PREKEYS_ACCOUNTS_UNENFORCED_COUNTER_NAME = name(PreKeyRateLimiter.class, "rateLimitedAccountsUnenforced");
private static final String RATE_LIMITED_ACCOUNTS_HLL_KEY = "PreKeyRateLimiter::rateLimitedAccounts";
private static final String RATE_LIMITED_ACCOUNTS_ENFORCED_HLL_KEY = "PreKeyRateLimiter::rateLimitedAccounts::enforced";
private static final String RATE_LIMITED_ACCOUNTS_UNENFORCED_HLL_KEY = "PreKeyRateLimiter::rateLimitedAccounts::unenforced";
private static final long RATE_LIMITED_ACCOUNTS_HLL_TTL_SECONDS = Duration.days(1).toSeconds();
private final RateLimiters rateLimiters;
private final DynamicConfigurationManager dynamicConfigurationManager;
private final RateLimitResetMetricsManager metricsManager;
public PreKeyRateLimiter(final RateLimiters rateLimiters,
final DynamicConfigurationManager dynamicConfigurationManager,
final RateLimitResetMetricsManager metricsManager) {
this.rateLimiters = rateLimiters;
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.metricsManager = metricsManager;
metricsManager.initializeFunctionCounters(RATE_LIMITED_PREKEYS_TOTAL_ACCOUNTS_COUNTER_NAME,
RATE_LIMITED_ACCOUNTS_HLL_KEY);
metricsManager.initializeFunctionCounters(RATE_LIMITED_PREKEYS_ACCOUNTS_ENFORCED_COUNTER_NAME,
RATE_LIMITED_ACCOUNTS_ENFORCED_HLL_KEY);
metricsManager.initializeFunctionCounters(RATE_LIMITED_PREKEYS_ACCOUNTS_UNENFORCED_COUNTER_NAME,
RATE_LIMITED_ACCOUNTS_UNENFORCED_HLL_KEY);
}
public void validate(final Account account) throws RateLimitExceededException {
try {
rateLimiters.getDailyPreKeysLimiter().validate(account.getNumber());
} catch (final RateLimitExceededException e) {
final boolean enforceLimit = dynamicConfigurationManager.getConfiguration()
.getRateLimitChallengeConfiguration().isPreKeyLimitEnforced();
metricsManager.recordMetrics(account, enforceLimit,
RATE_LIMITED_PREKEYS_COUNTER_NAME,
enforceLimit ? RATE_LIMITED_ACCOUNTS_ENFORCED_HLL_KEY : RATE_LIMITED_ACCOUNTS_UNENFORCED_HLL_KEY,
RATE_LIMITED_ACCOUNTS_HLL_KEY,
RATE_LIMITED_ACCOUNTS_HLL_TTL_SECONDS
);
if (enforceLimit) {
throw e;
}
}
}
public void handleRateLimitReset(final Account account) {
rateLimiters.getDailyPreKeysLimiter().clear(account.getNumber());
Metrics.counter(RATE_LIMIT_RESET_COUNTER_NAME, "countryCode", Util.getCountryCode(account.getNumber()))
.increment();
}
}

View File

@ -0,0 +1,115 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import io.micrometer.core.instrument.Metrics;
import org.apache.commons.codec.DecoderException;
import org.apache.commons.codec.binary.Hex;
import org.apache.commons.lang3.StringUtils;
import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.ApnMessage;
import org.whispersystems.textsecuregcm.push.ApnMessage.Type;
import org.whispersystems.textsecuregcm.push.GCMSender;
import org.whispersystems.textsecuregcm.push.GcmMessage;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.PushChallengeDynamoDb;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import java.security.SecureRandom;
import java.time.Duration;
import java.util.Optional;
import static com.codahale.metrics.MetricRegistry.name;
public class PushChallengeManager {
private final APNSender apnSender;
private final GCMSender gcmSender;
private final PushChallengeDynamoDb pushChallengeDynamoDb;
private final SecureRandom random = new SecureRandom();
private static final int CHALLENGE_TOKEN_LENGTH = 16;
private static final Duration CHALLENGE_TTL = Duration.ofMinutes(5);
private static final String CHALLENGE_REQUESTED_COUNTER_NAME = name(PushChallengeManager.class, "requested");
private static final String CHALLENGE_ANSWERED_COUNTER_NAME = name(PushChallengeManager.class, "answered");
private static final String PLATFORM_TAG_NAME = "platform";
private static final String SENT_TAG_NAME = "sent";
private static final String SUCCESS_TAG_NAME = "success";
public PushChallengeManager(final APNSender apnSender, final GCMSender gcmSender,
final PushChallengeDynamoDb pushChallengeDynamoDb) {
this.apnSender = apnSender;
this.gcmSender = gcmSender;
this.pushChallengeDynamoDb = pushChallengeDynamoDb;
}
public void sendChallenge(final Account account) throws NotPushRegisteredException {
final Device masterDevice = account.getMasterDevice().orElseThrow(NotPushRegisteredException::new);
if (StringUtils.isAllBlank(masterDevice.getGcmId(), masterDevice.getApnId())) {
throw new NotPushRegisteredException();
}
final byte[] token = new byte[CHALLENGE_TOKEN_LENGTH];
random.nextBytes(token);
final boolean sent;
final String platform;
if (pushChallengeDynamoDb.add(account.getUuid(), token, CHALLENGE_TTL)) {
final String tokenHex = Hex.encodeHexString(token);
sent = true;
if (StringUtils.isNotBlank(masterDevice.getGcmId())) {
gcmSender.sendMessage(new GcmMessage(masterDevice.getGcmId(), account.getNumber(), 0, GcmMessage.Type.RATE_LIMIT_CHALLENGE, Optional.of(tokenHex)));
platform = ClientPlatform.ANDROID.name().toLowerCase();
} else if (StringUtils.isNotBlank(masterDevice.getApnId())) {
apnSender.sendMessage(new ApnMessage(masterDevice.getApnId(), account.getNumber(), 0, false, Type.RATE_LIMIT_CHALLENGE, Optional.of(tokenHex)));
platform = ClientPlatform.IOS.name().toLowerCase();
} else {
throw new AssertionError();
}
} else {
sent = false;
platform = null;
}
Metrics.counter(CHALLENGE_REQUESTED_COUNTER_NAME,
PLATFORM_TAG_NAME, platform,
SENT_TAG_NAME, String.valueOf(sent)).increment();
}
public boolean answerChallenge(final Account account, final String challengeTokenHex) {
boolean success = false;
try {
success = pushChallengeDynamoDb.remove(account.getUuid(), Hex.decodeHex(challengeTokenHex));
} catch (final DecoderException ignored) {
}
final String platform = account.getMasterDevice().map(masterDevice -> {
if (StringUtils.isNotBlank(masterDevice.getGcmId())) {
return ClientPlatform.IOS.name().toLowerCase();
} else if (StringUtils.isNotBlank(masterDevice.getApnId())) {
return ClientPlatform.ANDROID.name().toLowerCase();
} else {
return "unknown";
}
}).orElse("unknown");
Metrics.counter(CHALLENGE_ANSWERED_COUNTER_NAME,
PLATFORM_TAG_NAME, platform,
SUCCESS_TAG_NAME, String.valueOf(success)).increment();
return success;
}
}

View File

@ -0,0 +1,23 @@
package org.whispersystems.textsecuregcm.limits;
import org.whispersystems.textsecuregcm.storage.Account;
import java.time.Duration;
public class RateLimitChallengeException extends Exception {
private final Account account;
private final Duration retryAfter;
public RateLimitChallengeException(final Account account, final Duration retryAfter) {
this.account = account;
this.retryAfter = retryAfter;
}
public Account getAccount() {
return account;
}
public Duration getRetryAfter() {
return retryAfter;
}
}

View File

@ -0,0 +1,114 @@
package org.whispersystems.textsecuregcm.limits;
import com.vdurmont.semver4j.Semver;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.recaptcha.RecaptchaClient;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
public class RateLimitChallengeManager {
private final PushChallengeManager pushChallengeManager;
private final RecaptchaClient recaptchaClient;
private final PreKeyRateLimiter preKeyRateLimiter;
private final UnsealedSenderRateLimiter unsealedSenderRateLimiter;
private final RateLimiters rateLimiters;
private final DynamicConfigurationManager dynamicConfigurationManager;
public static final String OPTION_RECAPTCHA = "recaptcha";
public static final String OPTION_PUSH_CHALLENGE = "pushChallenge";
public RateLimitChallengeManager(
final PushChallengeManager pushChallengeManager,
final RecaptchaClient recaptchaClient,
final PreKeyRateLimiter preKeyRateLimiter,
final UnsealedSenderRateLimiter unsealedSenderRateLimiter,
final RateLimiters rateLimiters,
final DynamicConfigurationManager dynamicConfigurationManager) {
this.pushChallengeManager = pushChallengeManager;
this.recaptchaClient = recaptchaClient;
this.preKeyRateLimiter = preKeyRateLimiter;
this.unsealedSenderRateLimiter = unsealedSenderRateLimiter;
this.rateLimiters = rateLimiters;
this.dynamicConfigurationManager = dynamicConfigurationManager;
}
public void answerPushChallenge(final Account account, final String challenge) throws RateLimitExceededException {
rateLimiters.getPushChallengeAttemptLimiter().validate(account.getNumber());
final boolean challengeSuccess = pushChallengeManager.answerChallenge(account, challenge);
if (challengeSuccess) {
rateLimiters.getPushChallengeSuccessLimiter().validate(account.getNumber());
resetRateLimits(account);
}
}
public void answerRecaptchaChallenge(final Account account, final String captcha, final String mostRecentProxyIp)
throws RateLimitExceededException {
rateLimiters.getRecaptchaChallengeAttemptLimiter().validate(account.getNumber());
final boolean challengeSuccess = recaptchaClient.verify(captcha, mostRecentProxyIp);
if (challengeSuccess) {
rateLimiters.getRecaptchaChallengeSuccessLimiter().validate(account.getNumber());
resetRateLimits(account);
}
}
private void resetRateLimits(final Account account) throws RateLimitExceededException {
rateLimiters.getRateLimitResetLimiter().validate(account.getNumber());
preKeyRateLimiter.handleRateLimitReset(account);
unsealedSenderRateLimiter.handleRateLimitReset(account);
}
public boolean shouldIssueRateLimitChallenge(final String userAgent) {
try {
final UserAgent client = UserAgentUtil.parseUserAgentString(userAgent);
final Optional<Semver> minimumClientVersion = dynamicConfigurationManager.getConfiguration()
.getRateLimitChallengeConfiguration()
.getMinimumSupportedVersion(client.getPlatform());
return minimumClientVersion.map(version -> version.isLowerThanOrEqualTo(client.getVersion()))
.orElse(false);
} catch (final UnrecognizedUserAgentException ignored) {
return false;
}
}
public List<String> getChallengeOptions(final Account account) {
final List<String> options = new ArrayList<>(2);
final String key = account.getNumber();
if (rateLimiters.getRecaptchaChallengeAttemptLimiter().hasAvailablePermits(key, 1) &&
rateLimiters.getRecaptchaChallengeSuccessLimiter().hasAvailablePermits(key, 1)) {
options.add(OPTION_RECAPTCHA);
}
if (rateLimiters.getPushChallengeAttemptLimiter().hasAvailablePermits(key, 1) &&
rateLimiters.getPushChallengeSuccessLimiter().hasAvailablePermits(key, 1)) {
options.add(OPTION_PUSH_CHALLENGE);
}
return options;
}
public void sendPushChallenge(final Account account) throws NotPushRegisteredException {
pushChallengeManager.sendChallenge(account);
}
}

View File

@ -0,0 +1,40 @@
package org.whispersystems.textsecuregcm.limits;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.FunctionCounter;
import io.micrometer.core.instrument.MeterRegistry;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.Account;
public class RateLimitResetMetricsManager {
private final FaultTolerantRedisCluster metricsCluster;
private final MeterRegistry meterRegistry;
public RateLimitResetMetricsManager(
final FaultTolerantRedisCluster metricsCluster, final MeterRegistry meterRegistry) {
this.metricsCluster = metricsCluster;
this.meterRegistry = meterRegistry;
}
void initializeFunctionCounters(String counterKey, String hllKey) {
FunctionCounter.builder(counterKey, null, (ignored) ->
metricsCluster.<Long>withCluster(conn -> conn.sync().pfcount(hllKey))).register(meterRegistry);
}
void recordMetrics(Account account, boolean enforced, String counterKey, String hllEnforcedKey, String hllTotalKey,
long hllTtl) {
Counter.builder(counterKey)
.tag("enforced", String.valueOf(enforced))
.register(meterRegistry)
.increment();
metricsCluster.useCluster(connection -> {
connection.sync().pfadd(hllEnforcedKey, account.getUuid().toString());
connection.sync().expire(hllEnforcedKey, hllTtl);
connection.sync().pfadd(hllTotalKey, account.getUuid().toString());
connection.sync().expire(hllTotalKey, hllTtl);
});
}
}

View File

@ -13,6 +13,7 @@ import com.codahale.metrics.Timer;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration;
@ -64,6 +65,10 @@ public class RateLimiter {
validate(key, 1);
}
public boolean hasAvailablePermits(final String key, final int permits) {
return getBucket(key).getTimeUntilSpaceAvailable(permits).equals(Duration.ZERO);
}
public void clear(String key) {
cacheCluster.useCluster(connection -> connection.sync().del(getBucketName(key)));
}

View File

@ -6,6 +6,7 @@ package org.whispersystems.textsecuregcm.limits;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.CardinalityRateLimitConfiguration;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration;
@ -37,8 +38,14 @@ public class RateLimiters {
private final RateLimiter usernameLookupLimiter;
private final RateLimiter usernameSetLimiter;
private final AtomicReference<CardinalityRateLimiter> unsealedSenderLimiter;
private final AtomicReference<CardinalityRateLimiter> unsealedSenderCardinalityLimiter;
private final AtomicReference<RateLimiter> unsealedIpLimiter;
private final AtomicReference<RateLimiter> rateLimitResetLimiter;
private final AtomicReference<RateLimiter> recaptchaChallengeAttemptLimiter;
private final AtomicReference<RateLimiter> recaptchaChallengeSuccessLimiter;
private final AtomicReference<RateLimiter> pushChallengeAttemptLimiter;
private final AtomicReference<RateLimiter> pushChallengeSuccessLimiter;
private final AtomicReference<RateLimiter> dailyPreKeysLimiter;
private final FaultTolerantRedisCluster cacheCluster;
private final DynamicConfigurationManager dynamicConfig;
@ -119,30 +126,90 @@ public class RateLimiters {
config.getUsernameSet().getBucketSize(),
config.getUsernameSet().getLeakRatePerMinute());
this.unsealedSenderLimiter = new AtomicReference<>(createUnsealedSenderLimiter(cacheCluster, dynamicConfig.getConfiguration().getLimits().getUnsealedSenderNumber()));
this.dailyPreKeysLimiter = new AtomicReference<>(createDailyPreKeysLimiter(cacheCluster, dynamicConfig.getConfiguration().getLimits().getDailyPreKeys()));
this.unsealedSenderCardinalityLimiter = new AtomicReference<>(createUnsealedSenderCardinalityLimiter(cacheCluster, dynamicConfig.getConfiguration().getLimits().getUnsealedSenderNumber()));
this.unsealedIpLimiter = new AtomicReference<>(createUnsealedIpLimiter(cacheCluster, dynamicConfig.getConfiguration().getLimits().getUnsealedSenderIp()));
this.rateLimitResetLimiter = new AtomicReference<>(
createRateLimitResetLimiter(cacheCluster, dynamicConfig.getConfiguration().getLimits().getRateLimitReset()));
this.recaptchaChallengeAttemptLimiter = new AtomicReference<>(createRecaptchaChallengeAttemptLimiter(cacheCluster, dynamicConfig.getConfiguration().getLimits().getRecaptchaChallengeAttempt()));
this.recaptchaChallengeSuccessLimiter = new AtomicReference<>(createRecaptchaChallengeSuccessLimiter(cacheCluster, dynamicConfig.getConfiguration().getLimits().getRecaptchaChallengeSuccess()));
this.pushChallengeAttemptLimiter = new AtomicReference<>(createPushChallengeAttemptLimiter(cacheCluster, dynamicConfig.getConfiguration().getLimits().getPushChallengeAttempt()));
this.pushChallengeSuccessLimiter = new AtomicReference<>(createPushChallengeSuccessLimiter(cacheCluster, dynamicConfig.getConfiguration().getLimits().getPushChallengeSuccess()));
}
public CardinalityRateLimiter getUnsealedSenderLimiter() {
public CardinalityRateLimiter getUnsealedSenderCardinalityLimiter() {
CardinalityRateLimitConfiguration currentConfiguration = dynamicConfig.getConfiguration().getLimits().getUnsealedSenderNumber();
return this.unsealedSenderLimiter.updateAndGet(rateLimiter -> {
return this.unsealedSenderCardinalityLimiter.updateAndGet(rateLimiter -> {
if (rateLimiter.hasConfiguration(currentConfiguration)) {
return rateLimiter;
} else {
return createUnsealedSenderLimiter(cacheCluster, currentConfiguration);
return createUnsealedSenderCardinalityLimiter(cacheCluster, currentConfiguration);
}
});
}
public RateLimiter getUnsealedIpLimiter() {
RateLimitConfiguration currentConfiguration = dynamicConfig.getConfiguration().getLimits().getUnsealedSenderIp();
return updateAndGetRateLimiter(
unsealedIpLimiter,
dynamicConfig.getConfiguration().getLimits().getUnsealedSenderIp(),
this::createUnsealedIpLimiter);
}
return this.unsealedIpLimiter.updateAndGet(rateLimiter -> {
if (rateLimiter.hasConfiguration(currentConfiguration)) {
return rateLimiter;
public RateLimiter getRateLimitResetLimiter() {
return updateAndGetRateLimiter(
rateLimitResetLimiter,
dynamicConfig.getConfiguration().getLimits().getRateLimitReset(),
this::createRateLimitResetLimiter);
}
public RateLimiter getRecaptchaChallengeAttemptLimiter() {
return updateAndGetRateLimiter(
recaptchaChallengeAttemptLimiter,
dynamicConfig.getConfiguration().getLimits().getRecaptchaChallengeAttempt(),
this::createRecaptchaChallengeAttemptLimiter);
}
public RateLimiter getRecaptchaChallengeSuccessLimiter() {
return updateAndGetRateLimiter(
recaptchaChallengeSuccessLimiter,
dynamicConfig.getConfiguration().getLimits().getRecaptchaChallengeSuccess(),
this::createRecaptchaChallengeSuccessLimiter);
}
public RateLimiter getPushChallengeAttemptLimiter() {
return updateAndGetRateLimiter(
pushChallengeAttemptLimiter,
dynamicConfig.getConfiguration().getLimits().getPushChallengeAttempt(),
this::createPushChallengeAttemptLimiter);
}
public RateLimiter getPushChallengeSuccessLimiter() {
return updateAndGetRateLimiter(
pushChallengeSuccessLimiter,
dynamicConfig.getConfiguration().getLimits().getPushChallengeSuccess(),
this::createPushChallengeSuccessLimiter);
}
public RateLimiter getDailyPreKeysLimiter() {
return updateAndGetRateLimiter(
dailyPreKeysLimiter,
dynamicConfig.getConfiguration().getLimits().getDailyPreKeys(),
this::createDailyPreKeysLimiter);
}
private RateLimiter updateAndGetRateLimiter(final AtomicReference<RateLimiter> rateLimiter,
RateLimitConfiguration currentConfiguration,
BiFunction<FaultTolerantRedisCluster, RateLimitConfiguration, RateLimiter> rateLimitFactory) {
return rateLimiter.updateAndGet(limiter -> {
if (limiter.hasConfiguration(currentConfiguration)) {
return limiter;
} else {
return createUnsealedIpLimiter(cacheCluster, currentConfiguration);
return rateLimitFactory.apply(cacheCluster, currentConfiguration);
}
});
}
@ -219,8 +286,8 @@ public class RateLimiters {
return usernameSetLimiter;
}
private CardinalityRateLimiter createUnsealedSenderLimiter(FaultTolerantRedisCluster cacheCluster, CardinalityRateLimitConfiguration configuration) {
return new CardinalityRateLimiter(cacheCluster, "unsealedSender", configuration.getTtl(), configuration.getTtlJitter(), configuration.getMaxCardinality());
private CardinalityRateLimiter createUnsealedSenderCardinalityLimiter(FaultTolerantRedisCluster cacheCluster, CardinalityRateLimitConfiguration configuration) {
return new CardinalityRateLimiter(cacheCluster, "unsealedSender", configuration.getTtl(), configuration.getMaxCardinality());
}
private RateLimiter createUnsealedIpLimiter(FaultTolerantRedisCluster cacheCluster, RateLimitConfiguration configuration)
@ -228,6 +295,30 @@ public class RateLimiters {
return createLimiter(cacheCluster, configuration, "unsealedIp");
}
public RateLimiter createRateLimitResetLimiter(FaultTolerantRedisCluster cacheCluster, RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "rateLimitReset");
}
public RateLimiter createRecaptchaChallengeAttemptLimiter(FaultTolerantRedisCluster cacheCluster, RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "recaptchaChallengeAttempt");
}
public RateLimiter createRecaptchaChallengeSuccessLimiter(FaultTolerantRedisCluster cacheCluster, RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "recaptchaChallengeSuccess");
}
public RateLimiter createPushChallengeAttemptLimiter(FaultTolerantRedisCluster cacheCluster, RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "pushChallengeAttempt");
}
public RateLimiter createPushChallengeSuccessLimiter(FaultTolerantRedisCluster cacheCluster, RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "pushChallengeSuccess");
}
public RateLimiter createDailyPreKeysLimiter(FaultTolerantRedisCluster cacheCluster, RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "dailyPreKeys");
}
private RateLimiter createLimiter(FaultTolerantRedisCluster cacheCluster, RateLimitConfiguration configuration, String name) {
return new RateLimiter(cacheCluster, name,
configuration.getBucketSize(),

View File

@ -0,0 +1,114 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static com.codahale.metrics.MetricRegistry.name;
import io.dropwizard.util.Duration;
import io.lettuce.core.SetArgs;
import io.micrometer.core.instrument.Metrics;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitsConfiguration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.Util;
public class UnsealedSenderRateLimiter {
private final RateLimiters rateLimiters;
private final FaultTolerantRedisCluster rateLimitCluster;
private final DynamicConfigurationManager dynamicConfigurationManager;
private final RateLimitResetMetricsManager metricsManager;
private static final String RATE_LIMIT_RESET_COUNTER_NAME = name(UnsealedSenderRateLimiter.class, "reset");
private static final String RATE_LIMITED_UNSEALED_SENDER_COUNTER_NAME = name(UnsealedSenderRateLimiter.class, "rateLimited");
private static final String RATE_LIMITED_UNSEALED_SENDER_ACCOUNTS_TOTAL_COUNTER_NAME = name(UnsealedSenderRateLimiter.class, "rateLimitedAccountsTotal");
private static final String RATE_LIMITED_UNSEALED_SENDER_ACCOUNTS_ENFORCED_COUNTER_NAME = name(UnsealedSenderRateLimiter.class, "rateLimitedAccountsEnforced");
private static final String RATE_LIMITED_UNSEALED_SENDER_ACCOUNTS_UNENFORCED_COUNTER_NAME = name(UnsealedSenderRateLimiter.class, "rateLimitedAccountsUnenforced");
private static final String RATE_LIMITED_ACCOUNTS_HLL_KEY = "UnsealedSenderRateLimiter::rateLimitedAccounts::total";
private static final String RATE_LIMITED_ACCOUNTS_ENFORCED_HLL_KEY = "UnsealedSenderRateLimiter::rateLimitedAccounts::enforced";
private static final String RATE_LIMITED_ACCOUNTS_UNENFORCED_HLL_KEY = "UnsealedSenderRateLimiter::rateLimitedAccounts::unenforced";
private static final long RATE_LIMITED_ACCOUNTS_HLL_TTL_SECONDS = Duration.days(1).toSeconds();
public UnsealedSenderRateLimiter(final RateLimiters rateLimiters,
final FaultTolerantRedisCluster rateLimitCluster,
final DynamicConfigurationManager dynamicConfigurationManager,
final RateLimitResetMetricsManager metricsManager) {
this.rateLimiters = rateLimiters;
this.rateLimitCluster = rateLimitCluster;
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.metricsManager = metricsManager;
metricsManager.initializeFunctionCounters(RATE_LIMITED_UNSEALED_SENDER_ACCOUNTS_TOTAL_COUNTER_NAME,
RATE_LIMITED_ACCOUNTS_HLL_KEY);
metricsManager.initializeFunctionCounters(RATE_LIMITED_UNSEALED_SENDER_ACCOUNTS_ENFORCED_COUNTER_NAME,
RATE_LIMITED_ACCOUNTS_ENFORCED_HLL_KEY);
metricsManager.initializeFunctionCounters(RATE_LIMITED_UNSEALED_SENDER_ACCOUNTS_UNENFORCED_COUNTER_NAME,
RATE_LIMITED_ACCOUNTS_UNENFORCED_HLL_KEY);
}
public void validate(final Account sender, final Account destination) throws RateLimitExceededException {
final int maxCardinality = rateLimitCluster.withCluster(connection -> {
final String cardinalityString = connection.sync().get(getMaxCardinalityKey(sender));
return cardinalityString != null
? Integer.parseInt(cardinalityString)
: dynamicConfigurationManager.getConfiguration().getLimits().getUnsealedSenderDefaultCardinalityLimit();
});
try {
rateLimiters.getUnsealedSenderCardinalityLimiter()
.validate(sender.getNumber(), destination.getUuid().toString(), maxCardinality);
} catch (final RateLimitExceededException e) {
final boolean enforceLimit = dynamicConfigurationManager.getConfiguration()
.getRateLimitChallengeConfiguration().isUnsealedSenderLimitEnforced();
metricsManager.recordMetrics(sender, enforceLimit, RATE_LIMITED_UNSEALED_SENDER_COUNTER_NAME,
enforceLimit ? RATE_LIMITED_ACCOUNTS_ENFORCED_HLL_KEY : RATE_LIMITED_ACCOUNTS_UNENFORCED_HLL_KEY,
RATE_LIMITED_ACCOUNTS_HLL_KEY,
RATE_LIMITED_ACCOUNTS_HLL_TTL_SECONDS
);
if (enforceLimit) {
throw e;
}
}
}
public void handleRateLimitReset(final Account account) {
rateLimitCluster.useCluster(connection -> {
final CardinalityRateLimiter unsealedSenderCardinalityLimiter = rateLimiters.getUnsealedSenderCardinalityLimiter();
final DynamicRateLimitsConfiguration rateLimitsConfiguration =
dynamicConfigurationManager.getConfiguration().getLimits();
final long ttl;
{
final long remainingTtl = unsealedSenderCardinalityLimiter.getRemainingTtl(account.getNumber());
ttl = remainingTtl > 0 ? remainingTtl : unsealedSenderCardinalityLimiter.getInitialTtl().toSeconds();
}
final String key = getMaxCardinalityKey(account);
connection.sync().set(key,
String.valueOf(rateLimitsConfiguration.getUnsealedSenderDefaultCardinalityLimit()),
SetArgs.Builder.nx().ex(ttl));
connection.sync().incrby(key, rateLimitsConfiguration.getUnsealedSenderPermitIncrement());
});
Metrics.counter(RATE_LIMIT_RESET_COUNTER_NAME,
"countryCode", Util.getCountryCode(account.getNumber())).increment();
}
private static String getMaxCardinalityKey(final Account account) {
return "max_unsealed_sender_cardinality::" + account.getUuid();
}
}

View File

@ -0,0 +1,25 @@
package org.whispersystems.textsecuregcm.mappers;
import org.whispersystems.textsecuregcm.entities.RateLimitChallenge;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeException;
import javax.ws.rs.core.Response;
import javax.ws.rs.ext.ExceptionMapper;
import java.util.UUID;
public class RateLimitChallengeExceptionMapper implements ExceptionMapper<RateLimitChallengeException> {
private final RateLimitChallengeManager rateLimitChallengeManager;
public RateLimitChallengeExceptionMapper(final RateLimitChallengeManager rateLimitChallengeManager) {
this.rateLimitChallengeManager = rateLimitChallengeManager;
}
@Override
public Response toResponse(final RateLimitChallengeException exception) {
return Response.status(428)
.entity(new RateLimitChallenge(UUID.randomUUID().toString(), rateLimitChallengeManager.getChallengeOptions(exception.getAccount())))
.header("Retry-After", exception.getRetryAfter().toSeconds())
.build();
}
}

View File

@ -15,6 +15,7 @@ import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.cluster.SlotHash;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.push.ApnMessage.Type;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.redis.RedisException;
@ -192,7 +193,7 @@ public class ApnFallbackManager implements Managed {
return;
}
apnSender.sendMessage(new ApnMessage(apnId, account.getNumber(), device.getId(), true, Optional.empty()));
apnSender.sendMessage(new ApnMessage(apnId, account.getNumber(), device.getId(), true, Type.NOTIFICATION, Optional.empty()));
retry.mark();
}

View File

@ -12,21 +12,28 @@ import java.util.Optional;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class ApnMessage {
public enum Type {
NOTIFICATION, CHALLENGE, RATE_LIMIT_CHALLENGE
}
public static final String APN_NOTIFICATION_PAYLOAD = "{\"aps\":{\"sound\":\"default\",\"alert\":{\"loc-key\":\"APN_Message\"}}}";
public static final String APN_CHALLENGE_PAYLOAD = "{\"aps\":{\"sound\":\"default\",\"alert\":{\"loc-key\":\"APN_Message\"}}, \"challenge\" : \"%s\"}";
public static final String APN_RATE_LIMIT_CHALLENGE_PAYLOAD = "{\"aps\":{\"sound\":\"default\",\"alert\":{\"loc-key\":\"APN_Message\"}}, \"rateLimitChallenge\" : \"%s\"}";
public static final long MAX_EXPIRATION = Integer.MAX_VALUE * 1000L;
private final String apnId;
private final String number;
private final long deviceId;
private final boolean isVoip;
private final Type type;
private final Optional<String> challengeData;
public ApnMessage(String apnId, String number, long deviceId, boolean isVoip, Optional<String> challengeData) {
public ApnMessage(String apnId, String number, long deviceId, boolean isVoip, Type type, Optional<String> challengeData) {
this.apnId = apnId;
this.number = number;
this.deviceId = deviceId;
this.isVoip = isVoip;
this.type = type;
this.challengeData = challengeData;
}
@ -39,8 +46,19 @@ public class ApnMessage {
}
public String getMessage() {
if (!challengeData.isPresent()) return APN_NOTIFICATION_PAYLOAD;
else return String.format(APN_CHALLENGE_PAYLOAD, challengeData.get());
switch (type) {
case NOTIFICATION:
return APN_NOTIFICATION_PAYLOAD;
case CHALLENGE:
return String.format(APN_CHALLENGE_PAYLOAD, challengeData.orElseThrow(AssertionError::new));
case RATE_LIMIT_CHALLENGE:
return String.format(APN_RATE_LIMIT_CHALLENGE_PAYLOAD, challengeData.orElseThrow(AssertionError::new));
default:
throw new AssertionError();
}
}
@VisibleForTesting

View File

@ -5,10 +5,18 @@
package org.whispersystems.textsecuregcm.push;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.google.common.annotations.VisibleForTesting;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.gcm.server.Message;
@ -22,15 +30,6 @@ import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.Util;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import static com.codahale.metrics.MetricRegistry.name;
public class GCMSender {
private final Logger logger = LoggerFactory.getLogger(GCMSender.class);
@ -45,6 +44,7 @@ public class GCMSender {
put("receipt", metricRegistry.meter(name(getClass(), "outbound", "receipt")));
put("notification", metricRegistry.meter(name(getClass(), "outbound", "notification")));
put("challenge", metricRegistry.meter(name(getClass(), "outbound", "challenge")));
put("rateLimitChallenge", metricRegistry.meter(name(getClass(), "outbound", "rateLimitChallenge")));
}};
private final AccountsManager accountsManager;
@ -74,6 +74,7 @@ public class GCMSender {
switch (message.getType()) {
case NOTIFICATION: key = "notification"; break;
case CHALLENGE: key = "challenge"; break;
case RATE_LIMIT_CHALLENGE: key = "rateLimitChallenge"; break;
default: throw new AssertionError();
}

View File

@ -12,7 +12,7 @@ import java.util.Optional;
public class GcmMessage {
public enum Type {
NOTIFICATION, CHALLENGE
NOTIFICATION, CHALLENGE, RATE_LIMIT_CHALLENGE
}
private final String gcmId;

View File

@ -8,6 +8,7 @@ import io.dropwizard.lifecycle.Managed;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import org.whispersystems.textsecuregcm.metrics.PushLatencyManager;
import org.whispersystems.textsecuregcm.push.ApnMessage.Type;
import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
@ -131,10 +132,10 @@ public class MessageSender implements Managed {
ApnMessage apnMessage;
if (!Util.isEmpty(device.getVoipApnId())) {
apnMessage = new ApnMessage(device.getVoipApnId(), account.getNumber(), device.getId(), true, Optional.empty());
apnMessage = new ApnMessage(device.getVoipApnId(), account.getNumber(), device.getId(), true, Type.NOTIFICATION, Optional.empty());
RedisOperation.unchecked(() -> apnFallbackManager.schedule(account, device));
} else {
apnMessage = new ApnMessage(device.getApnId(), account.getNumber(), device.getId(), false, Optional.empty());
apnMessage = new ApnMessage(device.getApnId(), account.getNumber(), device.getId(), false, Type.NOTIFICATION, Optional.empty());
}
apnSender.sendMessage(apnMessage);

View File

@ -6,6 +6,10 @@
package org.whispersystems.textsecuregcm.push;
public class NotPushRegisteredException extends Exception {
public NotPushRegisteredException() {
super();
}
public NotPushRegisteredException(String s) {
super(s);
}

View File

@ -0,0 +1,97 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import com.amazonaws.services.dynamodbv2.document.DynamoDB;
import com.amazonaws.services.dynamodbv2.document.Item;
import com.amazonaws.services.dynamodbv2.document.Table;
import com.amazonaws.services.dynamodbv2.document.spec.DeleteItemSpec;
import com.amazonaws.services.dynamodbv2.document.spec.PutItemSpec;
import com.amazonaws.services.dynamodbv2.model.ConditionalCheckFailedException;
import java.time.Clock;
import java.time.Duration;
import java.util.Map;
import java.util.UUID;
import com.google.common.annotations.VisibleForTesting;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
/**
* Stores push challenge tokens. Users may have at most one outstanding push challenge token at a time.
*/
public class PushChallengeDynamoDb extends AbstractDynamoDbStore {
private final Table table;
private final Clock clock;
static final String KEY_ACCOUNT_UUID = "U";
static final String ATTR_CHALLENGE_TOKEN = "C";
static final String ATTR_TTL = "T";
private static final Map<String, String> UUID_NAME_MAP = Map.of("#uuid", KEY_ACCOUNT_UUID);
private static final Map<String, String> CHALLENGE_TOKEN_NAME_MAP = Map.of("#challenge", ATTR_CHALLENGE_TOKEN);
public PushChallengeDynamoDb(final DynamoDB dynamoDB, final String tableName) {
this(dynamoDB, tableName, Clock.systemUTC());
}
@VisibleForTesting
PushChallengeDynamoDb(final DynamoDB dynamoDB, final String tableName, final Clock clock) {
super(dynamoDB);
this.table = dynamoDB.getTable(tableName);
this.clock = clock;
}
/**
* Stores a push challenge token for the given user if and only if the user doesn't already have a token stored. The
* existence check is strongly-consistent.
*
* @param accountUuid the UUID of the account for which to store a push challenge token
* @param challengeToken the challenge token itself
* @param ttl the time after which the token is no longer valid
* @return {@code true} if a new token was stored of {@code false} if another token already exists for the given
* account
*/
public boolean add(final UUID accountUuid, final byte[] challengeToken, final Duration ttl) {
try {
table.putItem( new PutItemSpec()
.withItem(new Item()
.withBinary(KEY_ACCOUNT_UUID, UUIDUtil.toByteBuffer(accountUuid))
.withBinary(ATTR_CHALLENGE_TOKEN, challengeToken)
.withNumber(ATTR_TTL, getExpirationTimestamp(ttl)))
.withConditionExpression("attribute_not_exists(#uuid)")
.withNameMap(UUID_NAME_MAP));
return true;
} catch (final ConditionalCheckFailedException e) {
return false;
}
}
long getExpirationTimestamp(final Duration ttl) {
return clock.instant().plus(ttl).getEpochSecond();
}
/**
* Clears a push challenge token for the given user if and only if the given challenge token matches the stored token.
* The token comparison is a strongly-consistent operation.
*
* @param accountUuid the account for which to remove a stored token
* @param challengeToken the token to remove
* @return {@code true} if the given token matched the stored token for the given user or {@code false} otherwise
*/
public boolean remove(final UUID accountUuid, final byte[] challengeToken) {
try {
table.deleteItem(new DeleteItemSpec()
.withPrimaryKey(KEY_ACCOUNT_UUID, UUIDUtil.toByteBuffer(accountUuid))
.withConditionExpression("#challenge = :challenge")
.withNameMap(CHALLENGE_TOKEN_NAME_MAP)
.withValueMap(Map.of(":challenge", challengeToken)));
return true;
} catch (final ConditionalCheckFailedException e) {
return false;
}
}
}

View File

@ -347,7 +347,6 @@ class DynamicConfigurationTest {
assertThat(emptyConfig.getLimits().getUnsealedSenderNumber().getMaxCardinality()).isEqualTo(100);
assertThat(emptyConfig.getLimits().getUnsealedSenderNumber().getTtl()).isEqualTo(Duration.ofDays(1));
assertThat(emptyConfig.getLimits().getUnsealedSenderNumber().getTtlJitter()).isEqualTo(Duration.ofDays(1));
}
{
@ -355,15 +354,46 @@ class DynamicConfigurationTest {
"limits:\n"
+ " unsealedSenderNumber:\n"
+ " maxCardinality: 99\n"
+ " ttl: PT23H\n"
+ " ttlJitter: PT22H";
+ " ttl: PT23H";
final CardinalityRateLimitConfiguration unsealedSenderNumber = DynamicConfigurationManager.OBJECT_MAPPER
.readValue(limitsConfig, DynamicConfiguration.class)
.getLimits().getUnsealedSenderNumber();
assertThat(unsealedSenderNumber.getMaxCardinality()).isEqualTo(99);
assertThat(unsealedSenderNumber.getTtl()).isEqualTo(Duration.ofHours(23));
assertThat(unsealedSenderNumber.getTtlJitter()).isEqualTo(Duration.ofHours(22));
}
}
@Test
void testParseRateLimitReset() throws JsonProcessingException {
{
final String emptyConfigYaml = "test: true";
final DynamicConfiguration emptyConfig = DynamicConfigurationManager.OBJECT_MAPPER.readValue(
emptyConfigYaml, DynamicConfiguration.class);
assertThat(emptyConfig.getRateLimitChallengeConfiguration().getClientSupportedVersions()).isEmpty();
assertThat(emptyConfig.getRateLimitChallengeConfiguration().isPreKeyLimitEnforced()).isFalse();
assertThat(emptyConfig.getRateLimitChallengeConfiguration().isUnsealedSenderLimitEnforced()).isFalse();
}
{
final String rateLimitChallengeConfig =
"rateLimitChallenge:\n"
+ " preKeyLimitEnforced: true\n"
+ " clientSupportedVersions:\n"
+ " IOS: 5.1.0\n"
+ " ANDROID: 5.2.0\n"
+ " DESKTOP: 5.0.0";
DynamicRateLimitChallengeConfiguration rateLimitChallengeConfiguration = DynamicConfigurationManager.OBJECT_MAPPER
.readValue(rateLimitChallengeConfig, DynamicConfiguration.class)
.getRateLimitChallengeConfiguration();
final Map<ClientPlatform, Semver> clientSupportedVersions = rateLimitChallengeConfiguration.getClientSupportedVersions();
assertThat(clientSupportedVersions.get(ClientPlatform.IOS)).isEqualTo(new Semver("5.1.0"));
assertThat(clientSupportedVersions.get(ClientPlatform.ANDROID)).isEqualTo(new Semver("5.2.0"));
assertThat(clientSupportedVersions.get(ClientPlatform.DESKTOP)).isEqualTo(new Semver("5.0.0"));
assertThat(rateLimitChallengeConfiguration.isPreKeyLimitEnforced()).isTrue();
assertThat(rateLimitChallengeConfiguration.isUnsealedSenderLimitEnforced()).isFalse();
}
}
}

View File

@ -0,0 +1,200 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.controllers;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.time.Duration;
import java.util.Set;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.Response;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.mappers.RetryLaterExceptionMapper;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ExtendWith(DropwizardExtensionsSupport.class)
class ChallengeControllerTest {
private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class);
private static final ChallengeController challengeController = new ChallengeController(rateLimitChallengeManager);
private static final ResourceExtension EXTENSION = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(Set.of(Account.class, DisabledPermittedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new RetryLaterExceptionMapper())
.addResource(challengeController)
.build();
@AfterEach
void teardown() {
reset(rateLimitChallengeManager);
}
@Test
void testHandlePushChallenge() throws RateLimitExceededException {
final String pushChallengeJson = "{\n"
+ " \"type\": \"rateLimitPushChallenge\",\n"
+ " \"challenge\": \"Hello I am a push challenge token\"\n"
+ "}";
final Response response = EXTENSION.target("/v1/challenge")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(Entity.json(pushChallengeJson));
assertEquals(200, response.getStatus());
verify(rateLimitChallengeManager).answerPushChallenge(AuthHelper.VALID_ACCOUNT, "Hello I am a push challenge token");
}
@Test
void testHandlePushChallengeRateLimited() throws RateLimitExceededException {
final String pushChallengeJson = "{\n"
+ " \"type\": \"rateLimitPushChallenge\",\n"
+ " \"challenge\": \"Hello I am a push challenge token\"\n"
+ "}";
final Duration retryAfter = Duration.ofMinutes(17);
doThrow(new RateLimitExceededException(retryAfter)).when(rateLimitChallengeManager).answerPushChallenge(any(), any());
final Response response = EXTENSION.target("/v1/challenge")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(Entity.json(pushChallengeJson));
assertEquals(413, response.getStatus());
assertEquals(String.valueOf(retryAfter.toSeconds()), response.getHeaderString("Retry-After"));
}
@Test
void testHandleRecaptcha() throws RateLimitExceededException {
final String recaptchaChallengeJson = "{\n"
+ " \"type\": \"recaptcha\",\n"
+ " \"token\": \"A server-generated token\",\n"
+ " \"captcha\": \"The value of the solved captcha token\"\n"
+ "}";
final Response response = EXTENSION.target("/v1/challenge")
.request()
.header("X-Forwarded-For", "10.0.0.1")
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(Entity.json(recaptchaChallengeJson));
assertEquals(200, response.getStatus());
verify(rateLimitChallengeManager).answerRecaptchaChallenge(AuthHelper.VALID_ACCOUNT, "The value of the solved captcha token", "10.0.0.1");
}
@Test
void testHandleRecaptchaRateLimited() throws RateLimitExceededException {
final String recaptchaChallengeJson = "{\n"
+ " \"type\": \"recaptcha\",\n"
+ " \"token\": \"A server-generated token\",\n"
+ " \"captcha\": \"The value of the solved captcha token\"\n"
+ "}";
final Duration retryAfter = Duration.ofMinutes(17);
doThrow(new RateLimitExceededException(retryAfter)).when(rateLimitChallengeManager).answerRecaptchaChallenge(any(), any(), any());
final Response response = EXTENSION.target("/v1/challenge")
.request()
.header("X-Forwarded-For", "10.0.0.1")
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(Entity.json(recaptchaChallengeJson));
assertEquals(413, response.getStatus());
assertEquals(String.valueOf(retryAfter.toSeconds()), response.getHeaderString("Retry-After"));
}
@Test
void testHandleRecaptchaNoForwardedFor() {
final String recaptchaChallengeJson = "{\n"
+ " \"type\": \"recaptcha\",\n"
+ " \"token\": \"A server-generated token\",\n"
+ " \"captcha\": \"The value of the solved captcha token\"\n"
+ "}";
final Response response = EXTENSION.target("/v1/challenge")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(Entity.json(recaptchaChallengeJson));
assertEquals(400, response.getStatus());
verifyZeroInteractions(rateLimitChallengeManager);
}
@Test
void testHandleUnrecognizedAnswer() {
final String unrecognizedJson = "{\n"
+ " \"type\": \"unrecognized\"\n"
+ "}";
final Response response = EXTENSION.target("/v1/challenge")
.request()
.header("X-Forwarded-For", "10.0.0.1")
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(Entity.json(unrecognizedJson));
assertEquals(400, response.getStatus());
verifyZeroInteractions(rateLimitChallengeManager);
}
@Test
void testRequestPushChallenge() throws NotPushRegisteredException {
{
final Response response = EXTENSION.target("/v1/challenge/push")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.post(Entity.text(""));
assertEquals(200, response.getStatus());
}
{
doThrow(NotPushRegisteredException.class).when(rateLimitChallengeManager).sendPushChallenge(AuthHelper.VALID_ACCOUNT_TWO);
final Response response = EXTENSION.target("/v1/challenge/push")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER_TWO, AuthHelper.VALID_PASSWORD_TWO))
.post(Entity.text(""));
assertEquals(404, response.getStatus());
}
}
@Test
void testValidationError() {
final String unrecognizedJson = "{\n"
+ " \"type\": \"rateLimitPushChallenge\"\n"
+ "}";
final Response response = EXTENSION.target("/v1/challenge")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.put(Entity.json(unrecognizedJson));
assertEquals(422, response.getStatus());
}
}

View File

@ -5,9 +5,16 @@
package org.whispersystems.textsecuregcm.controllers;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import java.util.concurrent.ScheduledExecutorService;
import org.junit.Before;
import org.junit.Test;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.limits.UnsealedSenderRateLimiter;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.MessageSender;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
@ -16,12 +23,6 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import java.util.concurrent.ScheduledExecutorService;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
public class MessageControllerMetricsTest extends AbstractRedisClusterTest {
private MessageController messageController;
@ -35,8 +36,10 @@ public class MessageControllerMetricsTest extends AbstractRedisClusterTest {
mock(ReceiptSender.class),
mock(AccountsManager.class),
mock(MessagesManager.class),
mock(UnsealedSenderRateLimiter.class),
mock(ApnFallbackManager.class),
mock(DynamicConfigurationManager.class),
mock(RateLimitChallengeManager.class),
getRedisCluster(),
mock(ScheduledExecutorService.class));
}

View File

@ -0,0 +1,63 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import static org.junit.jupiter.api.Assertions.*;
class AnswerChallengeRequestTest {
@Test
void parse() throws JsonProcessingException {
{
final String pushChallengeJson = "{\n"
+ " \"type\": \"rateLimitPushChallenge\",\n"
+ " \"challenge\": \"Hello I am a push challenge token\"\n"
+ "}";
final AnswerChallengeRequest answerChallengeRequest =
SystemMapper.getMapper().readValue(pushChallengeJson, AnswerChallengeRequest.class);
assertTrue(answerChallengeRequest instanceof AnswerPushChallengeRequest);
assertEquals("Hello I am a push challenge token",
((AnswerPushChallengeRequest) answerChallengeRequest).getChallenge());
}
{
final String recaptchaChallengeJson = "{\n"
+ " \"type\": \"recaptcha\",\n"
+ " \"token\": \"A server-generated token\",\n"
+ " \"captcha\": \"The value of the solved captcha token\"\n"
+ "}";
final AnswerChallengeRequest answerChallengeRequest =
SystemMapper.getMapper().readValue(recaptchaChallengeJson, AnswerChallengeRequest.class);
assertTrue(answerChallengeRequest instanceof AnswerRecaptchaChallengeRequest);
assertEquals("A server-generated token",
((AnswerRecaptchaChallengeRequest) answerChallengeRequest).getToken());
assertEquals("The value of the solved captcha token",
((AnswerRecaptchaChallengeRequest) answerChallengeRequest).getCaptcha());
}
{
final String unrecognizedTypeJson = "{\n"
+ " \"type\": \"unrecognized\",\n"
+ " \"token\": \"A server-generated token\",\n"
+ " \"captcha\": \"The value of the solved captcha token\"\n"
+ "}";
assertThrows(InvalidTypeIdException.class,
() -> SystemMapper.getMapper().readValue(unrecognizedTypeJson, AnswerChallengeRequest.class));
}
}
}

View File

@ -30,7 +30,8 @@ public class CardinalityRateLimiterTest extends AbstractRedisClusterTest {
@Test
public void testValidate() {
final int maxCardinality = 10;
final CardinalityRateLimiter rateLimiter = new CardinalityRateLimiter(getRedisCluster(), "test", Duration.ofDays(1), Duration.ofDays(1), maxCardinality);
final CardinalityRateLimiter rateLimiter =
new CardinalityRateLimiter(getRedisCluster(), "test", Duration.ofDays(1), maxCardinality);
final String source = "+18005551234";
int validatedAttempts = 0;
@ -38,7 +39,7 @@ public class CardinalityRateLimiterTest extends AbstractRedisClusterTest {
for (int i = 0; i < maxCardinality * 2; i++) {
try {
rateLimiter.validate(source, String.valueOf(i));
rateLimiter.validate(source, String.valueOf(i), rateLimiter.getDefaultMaxCardinality());
validatedAttempts++;
} catch (final RateLimitExceededException e) {
blockedAttempts++;
@ -51,9 +52,10 @@ public class CardinalityRateLimiterTest extends AbstractRedisClusterTest {
final String secondSource = "+18005554321";
try {
rateLimiter.validate(secondSource, "test");
rateLimiter.validate(secondSource, "test", rateLimiter.getDefaultMaxCardinality());
} catch (final RateLimitExceededException e) {
fail("New source should not trigger a rate limit exception on first attempted validation");
}
}
}

View File

@ -0,0 +1,66 @@
package org.whispersystems.textsecuregcm.limits;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitChallengeConfiguration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
class PreKeyRateLimiterTest {
private Account account;
private PreKeyRateLimiter preKeyRateLimiter;
private DynamicRateLimitChallengeConfiguration rateLimitChallengeConfiguration;
private RateLimiter dailyPreKeyLimiter;
@BeforeEach
void setup() {
final RateLimiters rateLimiters = mock(RateLimiters.class);
dailyPreKeyLimiter = mock(RateLimiter.class);
when(rateLimiters.getDailyPreKeysLimiter()).thenReturn(dailyPreKeyLimiter);
final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
rateLimitChallengeConfiguration = mock(DynamicRateLimitChallengeConfiguration.class);
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
when(dynamicConfiguration.getRateLimitChallengeConfiguration()).thenReturn(rateLimitChallengeConfiguration);
preKeyRateLimiter = new PreKeyRateLimiter(rateLimiters, dynamicConfigurationManager, mock(RateLimitResetMetricsManager.class));
account = mock(Account.class);
when(account.getNumber()).thenReturn("+18005551111");
when(account.getUuid()).thenReturn(UUID.randomUUID());
}
@Test
void enforcementConfiguration() throws RateLimitExceededException {
doThrow(RateLimitExceededException.class)
.when(dailyPreKeyLimiter).validate(any());
when(rateLimitChallengeConfiguration.isPreKeyLimitEnforced()).thenReturn(false);
preKeyRateLimiter.validate(account);
when(rateLimitChallengeConfiguration.isPreKeyLimitEnforced()).thenReturn(true);
assertThrows(RateLimitExceededException.class, () -> preKeyRateLimiter.validate(account));
when(rateLimitChallengeConfiguration.isPreKeyLimitEnforced()).thenReturn(false);
preKeyRateLimiter.validate(account);
}
}

View File

@ -0,0 +1,190 @@
package org.whispersystems.textsecuregcm.limits;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import com.vdurmont.semver4j.Semver;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitChallengeConfiguration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.recaptcha.RecaptchaClient;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
class RateLimitChallengeManagerTest {
private PushChallengeManager pushChallengeManager;
private RecaptchaClient recaptchaClient;
private PreKeyRateLimiter preKeyRateLimiter;
private UnsealedSenderRateLimiter unsealedSenderRateLimiter;
private DynamicRateLimitChallengeConfiguration rateLimitChallengeConfiguration;
private RateLimiters rateLimiters;
private RateLimitChallengeManager rateLimitChallengeManager;
@BeforeEach
void setUp() {
pushChallengeManager = mock(PushChallengeManager.class);
recaptchaClient = mock(RecaptchaClient.class);
preKeyRateLimiter = mock(PreKeyRateLimiter.class);
unsealedSenderRateLimiter = mock(UnsealedSenderRateLimiter.class);
rateLimiters = mock(RateLimiters.class);
final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
rateLimitChallengeConfiguration = mock(DynamicRateLimitChallengeConfiguration.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
when(dynamicConfiguration.getRateLimitChallengeConfiguration()).thenReturn(rateLimitChallengeConfiguration);
rateLimitChallengeManager = new RateLimitChallengeManager(
pushChallengeManager,
recaptchaClient,
preKeyRateLimiter,
unsealedSenderRateLimiter,
rateLimiters,
dynamicConfigurationManager);
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void answerPushChallenge(final boolean successfulChallenge) throws RateLimitExceededException {
final Account account = mock(Account.class);
when(pushChallengeManager.answerChallenge(eq(account), any())).thenReturn(successfulChallenge);
when(rateLimiters.getPushChallengeAttemptLimiter()).thenReturn(mock(RateLimiter.class));
when(rateLimiters.getPushChallengeSuccessLimiter()).thenReturn(mock(RateLimiter.class));
when(rateLimiters.getRateLimitResetLimiter()).thenReturn(mock(RateLimiter.class));
rateLimitChallengeManager.answerPushChallenge(account, "challenge");
if (successfulChallenge) {
verify(preKeyRateLimiter).handleRateLimitReset(account);
verify(unsealedSenderRateLimiter).handleRateLimitReset(account);
} else {
verifyZeroInteractions(preKeyRateLimiter);
verifyZeroInteractions(unsealedSenderRateLimiter);
}
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void answerRecaptchaChallenge(final boolean successfulChallenge) throws RateLimitExceededException {
final Account account = mock(Account.class);
when(recaptchaClient.verify(any(), any())).thenReturn(successfulChallenge);
when(rateLimiters.getRecaptchaChallengeAttemptLimiter()).thenReturn(mock(RateLimiter.class));
when(rateLimiters.getRecaptchaChallengeSuccessLimiter()).thenReturn(mock(RateLimiter.class));
when(rateLimiters.getRateLimitResetLimiter()).thenReturn(mock(RateLimiter.class));
rateLimitChallengeManager.answerRecaptchaChallenge(account, "captcha", "10.0.0.1");
if (successfulChallenge) {
verify(preKeyRateLimiter).handleRateLimitReset(account);
verify(unsealedSenderRateLimiter).handleRateLimitReset(account);
} else {
verifyZeroInteractions(preKeyRateLimiter);
verifyZeroInteractions(unsealedSenderRateLimiter);
}
}
@ParameterizedTest
@MethodSource
void shouldIssueRateLimitChallenge(final String userAgent, final boolean expectIssueChallenge) {
when(rateLimitChallengeConfiguration.getMinimumSupportedVersion(any())).thenReturn(Optional.empty());
when(rateLimitChallengeConfiguration.getMinimumSupportedVersion(ClientPlatform.ANDROID))
.thenReturn(Optional.of(new Semver("5.6.0")));
when(rateLimitChallengeConfiguration.getMinimumSupportedVersion(ClientPlatform.DESKTOP))
.thenReturn(Optional.of(new Semver("5.0.0-beta.2")));
assertEquals(expectIssueChallenge, rateLimitChallengeManager.shouldIssueRateLimitChallenge(userAgent));
}
private static Stream<Arguments> shouldIssueRateLimitChallenge() {
return Stream.of(
Arguments.of("Signal-Android/5.1.2 Android/30", false),
Arguments.of("Signal-Android/5.6.0 Android/30", true),
Arguments.of("Signal-Android/5.11.1 Android/30", true),
Arguments.of("Signal-Desktop/5.0.0-beta.3 macOS/11", true),
Arguments.of("Signal-Desktop/5.0.0-beta.1 Windows/3.1", false),
Arguments.of("Signal-Desktop/5.2.0 Debian/11", true),
Arguments.of("Signal-iOS/5.1.2 iOS/12.2", false),
Arguments.of("anything-else", false)
);
}
@ParameterizedTest
@MethodSource
void getChallengeOptions(final boolean captchaAttemptPermitted,
final boolean captchaSuccessPermitted,
final boolean pushAttemptPermitted,
final boolean pushSuccessPermitted,
final boolean expectCaptcha,
final boolean expectPushChallenge) {
final RateLimiter recaptchaChallengeAttemptLimiter = mock(RateLimiter.class);
final RateLimiter recaptchaChallengeSuccessLimiter = mock(RateLimiter.class);
final RateLimiter pushChallengeAttemptLimiter = mock(RateLimiter.class);
final RateLimiter pushChallengeSuccessLimiter = mock(RateLimiter.class);
when(rateLimiters.getRecaptchaChallengeAttemptLimiter()).thenReturn(recaptchaChallengeAttemptLimiter);
when(rateLimiters.getRecaptchaChallengeSuccessLimiter()).thenReturn(recaptchaChallengeSuccessLimiter);
when(rateLimiters.getPushChallengeAttemptLimiter()).thenReturn(pushChallengeAttemptLimiter);
when(rateLimiters.getPushChallengeSuccessLimiter()).thenReturn(pushChallengeSuccessLimiter);
when(recaptchaChallengeAttemptLimiter.hasAvailablePermits(any(), anyInt())).thenReturn(captchaAttemptPermitted);
when(recaptchaChallengeSuccessLimiter.hasAvailablePermits(any(), anyInt())).thenReturn(captchaSuccessPermitted);
when(pushChallengeAttemptLimiter.hasAvailablePermits(any(), anyInt())).thenReturn(pushAttemptPermitted);
when(pushChallengeSuccessLimiter.hasAvailablePermits(any(), anyInt())).thenReturn(pushSuccessPermitted);
final int expectedLength = (expectCaptcha ? 1 : 0) + (expectPushChallenge ? 1 : 0);
final List<String> options = rateLimitChallengeManager.getChallengeOptions(mock(Account.class));
assertEquals(expectedLength, options.size());
if (expectCaptcha) {
assertTrue(options.contains(RateLimitChallengeManager.OPTION_RECAPTCHA));
}
if (expectPushChallenge) {
assertTrue(options.contains(RateLimitChallengeManager.OPTION_PUSH_CHALLENGE));
}
}
private static Stream<Arguments> getChallengeOptions() {
return Stream.of(
Arguments.of(false, false, false, false, false, false),
Arguments.of(false, false, false, true, false, false),
Arguments.of(false, false, true, false, false, false),
Arguments.of(false, false, true, true, false, true),
Arguments.of(false, true, false, false, false, false),
Arguments.of(false, true, false, true, false, false),
Arguments.of(false, true, true, false, false, false),
Arguments.of(false, true, true, true, false, true),
Arguments.of(true, false, false, false, false, false),
Arguments.of(true, false, false, true, false, false),
Arguments.of(true, false, true, false, false, false),
Arguments.of(true, false, true, true, false, true),
Arguments.of(true, true, false, false, true, false),
Arguments.of(true, true, false, true, true, false),
Arguments.of(true, true, true, false, true, false),
Arguments.of(true, true, true, true, true, true)
);
}
}

View File

@ -0,0 +1,58 @@
package org.whispersystems.textsecuregcm.limits;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import io.dropwizard.util.Duration;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.simple.SimpleMeterRegistry;
import java.util.UUID;
import org.junit.Before;
import org.junit.Test;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
import org.whispersystems.textsecuregcm.storage.Account;
public class RateLimitResetMetricsManagerTest extends AbstractRedisClusterTest {
private RateLimitResetMetricsManager metricsManager;
private SimpleMeterRegistry meterRegistry;
@Before
@Override
public void setUp() throws Exception {
super.setUp();
meterRegistry = new SimpleMeterRegistry();
metricsManager = new RateLimitResetMetricsManager(getRedisCluster(), meterRegistry);
}
@Test
public void testRecordMetrics() {
final Account firstAccount = mock(Account.class);
when(firstAccount.getUuid()).thenReturn(UUID.randomUUID());
final Account secondAccount = mock(Account.class);
when(secondAccount.getUuid()).thenReturn(UUID.randomUUID());
metricsManager.recordMetrics(firstAccount, true, "counter", "enforced", "total", Duration.hours(1).toSeconds());
metricsManager.recordMetrics(firstAccount, true, "counter", "enforced", "total", Duration.hours(1).toSeconds());
metricsManager.recordMetrics(secondAccount, false, "counter", "unenforced", "total", Duration.hours(1).toSeconds());
final double counterTotal = meterRegistry.get("counter").counters().stream()
.map(Counter::count)
.reduce(Double::sum)
.orElseThrow();
assertEquals(3, counterTotal, 0.0);
final long enforcedCount = getRedisCluster().withCluster(conn -> conn.sync().pfcount("enforced"));
assertEquals(1L, enforcedCount);
final long unenforcedCount = getRedisCluster().withCluster(conn -> conn.sync().pfcount("unenforced"));
assertEquals(1L, unenforcedCount);
final long total = getRedisCluster().withCluster(conn -> conn.sync().pfcount("total"));
assertEquals(2L, total);
}
}

View File

@ -0,0 +1,118 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.time.Duration;
import java.util.UUID;
import org.junit.Before;
import org.junit.Test;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitChallengeConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitsConfiguration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public class UnsealedSenderRateLimiterTest extends AbstractRedisClusterTest {
private Account sender;
private Account firstDestination;
private Account secondDestination;
private UnsealedSenderRateLimiter unsealedSenderRateLimiter;
private DynamicRateLimitChallengeConfiguration rateLimitChallengeConfiguration;
@Before
@Override
public void setUp() throws Exception {
super.setUp();
final RateLimiters rateLimiters = mock(RateLimiters.class);
final CardinalityRateLimiter cardinalityRateLimiter =
new CardinalityRateLimiter(getRedisCluster(), "test", Duration.ofDays(1), 1);
when(rateLimiters.getUnsealedSenderCardinalityLimiter()).thenReturn(cardinalityRateLimiter);
when(rateLimiters.getRateLimitResetLimiter()).thenReturn(mock(RateLimiter.class));
final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
final DynamicRateLimitsConfiguration rateLimitsConfiguration = mock(DynamicRateLimitsConfiguration.class);
rateLimitChallengeConfiguration = mock(DynamicRateLimitChallengeConfiguration.class);
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(dynamicConfiguration);
when(dynamicConfiguration.getLimits()).thenReturn(rateLimitsConfiguration);
when(rateLimitsConfiguration.getUnsealedSenderDefaultCardinalityLimit()).thenReturn(1);
when(rateLimitsConfiguration.getUnsealedSenderPermitIncrement()).thenReturn(1);
when(dynamicConfiguration.getRateLimitChallengeConfiguration()).thenReturn(rateLimitChallengeConfiguration);
when(rateLimitChallengeConfiguration.isUnsealedSenderLimitEnforced()).thenReturn(true);
unsealedSenderRateLimiter = new UnsealedSenderRateLimiter(rateLimiters, getRedisCluster(), dynamicConfigurationManager,
mock(RateLimitResetMetricsManager.class));
sender = mock(Account.class);
when(sender.getNumber()).thenReturn("+18005551111");
when(sender.getUuid()).thenReturn(UUID.randomUUID());
firstDestination = mock(Account.class);
when(firstDestination.getNumber()).thenReturn("+18005552222");
when(firstDestination.getUuid()).thenReturn(UUID.randomUUID());
secondDestination = mock(Account.class);
when(secondDestination.getNumber()).thenReturn("+18005553333");
when(secondDestination.getUuid()).thenReturn(UUID.randomUUID());
}
@Test
public void validate() throws RateLimitExceededException {
unsealedSenderRateLimiter.validate(sender, firstDestination);
assertThrows(RateLimitExceededException.class, () -> unsealedSenderRateLimiter.validate(sender, secondDestination));
unsealedSenderRateLimiter.validate(sender, firstDestination);
}
@Test
public void handleRateLimitReset() throws RateLimitExceededException {
unsealedSenderRateLimiter.validate(sender, firstDestination);
assertThrows(RateLimitExceededException.class, () -> unsealedSenderRateLimiter.validate(sender, secondDestination));
unsealedSenderRateLimiter.handleRateLimitReset(sender);
unsealedSenderRateLimiter.validate(sender, firstDestination);
unsealedSenderRateLimiter.validate(sender, secondDestination);
}
@Test
public void enforcementConfiguration() throws RateLimitExceededException {
when(rateLimitChallengeConfiguration.isUnsealedSenderLimitEnforced()).thenReturn(false);
unsealedSenderRateLimiter.validate(sender, firstDestination);
unsealedSenderRateLimiter.validate(sender, secondDestination);
when(rateLimitChallengeConfiguration.isUnsealedSenderLimitEnforced()).thenReturn(true);
final Account thirdDestination = mock(Account.class);
when(thirdDestination.getNumber()).thenReturn("+18005554444");
when(thirdDestination.getUuid()).thenReturn(UUID.randomUUID());
assertThrows(RateLimitExceededException.class, () -> unsealedSenderRateLimiter.validate(sender, thirdDestination));
when(rateLimitChallengeConfiguration.isUnsealedSenderLimitEnforced()).thenReturn(false);
final Account fourthDestination = mock(Account.class);
when(fourthDestination.getNumber()).thenReturn("+18005555555");
when(fourthDestination.getUuid()).thenReturn(UUID.randomUUID());
unsealedSenderRateLimiter.validate(sender, fourthDestination);
}
}

View File

@ -0,0 +1,76 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import com.amazonaws.services.dynamodbv2.model.AttributeDefinition;
import com.amazonaws.services.dynamodbv2.model.ScalarAttributeType;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.ZoneId;
import java.util.Random;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
class PushChallengeDynamoDbTest {
private PushChallengeDynamoDb pushChallengeDynamoDb;
private static final long CURRENT_TIME_MILLIS = 1_000_000_000;
private static final Random RANDOM = new Random();
private static final String TABLE_NAME = "push_challenge_test";
@RegisterExtension
static DynamoDbExtension dynamoDbExtension = DynamoDbExtension.builder()
.tableName(TABLE_NAME)
.hashKey(PushChallengeDynamoDb.KEY_ACCOUNT_UUID)
.attributeDefinition(new AttributeDefinition(PushChallengeDynamoDb.KEY_ACCOUNT_UUID, ScalarAttributeType.B))
.build();
@BeforeEach
void setUp() {
this.pushChallengeDynamoDb = new PushChallengeDynamoDb(dynamoDbExtension.getDynamoDB(), TABLE_NAME, Clock.fixed(
Instant.ofEpochMilli(CURRENT_TIME_MILLIS), ZoneId.systemDefault()));
}
@Test
void add() {
final UUID uuid = UUID.randomUUID();
assertTrue(pushChallengeDynamoDb.add(uuid, generateRandomToken(), Duration.ofMinutes(1)));
assertFalse(pushChallengeDynamoDb.add(uuid, generateRandomToken(), Duration.ofMinutes(1)));
}
@Test
void remove() {
final UUID uuid = UUID.randomUUID();
final byte[] token = generateRandomToken();
assertFalse(pushChallengeDynamoDb.remove(uuid, token));
assertTrue(pushChallengeDynamoDb.add(uuid, token, Duration.ofMinutes(1)));
assertTrue(pushChallengeDynamoDb.remove(uuid, token));
}
@Test
void getExpirationTimestamp() {
assertEquals((CURRENT_TIME_MILLIS / 1000) + 3600,
pushChallengeDynamoDb.getExpirationTimestamp(Duration.ofHours(1)));
}
private static byte[] generateRandomToken() {
final byte[] token = new byte[16];
RANDOM.nextBytes(token);
return token;
}
}

View File

@ -5,35 +5,23 @@
package org.whispersystems.textsecuregcm.tests.controllers;
import com.google.common.collect.ImmutableSet;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.controllers.KeysController;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.PreKeyCount;
import org.whispersystems.textsecuregcm.entities.PreKeyResponse;
import org.whispersystems.textsecuregcm.entities.PreKeyState;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysDynamoDb;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.argThat;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import com.google.common.collect.ImmutableSet;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.time.Duration;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
@ -41,13 +29,42 @@ import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.whispersystems.textsecuregcm.auth.AmbiguousIdentifier;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.controllers.KeysController;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.PreKey;
import org.whispersystems.textsecuregcm.entities.PreKeyCount;
import org.whispersystems.textsecuregcm.entities.PreKeyResponse;
import org.whispersystems.textsecuregcm.entities.PreKeyState;
import org.whispersystems.textsecuregcm.entities.RateLimitChallenge;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.limits.PreKeyRateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.RateLimitChallengeExceptionMapper;
import org.whispersystems.textsecuregcm.sqs.DirectoryQueue;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.KeysDynamoDb;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit.ResourceTestRule;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.*;
public class KeysControllerTest {
@ExtendWith(DropwizardExtensionsSupport.class)
class KeysControllerTest {
private static final String EXISTS_NUMBER = "+14152222222";
private static final UUID EXISTS_UUID = UUID.randomUUID();
@ -70,24 +87,28 @@ public class KeysControllerTest {
private final SignedPreKey SAMPLE_SIGNED_KEY3 = new SignedPreKey( 3333, "barfoo", "sig33" );
private final SignedPreKey VALID_DEVICE_SIGNED_KEY = new SignedPreKey(89898, "zoofarb", "sigvalid");
private final KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class );
private final AccountsManager accounts = mock(AccountsManager.class );
private final DirectoryQueue directoryQueue = mock(DirectoryQueue.class );
private final Account existsAccount = mock(Account.class );
private final static KeysDynamoDb keysDynamoDb = mock(KeysDynamoDb.class );
private final static AccountsManager accounts = mock(AccountsManager.class );
private final static DirectoryQueue directoryQueue = mock(DirectoryQueue.class );
private final static PreKeyRateLimiter preKeyRateLimiter = mock(PreKeyRateLimiter.class );
private final static RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class );
private final static Account existsAccount = mock(Account.class );
private RateLimiters rateLimiters = mock(RateLimiters.class);
private RateLimiter rateLimiter = mock(RateLimiter.class );
private final static DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
@Rule
public final ResourceTestRule resources = ResourceTestRule.builder()
private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final RateLimiter rateLimiter = mock(RateLimiter.class );
private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new KeysController(rateLimiters, keysDynamoDb, accounts, directoryQueue))
.addResource(new RateLimitChallengeExceptionMapper(rateLimitChallengeManager))
.addResource(new KeysController(rateLimiters, keysDynamoDb, accounts, directoryQueue, preKeyRateLimiter, dynamicConfigurationManager, rateLimitChallengeManager))
.build();
@Before
public void setup() {
@BeforeEach
void setup() {
final Device sampleDevice = mock(Device.class);
final Device sampleDevice2 = mock(Device.class);
final Device sampleDevice3 = mock(Device.class);
@ -153,8 +174,23 @@ public class KeysControllerTest {
when(AuthHelper.VALID_ACCOUNT.getIdentityKey()).thenReturn(null);
}
@AfterEach
void teardown() {
reset(
keysDynamoDb,
accounts,
directoryQueue,
preKeyRateLimiter,
existsAccount,
rateLimiters,
rateLimiter,
dynamicConfigurationManager,
rateLimitChallengeManager
);
}
@Test
public void validKeyStatusTestByNumberV2() throws Exception {
void validKeyStatusTestByNumberV2() throws Exception {
PreKeyCount result = resources.getJerseyTest()
.target("/v2/keys")
.request()
@ -168,7 +204,7 @@ public class KeysControllerTest {
}
@Test
public void validKeyStatusTestByUuidV2() throws Exception {
void validKeyStatusTestByUuidV2() throws Exception {
PreKeyCount result = resources.getJerseyTest()
.target("/v2/keys")
.request()
@ -183,7 +219,7 @@ public class KeysControllerTest {
@Test
public void getSignedPreKeyV2ByNumber() throws Exception {
void getSignedPreKeyV2ByNumber() throws Exception {
SignedPreKey result = resources.getJerseyTest()
.target("/v2/keys/signed")
.request()
@ -196,7 +232,7 @@ public class KeysControllerTest {
}
@Test
public void getSignedPreKeyV2ByUuid() throws Exception {
void getSignedPreKeyV2ByUuid() throws Exception {
SignedPreKey result = resources.getJerseyTest()
.target("/v2/keys/signed")
.request()
@ -209,7 +245,7 @@ public class KeysControllerTest {
}
@Test
public void putSignedPreKeyV2ByNumber() throws Exception {
void putSignedPreKeyV2ByNumber() throws Exception {
SignedPreKey test = new SignedPreKey(9999, "fooozzz", "baaarzzz");
Response response = resources.getJerseyTest()
.target("/v2/keys/signed")
@ -224,7 +260,7 @@ public class KeysControllerTest {
}
@Test
public void putSignedPreKeyV2ByUuid() throws Exception {
void putSignedPreKeyV2ByUuid() throws Exception {
SignedPreKey test = new SignedPreKey(9998, "fooozzz", "baaarzzz");
Response response = resources.getJerseyTest()
.target("/v2/keys/signed")
@ -240,7 +276,7 @@ public class KeysControllerTest {
@Test
public void disabledPutSignedPreKeyV2ByNumber() throws Exception {
void disabledPutSignedPreKeyV2ByNumber() throws Exception {
SignedPreKey test = new SignedPreKey(9999, "fooozzz", "baaarzzz");
Response response = resources.getJerseyTest()
.target("/v2/keys/signed")
@ -252,7 +288,7 @@ public class KeysControllerTest {
}
@Test
public void disabledPutSignedPreKeyV2ByUuid() throws Exception {
void disabledPutSignedPreKeyV2ByUuid() throws Exception {
SignedPreKey test = new SignedPreKey(9999, "fooozzz", "baaarzzz");
Response response = resources.getJerseyTest()
.target("/v2/keys/signed")
@ -265,7 +301,7 @@ public class KeysControllerTest {
@Test
public void validSingleRequestTestV2ByNumber() throws Exception {
void validSingleRequestTestV2ByNumber() throws Exception {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
.request()
@ -283,7 +319,7 @@ public class KeysControllerTest {
}
@Test
public void validSingleRequestTestV2ByUuid() throws Exception {
void validSingleRequestTestV2ByUuid() throws Exception {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID))
.request()
@ -302,7 +338,7 @@ public class KeysControllerTest {
@Test
public void testUnidentifiedRequestByNumber() throws Exception {
void testUnidentifiedRequestByNumber() throws Exception {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
.request()
@ -320,7 +356,7 @@ public class KeysControllerTest {
}
@Test
public void testUnidentifiedRequestByUuid() throws Exception {
void testUnidentifiedRequestByUuid() throws Exception {
PreKeyResponse result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_UUID.toString()))
.request()
@ -337,9 +373,23 @@ public class KeysControllerTest {
verifyNoMoreInteractions(keysDynamoDb);
}
@Test
void testNoDevices() {
when(existsAccount.getDevices()).thenReturn(Collections.emptySet());
Response result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID.toString()))
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.get();
assertThat(result).isNotNull();
assertThat(result.getStatus()).isEqualTo(404);
}
@Test
public void testUnauthorizedUnidentifiedRequest() throws Exception {
void testUnauthorizedUnidentifiedRequest() throws Exception {
Response response = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
.request()
@ -351,7 +401,7 @@ public class KeysControllerTest {
}
@Test
public void testMalformedUnidentifiedRequest() throws Exception {
void testMalformedUnidentifiedRequest() throws Exception {
Response response = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
.request()
@ -364,7 +414,7 @@ public class KeysControllerTest {
@Test
public void validMultiRequestTestV2ByNumber() throws Exception {
void validMultiRequestTestV2ByNumber() throws Exception {
PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_NUMBER))
.request()
@ -414,7 +464,7 @@ public class KeysControllerTest {
}
@Test
public void validMultiRequestTestV2ByUuid() throws Exception {
void validMultiRequestTestV2ByUuid() throws Exception {
PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID.toString()))
.request()
@ -465,7 +515,7 @@ public class KeysControllerTest {
@Test
public void invalidRequestTestV2() throws Exception {
void invalidRequestTestV2() throws Exception {
Response response = resources.getJerseyTest()
.target(String.format("/v2/keys/%s", NOT_EXISTS_NUMBER))
.request()
@ -476,7 +526,7 @@ public class KeysControllerTest {
}
@Test
public void anotherInvalidRequestTestV2() throws Exception {
void anotherInvalidRequestTestV2() throws Exception {
Response response = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/22", EXISTS_NUMBER))
.request()
@ -487,7 +537,7 @@ public class KeysControllerTest {
}
@Test
public void unauthorizedRequestTestV2() throws Exception {
void unauthorizedRequestTestV2() throws Exception {
Response response =
resources.getJerseyTest()
.target(String.format("/v2/keys/%s/1", EXISTS_NUMBER))
@ -507,7 +557,7 @@ public class KeysControllerTest {
}
@Test
public void putKeysTestV2() throws Exception {
void putKeysTestV2() throws Exception {
final PreKey preKey = new PreKey(31337, "foobar");
final SignedPreKey signedPreKey = new SignedPreKey(31338, "foobaz", "myvalidsig");
final String identityKey = "barbar";
@ -541,7 +591,7 @@ public class KeysControllerTest {
}
@Test
public void disabledPutKeysTestV2() throws Exception {
void disabledPutKeysTestV2() throws Exception {
final PreKey preKey = new PreKey(31337, "foobar");
final SignedPreKey signedPreKey = new SignedPreKey(31338, "foobaz", "myvalidsig");
final String identityKey = "barbar";
@ -574,5 +624,42 @@ public class KeysControllerTest {
verify(accounts).update(AuthHelper.DISABLED_ACCOUNT);
}
@Test
void testRateLimitChallenge() throws RateLimitExceededException {
Duration retryAfter = Duration.ofMinutes(1);
doThrow(new RateLimitExceededException(retryAfter))
.when(preKeyRateLimiter).validate(any());
when(rateLimitChallengeManager.shouldIssueRateLimitChallenge("Signal-Android/5.1.2 Android/30")).thenReturn(true);
when(rateLimitChallengeManager.getChallengeOptions(AuthHelper.VALID_ACCOUNT))
.thenReturn(List.of(RateLimitChallengeManager.OPTION_PUSH_CHALLENGE, RateLimitChallengeManager.OPTION_RECAPTCHA));
Response result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID.toString()))
.request()
.header(OptionalAccess.UNIDENTIFIED, AuthHelper.getUnidentifiedAccessHeader("1337".getBytes()))
.header("User-Agent", "Signal-Android/5.1.2 Android/30")
.get();
// unidentified access should not be rate limited
assertThat(result.getStatus()).isEqualTo(200);
result = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID.toString()))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.header("User-Agent", "Signal-Android/5.1.2 Android/30")
.get();
assertThat(result.getStatus()).isEqualTo(428);
RateLimitChallenge rateLimitChallenge = result.readEntity(RateLimitChallenge.class);
assertThat(rateLimitChallenge.getToken()).isNotBlank();
assertThat(rateLimitChallenge.getOptions()).isNotEmpty();
assertThat(rateLimitChallenge.getOptions()).contains("recaptcha");
assertThat(rateLimitChallenge.getOptions()).contains("pushChallenge");
assertThat(Long.parseLong(result.getHeaderString("Retry-After"))).isEqualTo(retryAfter.toSeconds());
}
}

View File

@ -31,6 +31,7 @@ import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixtur
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableSet;
import com.vdurmont.semver4j.Semver;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
@ -57,8 +58,8 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.mockito.stubbing.Answer;
@ -67,6 +68,7 @@ import org.whispersystems.textsecuregcm.auth.DisabledPermittedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicMessageRateConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitChallengeConfiguration;
import org.whispersystems.textsecuregcm.controllers.MessageController;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.entities.IncomingMessageList;
@ -74,11 +76,15 @@ import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.MismatchedDevices;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.RateLimitChallenge;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.limits.CardinalityRateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.limits.UnsealedSenderRateLimiter;
import org.whispersystems.textsecuregcm.mappers.RateLimitChallengeExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.MessageSender;
@ -91,6 +97,7 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
@ExtendWith(DropwizardExtensionsSupport.class)
class MessageControllerTest {
@ -104,6 +111,8 @@ class MessageControllerTest {
private static final String INTERNATIONAL_RECIPIENT = "+61123456789";
private static final UUID INTERNATIONAL_UUID = UUID.randomUUID();
private Account internationalAccount;
@SuppressWarnings("unchecked")
private static final RedisAdvancedClusterCommands<String, String> redisCommands = mock(RedisAdvancedClusterCommands.class);
@ -114,8 +123,10 @@ class MessageControllerTest {
private static final RateLimiters rateLimiters = mock(RateLimiters.class);
private static final RateLimiter rateLimiter = mock(RateLimiter.class);
private static final CardinalityRateLimiter unsealedSenderLimiter = mock(CardinalityRateLimiter.class);
private static final UnsealedSenderRateLimiter unsealedSenderRateLimiter = mock(UnsealedSenderRateLimiter.class);
private static final ApnFallbackManager apnFallbackManager = mock(ApnFallbackManager.class);
private static final DynamicConfigurationManager dynamicConfigurationManager = mock(DynamicConfigurationManager.class);
private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class);
private static final FaultTolerantRedisCluster metricsCluster = RedisClusterHelper.buildMockRedisCluster(redisCommands);
private static final ScheduledExecutorService receiptExecutor = mock(ScheduledExecutorService.class);
@ -125,9 +136,10 @@ class MessageControllerTest {
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(Account.class, DisabledPermittedAccount.class)))
.addProvider(RateLimitExceededExceptionMapper.class)
.addProvider(new RateLimitChallengeExceptionMapper(rateLimitChallengeManager))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new MessageController(rateLimiters, messageSender, receiptSender, accountsManager,
messagesManager, apnFallbackManager, dynamicConfigurationManager, metricsCluster, receiptExecutor))
messagesManager, unsealedSenderRateLimiter, apnFallbackManager, dynamicConfigurationManager, rateLimitChallengeManager, metricsCluster, receiptExecutor))
.build();
@BeforeEach
@ -148,7 +160,7 @@ class MessageControllerTest {
Account singleDeviceAccount = new Account(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, singleDeviceList, "1234".getBytes());
Account multiDeviceAccount = new Account(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, multiDeviceList, "1234".getBytes());
Account internationalAccount = new Account(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, singleDeviceList, "1234".getBytes());
internationalAccount = new Account(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, singleDeviceList, "1234".getBytes());
when(accountsManager.get(eq(SINGLE_DEVICE_RECIPIENT))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(SINGLE_DEVICE_RECIPIENT)))).thenReturn(Optional.of(singleDeviceAccount));
@ -158,7 +170,6 @@ class MessageControllerTest {
when(accountsManager.get(argThat((ArgumentMatcher<AmbiguousIdentifier>) identifier -> identifier != null && identifier.hasNumber() && identifier.getNumber().equals(INTERNATIONAL_RECIPIENT)))).thenReturn(Optional.of(internationalAccount));
when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getUnsealedSenderLimiter()).thenReturn(unsealedSenderLimiter);
when(dynamicConfigurationManager.getConfiguration()).thenReturn(new DynamicConfiguration());
@ -179,9 +190,10 @@ class MessageControllerTest {
messagesManager,
rateLimiters,
rateLimiter,
unsealedSenderLimiter,
unsealedSenderRateLimiter,
apnFallbackManager,
dynamicConfigurationManager,
rateLimitChallengeManager,
metricsCluster,
receiptExecutor
);
@ -254,8 +266,8 @@ class MessageControllerTest {
}
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testUnsealedSenderCardinalityRateLimited(final boolean rateLimited) throws Exception {
@CsvSource({"true, 5.1.0, 413", "true, 5.6.4, 428", "false, 5.6.4, 200"})
void testUnsealedSenderCardinalityRateLimited(final boolean rateLimited, final String clientVersion, final int expectedStatusCode) throws Exception {
final DynamicConfiguration dynamicConfiguration = mock(DynamicConfiguration.class);
final DynamicMessageRateConfiguration messageRateConfiguration = mock(DynamicMessageRateConfiguration.class);
@ -268,11 +280,23 @@ class MessageControllerTest {
when(messageRateConfiguration.getReceiptDelayJitter()).thenReturn(Duration.ofMillis(1));
when(messageRateConfiguration.getReceiptProbability()).thenReturn(1.0);
DynamicRateLimitChallengeConfiguration dynamicRateLimitChallengeConfiguration = mock(
DynamicRateLimitChallengeConfiguration.class);
when(dynamicConfiguration.getRateLimitChallengeConfiguration())
.thenReturn(dynamicRateLimitChallengeConfiguration);
when(dynamicRateLimitChallengeConfiguration.getMinimumSupportedVersion(any())).thenReturn(Optional.empty());
when(dynamicRateLimitChallengeConfiguration.getMinimumSupportedVersion(ClientPlatform.ANDROID))
.thenReturn(Optional.of(new Semver("5.5.0")));
when(redisCommands.evalsha(any(), any(), any(), any())).thenReturn(List.of(1L, 1L));
if (rateLimited) {
doThrow(RateLimitExceededException.class)
.when(unsealedSenderLimiter).validate(eq(AuthHelper.VALID_NUMBER), eq(INTERNATIONAL_RECIPIENT));
doThrow(new RateLimitExceededException(Duration.ofHours(1)))
.when(unsealedSenderRateLimiter).validate(eq(AuthHelper.VALID_ACCOUNT), eq(internationalAccount));
when(rateLimitChallengeManager.shouldIssueRateLimitChallenge(String.format("Signal-Android/%s Android/30", clientVersion)))
.thenReturn(true);
}
Response response =
@ -280,18 +304,50 @@ class MessageControllerTest {
.target(String.format("/v1/messages/%s", INTERNATIONAL_RECIPIENT))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.header("User-Agent", "Signal-Android/5.6.4 Android/30")
.put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
if (rateLimited) {
assertThat("Error Response", response.getStatus(), is(equalTo(413)));
assertThat("Error Response", response.getStatus(), is(equalTo(expectedStatusCode)));
} else {
assertThat("Good Response", response.getStatus(), is(equalTo(200)));
assertThat("Good Response", response.getStatus(), is(equalTo(expectedStatusCode)));
}
verify(messageSender, rateLimited ? never() : times(1)).sendMessage(any(), any(), any(), anyBoolean());
}
@Test
void testRateLimitResetRequirement() throws Exception {
Duration retryAfter = Duration.ofMinutes(1);
doThrow(new RateLimitExceededException(retryAfter))
.when(unsealedSenderRateLimiter).validate(any(), any());
when(rateLimitChallengeManager.shouldIssueRateLimitChallenge("Signal-Android/5.1.2 Android/30")).thenReturn(true);
when(rateLimitChallengeManager.getChallengeOptions(AuthHelper.VALID_ACCOUNT))
.thenReturn(List.of(RateLimitChallengeManager.OPTION_PUSH_CHALLENGE, RateLimitChallengeManager.OPTION_RECAPTCHA));
Response response =
resources.getJerseyTest()
.target(String.format("/v1/messages/%s", INTERNATIONAL_RECIPIENT))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_NUMBER, AuthHelper.VALID_PASSWORD))
.header("User-Agent", "Signal-Android/5.1.2 Android/30")
.put(Entity.entity(mapper.readValue(jsonFixture("fixtures/current_message_single_device.json"), IncomingMessageList.class),
MediaType.APPLICATION_JSON_TYPE));
assertEquals(428, response.getStatus());
RateLimitChallenge rateLimitChallenge = response.readEntity(RateLimitChallenge.class);
assertFalse(rateLimitChallenge.getToken().isBlank());
assertFalse(rateLimitChallenge.getOptions().isEmpty());
assertTrue(rateLimitChallenge.getOptions().contains("recaptcha"));
assertTrue(rateLimitChallenge.getOptions().contains("pushChallenge"));
assertEquals(retryAfter.toSeconds(), Long.parseLong(response.getHeaderString("Retry-After")));
}
@Test
void testSingleDeviceCurrentUnidentified() throws Exception {
Response response =

View File

@ -1,15 +1,16 @@
package org.whispersystems.textsecuregcm.tests.limits;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertSame;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.time.Duration;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitsConfiguration;
import org.whispersystems.textsecuregcm.limits.CardinalityRateLimiter;
@ -18,13 +19,13 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public class DynamicRateLimitsTest {
class DynamicRateLimitsTest {
private DynamicConfigurationManager dynamicConfig;
private FaultTolerantRedisCluster redisCluster;
@Before
public void setup() {
@BeforeEach
void setup() {
this.dynamicConfig = mock(DynamicConfigurationManager.class);
this.redisCluster = mock(FaultTolerantRedisCluster.class);
@ -34,7 +35,7 @@ public class DynamicRateLimitsTest {
}
@Test
public void testUnchangingConfiguration() {
void testUnchangingConfiguration() {
RateLimiters rateLimiters = new RateLimiters(new RateLimitsConfiguration(), dynamicConfig, redisCluster);
RateLimiter limiter = rateLimiters.getUnsealedIpLimiter();
@ -45,34 +46,39 @@ public class DynamicRateLimitsTest {
}
@Test
public void testChangingConfiguration() {
void testChangingConfiguration() {
DynamicConfiguration configuration = mock(DynamicConfiguration.class);
DynamicRateLimitsConfiguration limitsConfiguration = mock(DynamicRateLimitsConfiguration.class);
when(configuration.getLimits()).thenReturn(limitsConfiguration);
when(limitsConfiguration.getUnsealedSenderNumber()).thenReturn(new RateLimitsConfiguration.CardinalityRateLimitConfiguration(10, Duration.ofHours(1), Duration.ofMinutes(10)));
when(limitsConfiguration.getUnsealedSenderIp()).thenReturn(new RateLimitsConfiguration.RateLimitConfiguration(4, 1.0));
when(limitsConfiguration.getUnsealedSenderNumber()).thenReturn(new RateLimitsConfiguration.CardinalityRateLimitConfiguration(10, Duration.ofHours(1)));
when(limitsConfiguration.getRecaptchaChallengeAttempt()).thenReturn(new RateLimitConfiguration());
when(limitsConfiguration.getRecaptchaChallengeSuccess()).thenReturn(new RateLimitConfiguration());
when(limitsConfiguration.getPushChallengeAttempt()).thenReturn(new RateLimitConfiguration());
when(limitsConfiguration.getPushChallengeSuccess()).thenReturn(new RateLimitConfiguration());
when(limitsConfiguration.getDailyPreKeys()).thenReturn(new RateLimitConfiguration());
final RateLimitConfiguration initialRateLimitConfiguration = new RateLimitConfiguration(4, 1.0);
when(limitsConfiguration.getUnsealedSenderIp()).thenReturn(initialRateLimitConfiguration);
when(limitsConfiguration.getRateLimitReset()).thenReturn(initialRateLimitConfiguration);
when(dynamicConfig.getConfiguration()).thenReturn(configuration);
RateLimiters rateLimiters = new RateLimiters(new RateLimitsConfiguration(), dynamicConfig, redisCluster);
CardinalityRateLimiter limiter = rateLimiters.getUnsealedSenderLimiter();
CardinalityRateLimiter limiter = rateLimiters.getUnsealedSenderCardinalityLimiter();
assertThat(limiter.getMaxCardinality()).isEqualTo(10);
assertThat(limiter.getTtl()).isEqualTo(Duration.ofHours(1));
assertThat(limiter.getTtlJitter()).isEqualTo(Duration.ofMinutes(10));
assertSame(rateLimiters.getUnsealedSenderLimiter(), limiter);
assertThat(limiter.getDefaultMaxCardinality()).isEqualTo(10);
assertThat(limiter.getInitialTtl()).isEqualTo(Duration.ofHours(1));
assertSame(rateLimiters.getUnsealedSenderCardinalityLimiter(), limiter);
when(limitsConfiguration.getUnsealedSenderNumber()).thenReturn(new RateLimitsConfiguration.CardinalityRateLimitConfiguration(20, Duration.ofHours(2), Duration.ofMinutes(7)));
when(limitsConfiguration.getUnsealedSenderNumber()).thenReturn(new RateLimitsConfiguration.CardinalityRateLimitConfiguration(20, Duration.ofHours(2)));
CardinalityRateLimiter changed = rateLimiters.getUnsealedSenderLimiter();
CardinalityRateLimiter changed = rateLimiters.getUnsealedSenderCardinalityLimiter();
assertThat(changed.getMaxCardinality()).isEqualTo(20);
assertThat(changed.getTtl()).isEqualTo(Duration.ofHours(2));
assertThat(changed.getTtlJitter()).isEqualTo(Duration.ofMinutes(7));
assertThat(changed.getDefaultMaxCardinality()).isEqualTo(20);
assertThat(changed.getInitialTtl()).isEqualTo(Duration.ofHours(2));
assertNotSame(limiter, changed);
}
}

View File

@ -19,6 +19,7 @@ import org.mockito.stubbing.Answer;
import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.ApnFallbackManager;
import org.whispersystems.textsecuregcm.push.ApnMessage;
import org.whispersystems.textsecuregcm.push.ApnMessage.Type;
import org.whispersystems.textsecuregcm.push.RetryingApnsClient;
import org.whispersystems.textsecuregcm.push.RetryingApnsClient.ApnResult;
import org.whispersystems.textsecuregcm.storage.Account;
@ -65,7 +66,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Type.NOTIFICATION, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager);
@ -99,7 +100,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, false, Optional.empty());
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, false, Type.NOTIFICATION, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager);
@ -135,7 +136,7 @@ public class APNSenderTest {
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Type.NOTIFICATION, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager);
@ -238,7 +239,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Type.NOTIFICATION, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager);
@ -333,7 +334,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), response));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Type.NOTIFICATION, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager);
@ -366,7 +367,7 @@ public class APNSenderTest {
.thenAnswer((Answer) invocationOnMock -> new MockPushNotificationFuture<>(invocationOnMock.getArgument(0), new Exception("lost connection")));
RetryingApnsClient retryingApnsClient = new RetryingApnsClient(apnsClient);
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Optional.empty());
ApnMessage message = new ApnMessage(DESTINATION_APN_ID, DESTINATION_NUMBER, 1, true, Type.NOTIFICATION, Optional.empty());
APNSender apnSender = new APNSender(new SynchronousExecutorService(), accountsManager, retryingApnsClient, "foo", false);
apnSender.setApnFallbackManager(fallbackManager);