Rate limiters code refactored

This commit is contained in:
Sergey Skrobotov 2023-02-23 10:21:39 -08:00
parent 378b32d44d
commit 7529c35013
35 changed files with 738 additions and 774 deletions

View File

@ -12,11 +12,11 @@ import java.util.List;
import java.util.Map;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import org.whispersystems.textsecuregcm.configuration.SpamFilterConfiguration;
import org.whispersystems.textsecuregcm.configuration.AccountDatabaseCrawlerConfiguration;
import org.whispersystems.textsecuregcm.configuration.AdminEventLoggingConfiguration;
import org.whispersystems.textsecuregcm.configuration.ApnConfiguration;
import org.whispersystems.textsecuregcm.configuration.AppConfigConfiguration;
import org.whispersystems.textsecuregcm.configuration.ArtServiceConfiguration;
import org.whispersystems.textsecuregcm.configuration.AwsAttachmentsConfiguration;
import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration;
import org.whispersystems.textsecuregcm.configuration.BraintreeConfiguration;
@ -33,8 +33,7 @@ import org.whispersystems.textsecuregcm.configuration.MaxDeviceConfiguration;
import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration;
import org.whispersystems.textsecuregcm.configuration.OneTimeDonationConfiguration;
import org.whispersystems.textsecuregcm.configuration.PaymentsServiceConfiguration;
import org.whispersystems.textsecuregcm.configuration.ArtServiceConfiguration;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
import org.whispersystems.textsecuregcm.limits.RateLimiterConfig;
import org.whispersystems.textsecuregcm.configuration.RecaptchaConfiguration;
import org.whispersystems.textsecuregcm.configuration.RedisClusterConfiguration;
import org.whispersystems.textsecuregcm.configuration.RedisConfiguration;
@ -44,6 +43,7 @@ import org.whispersystems.textsecuregcm.configuration.ReportMessageConfiguration
import org.whispersystems.textsecuregcm.configuration.SecureBackupServiceConfiguration;
import org.whispersystems.textsecuregcm.configuration.SecureStorageServiceConfiguration;
import org.whispersystems.textsecuregcm.configuration.SecureValueRecovery2Configuration;
import org.whispersystems.textsecuregcm.configuration.SpamFilterConfiguration;
import org.whispersystems.textsecuregcm.configuration.StripeConfiguration;
import org.whispersystems.textsecuregcm.configuration.SubscriptionConfiguration;
import org.whispersystems.textsecuregcm.configuration.TestDeviceConfiguration;
@ -167,7 +167,7 @@ public class WhisperServerConfiguration extends Configuration {
@Valid
@NotNull
@JsonProperty
private RateLimitsConfiguration limits = new RateLimitsConfiguration();
private Map<String, RateLimiterConfig> limits = new HashMap<>();
@Valid
@NotNull
@ -351,7 +351,7 @@ public class WhisperServerConfiguration extends Configuration {
return rateLimitersCluster;
}
public RateLimitsConfiguration getLimitsConfiguration() {
public Map<String, RateLimiterConfig> getLimitsConfiguration() {
return limits;
}

View File

@ -118,7 +118,6 @@ import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter;
import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter;
import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter;
import org.whispersystems.textsecuregcm.limits.DynamicRateLimiters;
import org.whispersystems.textsecuregcm.limits.PushChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
@ -506,8 +505,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
FcmSender fcmSender = new FcmSender(fcmSenderExecutor, config.getFcmConfiguration().credentials());
ApnPushNotificationScheduler apnPushNotificationScheduler = new ApnPushNotificationScheduler(pushSchedulerCluster, apnSender, accountsManager);
PushNotificationManager pushNotificationManager = new PushNotificationManager(accountsManager, apnSender, fcmSender, apnPushNotificationScheduler, pushLatencyManager, dynamicConfigurationManager);
RateLimiters rateLimiters = new RateLimiters(config.getLimitsConfiguration(), rateLimitersCluster);
DynamicRateLimiters dynamicRateLimiters = new DynamicRateLimiters(rateLimitersCluster, dynamicConfigurationManager);
RateLimiters rateLimiters = RateLimiters.createAndValidate(config.getLimitsConfiguration(), dynamicConfigurationManager, rateLimitersCluster);
ProvisioningManager provisioningManager = new ProvisioningManager(pubSubManager);
IssuedReceiptsManager issuedReceiptsManager = new IssuedReceiptsManager(
config.getDynamoDbTables().getIssuedReceipts().getTableName(),
@ -551,7 +549,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
PushChallengeManager pushChallengeManager = new PushChallengeManager(pushNotificationManager, pushChallengeDynamoDb);
RateLimitChallengeManager rateLimitChallengeManager = new RateLimitChallengeManager(pushChallengeManager,
captchaChecker, dynamicRateLimiters);
captchaChecker, rateLimiters);
MessagePersister messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, dynamicConfigurationManager, Duration.ofMinutes(config.getMessageCacheConfiguration().getPersistDelayMinutes()));
ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager);

View File

@ -1,201 +0,0 @@
/*
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.configuration;
import com.fasterxml.jackson.annotation.JsonProperty;
public class RateLimitsConfiguration {
@JsonProperty
private RateLimitConfiguration smsDestination = new RateLimitConfiguration(2, 2);
@JsonProperty
private RateLimitConfiguration voiceDestination = new RateLimitConfiguration(2, 1.0 / 2.0);
@JsonProperty
private RateLimitConfiguration voiceDestinationDaily = new RateLimitConfiguration(10, 10.0 / (24.0 * 60.0));
@JsonProperty
private RateLimitConfiguration smsVoiceIp = new RateLimitConfiguration(1000, 1000);
@JsonProperty
private RateLimitConfiguration smsVoicePrefix = new RateLimitConfiguration(1000, 1000);
@JsonProperty
private RateLimitConfiguration verifyNumber = new RateLimitConfiguration(2, 2);
@JsonProperty
private RateLimitConfiguration verifyPin = new RateLimitConfiguration(10, 1 / (24.0 * 60.0));
@JsonProperty
private RateLimitConfiguration verificationCaptcha = new RateLimitConfiguration(10, 2);
@JsonProperty
private RateLimitConfiguration verificationPushChallenge = new RateLimitConfiguration(5, 2);
@JsonProperty
private RateLimitConfiguration registration = new RateLimitConfiguration(2, 2);
@JsonProperty
private RateLimitConfiguration attachments = new RateLimitConfiguration(50, 50);
@JsonProperty
private RateLimitConfiguration prekeys = new RateLimitConfiguration(6, 1.0 / 10.0);
@JsonProperty
private RateLimitConfiguration messages = new RateLimitConfiguration(60, 60);
@JsonProperty
private RateLimitConfiguration allocateDevice = new RateLimitConfiguration(2, 1.0 / 2.0);
@JsonProperty
private RateLimitConfiguration verifyDevice = new RateLimitConfiguration(6, 1.0 / 10.0);
@JsonProperty
private RateLimitConfiguration turnAllocations = new RateLimitConfiguration(60, 60);
@JsonProperty
private RateLimitConfiguration profile = new RateLimitConfiguration(4320, 3);
@JsonProperty
private RateLimitConfiguration stickerPack = new RateLimitConfiguration(50, 20 / (24.0 * 60.0));
@JsonProperty
private RateLimitConfiguration artPack = new RateLimitConfiguration(50, 20 / (24.0 * 60.0));
@JsonProperty
private RateLimitConfiguration usernameLookup = new RateLimitConfiguration(100, 100 / (24.0 * 60.0));
@JsonProperty
private RateLimitConfiguration usernameSet = new RateLimitConfiguration(100, 100 / (24.0 * 60.0));
@JsonProperty
private RateLimitConfiguration usernameReserve = new RateLimitConfiguration(100, 100 / (24.0 * 60.0));
@JsonProperty
private RateLimitConfiguration checkAccountExistence = new RateLimitConfiguration(1_000, 1_000 / 60.0);
@JsonProperty
private RateLimitConfiguration backupAuthCheck = new RateLimitConfiguration(100, 100 / (24.0 * 60.0));
public RateLimitConfiguration getAllocateDevice() {
return allocateDevice;
}
public RateLimitConfiguration getVerifyDevice() {
return verifyDevice;
}
public RateLimitConfiguration getMessages() {
return messages;
}
public RateLimitConfiguration getPreKeys() {
return prekeys;
}
public RateLimitConfiguration getAttachments() {
return attachments;
}
public RateLimitConfiguration getSmsDestination() {
return smsDestination;
}
public RateLimitConfiguration getVoiceDestination() {
return voiceDestination;
}
public RateLimitConfiguration getVoiceDestinationDaily() {
return voiceDestinationDaily;
}
public RateLimitConfiguration getSmsVoiceIp() {
return smsVoiceIp;
}
public RateLimitConfiguration getSmsVoicePrefix() {
return smsVoicePrefix;
}
public RateLimitConfiguration getVerifyNumber() {
return verifyNumber;
}
public RateLimitConfiguration getVerifyPin() {
return verifyPin;
}
public RateLimitConfiguration getVerificationCaptcha() {
return verificationCaptcha;
}
public RateLimitConfiguration getVerificationPushChallenge() {
return verificationPushChallenge;
}
public RateLimitConfiguration getRegistration() {
return registration;
}
public RateLimitConfiguration getTurnAllocations() {
return turnAllocations;
}
public RateLimitConfiguration getProfile() {
return profile;
}
public RateLimitConfiguration getStickerPack() {
return stickerPack;
}
public RateLimitConfiguration getArtPack() {
return artPack;
}
public RateLimitConfiguration getUsernameLookup() {
return usernameLookup;
}
public RateLimitConfiguration getUsernameSet() {
return usernameSet;
}
public RateLimitConfiguration getUsernameReserve() {
return usernameReserve;
}
public RateLimitConfiguration getCheckAccountExistence() {
return checkAccountExistence;
}
public RateLimitConfiguration getBackupAuthCheck() {
return backupAuthCheck;
}
public static class RateLimitConfiguration {
@JsonProperty
private int bucketSize;
@JsonProperty
private double leakRatePerMinute;
public RateLimitConfiguration(int bucketSize, double leakRatePerMinute) {
this.bucketSize = bucketSize;
this.leakRatePerMinute = leakRatePerMinute;
}
public RateLimitConfiguration() {}
public int getBucketSize() {
return bucketSize;
}
public double getLeakRatePerMinute() {
return leakRatePerMinute;
}
}
}

View File

@ -1,10 +1,17 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.configuration.dynamic;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import javax.validation.Valid;
import org.whispersystems.textsecuregcm.limits.RateLimiterConfig;
public class DynamicConfiguration {
@ -18,7 +25,7 @@ public class DynamicConfiguration {
@JsonProperty
@Valid
private DynamicRateLimitsConfiguration limits = new DynamicRateLimitsConfiguration();
private Map<String, RateLimiterConfig> limits = new HashMap<>();
@JsonProperty
@Valid
@ -65,7 +72,7 @@ public class DynamicConfiguration {
return Optional.ofNullable(preRegistrationExperiments.get(experimentName));
}
public DynamicRateLimitsConfiguration getLimits() {
public Map<String, RateLimiterConfig> getLimits() {
return limits;
}

View File

@ -1,42 +0,0 @@
package org.whispersystems.textsecuregcm.configuration.dynamic;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration;
public class DynamicRateLimitsConfiguration {
@JsonProperty
private RateLimitConfiguration rateLimitReset = new RateLimitConfiguration(2, 2.0 / (60 * 24));
@JsonProperty
private RateLimitConfiguration recaptchaChallengeAttempt = new RateLimitConfiguration(10, 10.0 / (60 * 24));
@JsonProperty
private RateLimitConfiguration recaptchaChallengeSuccess = new RateLimitConfiguration(2, 2.0 / (60 * 24));
@JsonProperty
private RateLimitConfiguration pushChallengeAttempt = new RateLimitConfiguration(10, 10.0 / (60 * 24));
@JsonProperty
private RateLimitConfiguration pushChallengeSuccess = new RateLimitConfiguration(2, 2.0 / (60 * 24));
public RateLimitConfiguration getRateLimitReset() {
return rateLimitReset;
}
public RateLimitConfiguration getRecaptchaChallengeAttempt() {
return recaptchaChallengeAttempt;
}
public RateLimitConfiguration getRecaptchaChallengeSuccess() {
return recaptchaChallengeSuccess;
}
public RateLimitConfiguration getPushChallengeAttempt() {
return pushChallengeAttempt;
}
public RateLimitConfiguration getPushChallengeSuccess() {
return pushChallengeSuccess;
}
}

View File

@ -746,7 +746,7 @@ public class AccountController {
@GET
@Path("/username_hash/{usernameHash}")
@Produces(MediaType.APPLICATION_JSON)
@RateLimitedByIp(RateLimiters.Handle.USERNAME_LOOKUP)
@RateLimitedByIp(RateLimiters.For.USERNAME_LOOKUP)
public AccountIdentifierResponse lookupUsernameHash(
@HeaderParam(HeaderUtils.X_SIGNAL_AGENT) final String userAgent,
@HeaderParam(HttpHeaders.X_FORWARDED_FOR) final String forwardedFor,
@ -780,7 +780,7 @@ public class AccountController {
@HEAD
@Path("/account/{uuid}")
@RateLimitedByIp(RateLimiters.Handle.CHECK_ACCOUNT_EXISTENCE)
@RateLimitedByIp(RateLimiters.For.CHECK_ACCOUNT_EXISTENCE)
public Response accountExists(
@PathParam("uuid") final UUID uuid,
@Context HttpServletRequest request) throws RateLimitExceededException {

View File

@ -79,7 +79,7 @@ public class SecureBackupController {
@Path("/auth/check")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@RateLimitedByIp(RateLimiters.Handle.BACKUP_AUTH_CHECK)
@RateLimitedByIp(RateLimiters.For.BACKUP_AUTH_CHECK)
public AuthCheckResponse authCheck(@NotNull @Valid final AuthCheckRequest request) {
final Map<String, AuthCheckResponse.Result> results = new HashMap<>();
final Map<String, Pair<UUID, Long>> tokenToUuid = new HashMap<>();

View File

@ -0,0 +1,82 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static java.util.Objects.requireNonNull;
import java.lang.invoke.MethodHandles;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
private final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
private final Map<T, RateLimiter> rateLimiterByDescriptor;
private final Map<String, RateLimiterConfig> configs;
protected BaseRateLimiters(
final T[] values,
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final FaultTolerantRedisCluster cacheCluster) {
this.configs = configs;
this.rateLimiterByDescriptor = Arrays.stream(values)
.map(descriptor -> Pair.of(
descriptor,
createForDescriptor(descriptor, configs, dynamicConfigurationManager, cacheCluster)))
.collect(Collectors.toUnmodifiableMap(Pair::getKey, Pair::getValue));
}
public RateLimiter forDescriptor(final T handle) {
return requireNonNull(rateLimiterByDescriptor.get(handle));
}
public void validateValuesAndConfigs() {
final Set<String> ids = rateLimiterByDescriptor.keySet().stream()
.map(RateLimiterDescriptor::id)
.collect(Collectors.toSet());
for (final String key: configs.keySet()) {
if (!ids.contains(key)) {
final String message = String.format(
"Static configuration has an unexpected field '%s' that doesn't match any RateLimiterDescriptor",
key
);
logger.error(message);
throw new IllegalArgumentException(message);
}
}
}
private static RateLimiter createForDescriptor(
final RateLimiterDescriptor descriptor,
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final FaultTolerantRedisCluster cacheCluster) {
if (descriptor.isDynamic()) {
final Supplier<RateLimiterConfig> configResolver = () -> {
final RateLimiterConfig config = dynamicConfigurationManager.getConfiguration().getLimits().get(descriptor.id());
return config != null
? config
: configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
};
return new DynamicRateLimiter(descriptor.id(), configResolver, cacheCluster);
}
final RateLimiterConfig cfg = configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
return new StaticRateLimiter(descriptor.id(), cfg, cacheCluster);
}
}

View File

@ -0,0 +1,63 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static java.util.Objects.requireNonNull;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import org.apache.commons.lang3.tuple.Pair;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
public class DynamicRateLimiter implements RateLimiter {
private final String name;
private final Supplier<RateLimiterConfig> configResolver;
private final FaultTolerantRedisCluster cluster;
private final AtomicReference<Pair<RateLimiterConfig, RateLimiter>> currentHolder = new AtomicReference<>();
public DynamicRateLimiter(
final String name,
final Supplier<RateLimiterConfig> configResolver,
final FaultTolerantRedisCluster cluster) {
this.name = requireNonNull(name);
this.configResolver = requireNonNull(configResolver);
this.cluster = requireNonNull(cluster);
}
@Override
public void validate(final String key, final int amount) throws RateLimitExceededException {
current().getRight().validate(key, amount);
}
@Override
public boolean hasAvailablePermits(final String key, final int permits) {
return current().getRight().hasAvailablePermits(key, permits);
}
@Override
public void clear(final String key) {
current().getRight().clear(key);
}
@Override
public RateLimiterConfig config() {
return current().getLeft();
}
private Pair<RateLimiterConfig, RateLimiter> current() {
final RateLimiterConfig cfg = configResolver.get();
return currentHolder.updateAndGet(p -> p != null && p.getLeft().equals(cfg)
? p
: Pair.of(cfg, new StaticRateLimiter(name, cfg, cluster))
);
}
}

View File

@ -1,130 +0,0 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public class DynamicRateLimiters {
private final FaultTolerantRedisCluster cacheCluster;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final AtomicReference<RateLimiter> rateLimitResetLimiter;
private final AtomicReference<RateLimiter> recaptchaChallengeAttemptLimiter;
private final AtomicReference<RateLimiter> recaptchaChallengeSuccessLimiter;
private final AtomicReference<RateLimiter> pushChallengeAttemptLimiter;
private final AtomicReference<RateLimiter> pushChallengeSuccessLimiter;
public DynamicRateLimiters(final FaultTolerantRedisCluster rateLimitCluster,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) {
this.cacheCluster = rateLimitCluster;
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.rateLimitResetLimiter = new AtomicReference<>(
createRateLimitResetLimiter(this.cacheCluster,
this.dynamicConfigurationManager.getConfiguration().getLimits().getRateLimitReset()));
this.recaptchaChallengeAttemptLimiter = new AtomicReference<>(createRecaptchaChallengeAttemptLimiter(
this.cacheCluster,
this.dynamicConfigurationManager.getConfiguration().getLimits().getRecaptchaChallengeAttempt()));
this.recaptchaChallengeSuccessLimiter = new AtomicReference<>(createRecaptchaChallengeSuccessLimiter(
this.cacheCluster,
this.dynamicConfigurationManager.getConfiguration().getLimits().getRecaptchaChallengeSuccess()));
this.pushChallengeAttemptLimiter = new AtomicReference<>(createPushChallengeAttemptLimiter(this.cacheCluster,
this.dynamicConfigurationManager.getConfiguration().getLimits().getPushChallengeAttempt()));
this.pushChallengeSuccessLimiter = new AtomicReference<>(createPushChallengeSuccessLimiter(this.cacheCluster,
this.dynamicConfigurationManager.getConfiguration().getLimits().getPushChallengeSuccess()));
}
public RateLimiter getRateLimitResetLimiter() {
return updateAndGetRateLimiter(
rateLimitResetLimiter,
dynamicConfigurationManager.getConfiguration().getLimits().getRateLimitReset(),
this::createRateLimitResetLimiter);
}
public RateLimiter getRecaptchaChallengeAttemptLimiter() {
return updateAndGetRateLimiter(
recaptchaChallengeAttemptLimiter,
dynamicConfigurationManager.getConfiguration().getLimits().getRecaptchaChallengeAttempt(),
this::createRecaptchaChallengeAttemptLimiter);
}
public RateLimiter getRecaptchaChallengeSuccessLimiter() {
return updateAndGetRateLimiter(
recaptchaChallengeSuccessLimiter,
dynamicConfigurationManager.getConfiguration().getLimits().getRecaptchaChallengeSuccess(),
this::createRecaptchaChallengeSuccessLimiter);
}
public RateLimiter getPushChallengeAttemptLimiter() {
return updateAndGetRateLimiter(
pushChallengeAttemptLimiter,
dynamicConfigurationManager.getConfiguration().getLimits().getPushChallengeAttempt(),
this::createPushChallengeAttemptLimiter);
}
public RateLimiter getPushChallengeSuccessLimiter() {
return updateAndGetRateLimiter(
pushChallengeSuccessLimiter,
dynamicConfigurationManager.getConfiguration().getLimits().getPushChallengeSuccess(),
this::createPushChallengeSuccessLimiter);
}
private RateLimiter updateAndGetRateLimiter(final AtomicReference<RateLimiter> rateLimiter,
RateLimitConfiguration currentConfiguration,
BiFunction<FaultTolerantRedisCluster, RateLimitConfiguration, RateLimiter> rateLimitFactory) {
return rateLimiter.updateAndGet(limiter -> {
if (limiter.hasConfiguration(currentConfiguration)) {
return limiter;
} else {
return rateLimitFactory.apply(cacheCluster, currentConfiguration);
}
});
}
public RateLimiter createRateLimitResetLimiter(FaultTolerantRedisCluster cacheCluster,
RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "rateLimitReset");
}
public RateLimiter createRecaptchaChallengeAttemptLimiter(FaultTolerantRedisCluster cacheCluster,
RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "recaptchaChallengeAttempt");
}
public RateLimiter createRecaptchaChallengeSuccessLimiter(FaultTolerantRedisCluster cacheCluster,
RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "recaptchaChallengeSuccess");
}
public RateLimiter createPushChallengeAttemptLimiter(FaultTolerantRedisCluster cacheCluster,
RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "pushChallengeAttempt");
}
public RateLimiter createPushChallengeSuccessLimiter(FaultTolerantRedisCluster cacheCluster,
RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "pushChallengeSuccess");
}
private RateLimiter createLimiter(FaultTolerantRedisCluster cacheCluster, RateLimitConfiguration configuration,
String name) {
return new RateLimiter(cacheCluster, name,
configuration.getBucketSize(),
configuration.getLeakRatePerMinute());
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -16,22 +16,28 @@ import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.Constants;
public class LockingRateLimiter extends RateLimiter {
public class LockingRateLimiter extends StaticRateLimiter {
private static final RateLimitExceededException REUSABLE_RATE_LIMIT_EXCEEDED_EXCEPTION
= new RateLimitExceededException(Duration.ZERO, true);
private final Meter meter;
public LockingRateLimiter(FaultTolerantRedisCluster cacheCluster, String name, int bucketSize, double leakRatePerMinute) {
super(cacheCluster, name, bucketSize, leakRatePerMinute);
MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
public LockingRateLimiter(
final String name,
final RateLimiterConfig config,
final FaultTolerantRedisCluster cacheCluster) {
super(name, config, cacheCluster);
final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
this.meter = metricRegistry.meter(name(getClass(), name, "locked"));
}
@Override
public void validate(String key, int amount) throws RateLimitExceededException {
public void validate(final String key, final int amount) throws RateLimitExceededException {
if (!acquireLock(key)) {
meter.mark();
throw new RateLimitExceededException(Duration.ZERO, true);
throw REUSABLE_RATE_LIMIT_EXCEEDED_EXCEPTION;
}
try {
@ -41,22 +47,15 @@ public class LockingRateLimiter extends RateLimiter {
}
}
@Override
public void validate(String key) throws RateLimitExceededException {
validate(key, 1);
}
private void releaseLock(String key) {
private void releaseLock(final String key) {
cacheCluster.useCluster(connection -> connection.sync().del(getLockName(key)));
}
private boolean acquireLock(String key) {
private boolean acquireLock(final String key) {
return cacheCluster.withCluster(connection -> connection.sync().set(getLockName(key), "L", SetArgs.Builder.nx().ex(10))) != null;
}
private String getLockName(String key) {
private String getLockName(final String key) {
return "leaky_lock::" + name + "::" + key;
}
}

View File

@ -58,7 +58,7 @@ public class RateLimitByIpFilter implements ContainerRequestFilter {
return;
}
final RateLimiters.Handle handle = annotation.value();
final RateLimiters.For handle = annotation.value();
try {
final String xffHeader = requestContext.getHeaders().getFirst(HttpHeaders.X_FORWARDED_FOR);
@ -77,13 +77,8 @@ public class RateLimitByIpFilter implements ContainerRequestFilter {
return;
}
final Optional<RateLimiter> maybeRateLimiter = rateLimiters.byHandle(handle);
if (maybeRateLimiter.isEmpty()) {
logger.warn("RateLimiter not found for {}. Make sure it's initialized in RateLimiters class", handle);
return;
}
maybeRateLimiter.get().validate(maybeMostRecentProxy.get());
final RateLimiter rateLimiter = rateLimiters.forDescriptor(handle);
rateLimiter.validate(maybeMostRecentProxy.get());
} catch (RateLimitExceededException e) {
final Response response = EXCEPTION_MAPPER.toResponse(e);
throw new ClientErrorException(response);

View File

@ -1,3 +1,8 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static com.codahale.metrics.MetricRegistry.name;
@ -9,11 +14,11 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.whispersystems.textsecuregcm.spam.RateLimitChallengeListener;
import org.whispersystems.textsecuregcm.captcha.CaptchaChecker;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.spam.RateLimitChallengeListener;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Util;
@ -21,7 +26,7 @@ public class RateLimitChallengeManager {
private final PushChallengeManager pushChallengeManager;
private final CaptchaChecker captchaChecker;
private final DynamicRateLimiters rateLimiters;
private final RateLimiters rateLimiters;
private final List<RateLimitChallengeListener> rateLimitChallengeListeners =
Collections.synchronizedList(new ArrayList<>());
@ -35,7 +40,7 @@ public class RateLimitChallengeManager {
public RateLimitChallengeManager(
final PushChallengeManager pushChallengeManager,
final CaptchaChecker captchaChecker,
final DynamicRateLimiters rateLimiters) {
final RateLimiters rateLimiters) {
this.pushChallengeManager = pushChallengeManager;
this.captchaChecker = captchaChecker;

View File

@ -1,30 +1,30 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import com.vdurmont.semver4j.Semver;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
public class RateLimitChallengeOptionManager {
private final DynamicRateLimiters rateLimiters;
private final RateLimiters rateLimiters;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
public static final String OPTION_RECAPTCHA = "recaptcha";
public static final String OPTION_PUSH_CHALLENGE = "pushChallenge";
public RateLimitChallengeOptionManager(final DynamicRateLimiters rateLimiters,
public RateLimitChallengeOptionManager(final RateLimiters rateLimiters,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) {
this.rateLimiters = rateLimiters;

View File

@ -14,7 +14,7 @@ import java.lang.annotation.Target;
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimitedByIp {
RateLimiters.Handle value();
RateLimiters.For value();
boolean failOnUnresolvedIp() default true;
}

View File

@ -1,143 +1,48 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.time.Duration;
import java.util.UUID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.SystemMapper;
public class RateLimiter {
public interface RateLimiter {
private final Logger logger = LoggerFactory.getLogger(RateLimiter.class);
private final ObjectMapper mapper = SystemMapper.getMapper();
void validate(String key, int amount) throws RateLimitExceededException;
private final Meter meter;
private final Timer validateTimer;
protected final FaultTolerantRedisCluster cacheCluster;
protected final String name;
private final int bucketSize;
private final double leakRatePerMinute;
private final double leakRatePerMillis;
boolean hasAvailablePermits(String key, int permits);
public RateLimiter(FaultTolerantRedisCluster cacheCluster, String name, int bucketSize, double leakRatePerMinute)
{
MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
void clear(String key);
this.meter = metricRegistry.meter(name(getClass(), name, "exceeded"));
this.validateTimer = metricRegistry.timer(name(getClass(), name, "validate"));
this.cacheCluster = cacheCluster;
this.name = name;
this.bucketSize = bucketSize;
this.leakRatePerMinute = leakRatePerMinute;
this.leakRatePerMillis = leakRatePerMinute / (60.0 * 1000.0);
}
RateLimiterConfig config();
public void validate(String key, int amount) throws RateLimitExceededException {
try (final Timer.Context ignored = validateTimer.time()) {
LeakyBucket bucket = getBucket(key);
if (bucket.add(amount)) {
setBucket(key, bucket);
} else {
meter.mark();
throw new RateLimitExceededException(bucket.getTimeUntilSpaceAvailable(amount), true);
}
}
}
public void validate(final UUID accountUuid) throws RateLimitExceededException {
validate(accountUuid.toString());
}
public void validate(final UUID sourceAccountUuid, final UUID destinationAccountUuid)
throws RateLimitExceededException {
validate(sourceAccountUuid.toString() + "__" + destinationAccountUuid.toString());
}
public void validate(String key) throws RateLimitExceededException {
default void validate(final String key) throws RateLimitExceededException {
validate(key, 1);
}
public boolean hasAvailablePermits(final UUID accountUuid, final int permits) {
default void validate(final UUID accountUuid) throws RateLimitExceededException {
validate(accountUuid.toString());
}
default void validate(final UUID srcAccountUuid, final UUID dstAccountUuid) throws RateLimitExceededException {
validate(srcAccountUuid.toString() + "__" + dstAccountUuid.toString());
}
default boolean hasAvailablePermits(final UUID accountUuid, final int permits) {
return hasAvailablePermits(accountUuid.toString(), permits);
}
public boolean hasAvailablePermits(final String key, final int permits) {
return getBucket(key).getTimeUntilSpaceAvailable(permits).equals(Duration.ZERO);
}
public void clear(final UUID accountUuid) {
default void clear(final UUID accountUuid) {
clear(accountUuid.toString());
}
public void clear(String key) {
cacheCluster.useCluster(connection -> connection.sync().del(getBucketName(key)));
}
public int getBucketSize() {
return bucketSize;
}
public double getLeakRatePerMinute() {
return leakRatePerMinute;
}
private void setBucket(String key, LeakyBucket bucket) {
try {
final String serialized = bucket.serialize(mapper);
cacheCluster.useCluster(connection -> connection.sync().setex(getBucketName(key), (int) Math.ceil((bucketSize / leakRatePerMillis) / 1000), serialized));
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
}
private LeakyBucket getBucket(String key) {
try {
final String serialized = cacheCluster.withCluster(connection -> connection.sync().get(getBucketName(key)));
if (serialized != null) {
return LeakyBucket.fromSerialized(mapper, serialized);
}
} catch (IOException e) {
logger.warn("Deserialization error", e);
}
return new LeakyBucket(bucketSize, leakRatePerMillis);
}
private String getBucketName(String key) {
return "leaky_bucket::" + name + "::" + key;
}
public boolean hasConfiguration(final RateLimitConfiguration configuration) {
return bucketSize == configuration.getBucketSize() && leakRatePerMinute == configuration.getLeakRatePerMinute();
}
/**
* If the wrapped {@code validate()} call throws a {@link RateLimitExceededException}, it will adapt it to ensure that
* {@link RateLimitExceededException#isLegacy()} returns {@code true}
*/
public static void adaptLegacyException(final RateLimitValidator validator) throws RateLimitExceededException {
static void adaptLegacyException(final RateLimitValidator validator) throws RateLimitExceededException {
try {
validator.validate();
} catch (final RateLimitExceededException e) {
@ -146,9 +51,8 @@ public class RateLimiter {
}
@FunctionalInterface
public interface RateLimitValidator {
interface RateLimitValidator {
void validate() throws RateLimitExceededException;
}
}

View File

@ -0,0 +1,13 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
public record RateLimiterConfig(int bucketSize, double leakRatePerMinute) {
public double leakRatePerMillis() {
return leakRatePerMinute / (60.0 * 1000.0);
}
}

View File

@ -0,0 +1,28 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
/**
* Represents an information that defines a rate limiter.
*/
public interface RateLimiterDescriptor {
/**
* Implementing classes will likely be Enums, so name is chosen not to clash with {@link Enum#name()}.
* @return id of this rate limiter to be used in `yml` config files and as a part of the bucket key.
*/
String id();
/**
* @return {@code true} if this rate limiter needs to watch for dynamic configuration changes.
*/
boolean isDynamic();
/**
* @return an instance of {@link RateLimiterConfig} to be used by default,
* i.e. if there is no overrides in the application configuration files (static or dynamic).
*/
RateLimiterConfig defaultConfig();
}

View File

@ -5,193 +5,232 @@
package org.whispersystems.textsecuregcm.limits;
import com.google.common.annotations.VisibleForTesting;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.Pair;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public class RateLimiters {
public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
public enum Handle {
USERNAME_LOOKUP("usernameLookup"),
CHECK_ACCOUNT_EXISTENCE("checkAccountExistence"),
BACKUP_AUTH_CHECK;
public enum For implements RateLimiterDescriptor {
BACKUP_AUTH_CHECK("backupAuthCheck", false, new RateLimiterConfig(100, 100 / (24.0 * 60.0))),
SMS_DESTINATION("smsDestination", false, new RateLimiterConfig(2, 2)),
VOICE_DESTINATION("voxDestination", false, new RateLimiterConfig(2, 1.0 / 2.0)),
VOICE_DESTINATION_DAILY("voxDestinationDaily", false, new RateLimiterConfig(10, 10.0 / (24.0 * 60.0))),
SMS_VOICE_IP("smsVoiceIp", false, new RateLimiterConfig(1000, 1000)),
SMS_VOICE_PREFIX("smsVoicePrefix", false, new RateLimiterConfig(1000, 1000)),
VERIFY("verify", false, new RateLimiterConfig(2, 2)),
PIN("pin", false, new RateLimiterConfig(10, 1 / (24.0 * 60.0))),
ATTACHMENT("attachmentCreate", false, new RateLimiterConfig(50, 50)),
PRE_KEYS("prekeys", false, new RateLimiterConfig(6, 1.0 / 10.0)),
MESSAGES("messages", false, new RateLimiterConfig(60, 60)),
ALLOCATE_DEVICE("allocateDevice", false, new RateLimiterConfig(2, 1.0 / 2.0)),
VERIFY_DEVICE("verifyDevice", false, new RateLimiterConfig(6, 1.0 / 10.0)),
TURN("turnAllocate", false, new RateLimiterConfig(60, 60)),
PROFILE("profile", false, new RateLimiterConfig(4320, 3)),
STICKER_PACK("stickerPack", false, new RateLimiterConfig(50, 20 / (24.0 * 60.0))),
ART_PACK("artPack", false, new RateLimiterConfig(50, 20 / (24.0 * 60.0))),
USERNAME_LOOKUP("usernameLookup", false, new RateLimiterConfig(100, 100 / (24.0 * 60.0))),
USERNAME_SET("usernameSet", false, new RateLimiterConfig(100, 100 / (24.0 * 60.0))),
USERNAME_RESERVE("usernameReserve", false, new RateLimiterConfig(100, 100 / (24.0 * 60.0))),
CHECK_ACCOUNT_EXISTENCE("checkAccountExistence", false, new RateLimiterConfig(1_000, 1_000 / 60.0)),
STORIES("stories", false, new RateLimiterConfig(10_000, 10_000 / (24.0 * 60.0))),
REGISTRATION("registration", false, new RateLimiterConfig(2, 2)),
VERIFICATION_PUSH_CHALLENGE("verificationPushChallenge", false, new RateLimiterConfig(5, 2)),
VERIFICATION_CAPTCHA("verificationCaptcha", false, new RateLimiterConfig(10, 2)),
RATE_LIMIT_RESET("rateLimitReset", true, new RateLimiterConfig(2, 2.0 / (60 * 24))),
RECAPTCHA_CHALLENGE_ATTEMPT("recaptchaChallengeAttempt", true, new RateLimiterConfig(10, 10.0 / (60 * 24))),
RECAPTCHA_CHALLENGE_SUCCESS("recaptchaChallengeSuccess", true, new RateLimiterConfig(2, 2.0 / (60 * 24))),
PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", true, new RateLimiterConfig(10, 10.0 / (60 * 24))),
PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", true, new RateLimiterConfig(2, 2.0 / (60 * 24))),
;
private final String id;
private final boolean dynamic;
Handle(final String id) {
private final RateLimiterConfig defaultConfig;
For(final String id, final boolean dynamic, final RateLimiterConfig defaultConfig) {
this.id = id;
}
Handle() {
this.id = name();
this.dynamic = dynamic;
this.defaultConfig = defaultConfig;
}
public String id() {
return id;
}
@Override
public boolean isDynamic() {
return dynamic;
}
public RateLimiterConfig defaultConfig() {
return defaultConfig;
}
}
private final RateLimiter smsDestinationLimiter;
private final RateLimiter voiceDestinationLimiter;
private final RateLimiter voiceDestinationDailyLimiter;
private final RateLimiter smsVoiceIpLimiter;
private final RateLimiter smsVoicePrefixLimiter;
private final RateLimiter verifyLimiter;
private final RateLimiter verificationCaptchaLimiter;
private final RateLimiter verificationPushChallengeLimiter;
private final RateLimiter pinLimiter;
private final RateLimiter registrationLimiter;
private final RateLimiter attachmentLimiter;
private final RateLimiter preKeysLimiter;
private final RateLimiter messagesLimiter;
private final RateLimiter allocateDeviceLimiter;
private final RateLimiter verifyDeviceLimiter;
private final RateLimiter turnLimiter;
private final RateLimiter profileLimiter;
private final RateLimiter stickerPackLimiter;
private final RateLimiter artPackLimiter;
private final RateLimiter usernameSetLimiter;
private final RateLimiter usernameReserveLimiter;
private final Map<String, RateLimiter> rateLimiterByHandle;
public RateLimiters(final RateLimitsConfiguration config, final FaultTolerantRedisCluster cacheCluster) {
this.smsDestinationLimiter = fromConfig("smsDestination", config.getSmsDestination(), cacheCluster);
this.voiceDestinationLimiter = fromConfig("voxDestination", config.getVoiceDestination(), cacheCluster);
this.voiceDestinationDailyLimiter = fromConfig("voxDestinationDaily", config.getVoiceDestinationDaily(),
cacheCluster);
this.smsVoiceIpLimiter = fromConfig("smsVoiceIp", config.getSmsVoiceIp(), cacheCluster);
this.smsVoicePrefixLimiter = fromConfig("smsVoicePrefix", config.getSmsVoicePrefix(), cacheCluster);
this.verifyLimiter = fromConfig("verify", config.getVerifyNumber(), cacheCluster);
this.verificationCaptchaLimiter = fromConfig("verificationCaptcha", config.getVerificationCaptcha(), cacheCluster);
this.verificationPushChallengeLimiter = fromConfig("verificationPushChallenge",
config.getVerificationPushChallenge(), cacheCluster);
this.pinLimiter = fromConfig("pin", config.getVerifyPin(), cacheCluster);
this.registrationLimiter = fromConfig("registration", config.getRegistration(), cacheCluster);
this.attachmentLimiter = fromConfig("attachmentCreate", config.getAttachments(), cacheCluster);
this.preKeysLimiter = fromConfig("prekeys", config.getPreKeys(), cacheCluster);
this.messagesLimiter = fromConfig("messages", config.getMessages(), cacheCluster);
this.allocateDeviceLimiter = fromConfig("allocateDevice", config.getAllocateDevice(), cacheCluster);
this.verifyDeviceLimiter = fromConfig("verifyDevice", config.getVerifyDevice(), cacheCluster);
this.turnLimiter = fromConfig("turnAllocate", config.getTurnAllocations(), cacheCluster);
this.profileLimiter = fromConfig("profile", config.getProfile(), cacheCluster);
this.stickerPackLimiter = fromConfig("stickerPack", config.getStickerPack(), cacheCluster);
this.artPackLimiter = fromConfig("artPack", config.getArtPack(), cacheCluster);
this.usernameSetLimiter = fromConfig("usernameSet", config.getUsernameSet(), cacheCluster);
this.usernameReserveLimiter = fromConfig("usernameReserve", config.getUsernameReserve(), cacheCluster);
this.rateLimiterByHandle = Stream.of(
fromConfig(Handle.BACKUP_AUTH_CHECK.id(), config.getBackupAuthCheck(), cacheCluster),
fromConfig(Handle.CHECK_ACCOUNT_EXISTENCE.id(), config.getCheckAccountExistence(), cacheCluster),
fromConfig(Handle.USERNAME_LOOKUP.id(), config.getUsernameLookup(), cacheCluster)
).map(rl -> Pair.of(rl.name, rl)).collect(Collectors.toMap(Pair::getKey, Pair::getValue));
public static RateLimiters createAndValidate(
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final FaultTolerantRedisCluster cacheCluster) {
final RateLimiters rateLimiters = new RateLimiters(configs, dynamicConfigurationManager, cacheCluster);
rateLimiters.validateValuesAndConfigs();
return rateLimiters;
}
public Optional<RateLimiter> byHandle(final Handle handle) {
return Optional.ofNullable(rateLimiterByHandle.get(handle.id()));
@VisibleForTesting
RateLimiters(
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final FaultTolerantRedisCluster cacheCluster) {
super(For.values(), configs, dynamicConfigurationManager, cacheCluster);
}
public RateLimiter getAllocateDeviceLimiter() {
return allocateDeviceLimiter;
return forDescriptor(For.ALLOCATE_DEVICE);
}
public RateLimiter getVerifyDeviceLimiter() {
return verifyDeviceLimiter;
return forDescriptor(For.VERIFY_DEVICE);
}
public RateLimiter getMessagesLimiter() {
return messagesLimiter;
return forDescriptor(For.MESSAGES);
}
public RateLimiter getPreKeysLimiter() {
return preKeysLimiter;
return forDescriptor(For.PRE_KEYS);
}
public RateLimiter getAttachmentLimiter() {
return this.attachmentLimiter;
return forDescriptor(For.ATTACHMENT);
}
public RateLimiter getSmsDestinationLimiter() {
return smsDestinationLimiter;
return forDescriptor(For.SMS_DESTINATION);
}
public RateLimiter getSmsVoiceIpLimiter() {
return smsVoiceIpLimiter;
return forDescriptor(For.SMS_VOICE_IP);
}
public RateLimiter getSmsVoicePrefixLimiter() {
return smsVoicePrefixLimiter;
return forDescriptor(For.SMS_VOICE_PREFIX);
}
public RateLimiter getVoiceDestinationLimiter() {
return voiceDestinationLimiter;
return forDescriptor(For.VOICE_DESTINATION);
}
public RateLimiter getVoiceDestinationDailyLimiter() {
return voiceDestinationDailyLimiter;
return forDescriptor(For.VOICE_DESTINATION_DAILY);
}
public RateLimiter getVerifyLimiter() {
return verifyLimiter;
}
public RateLimiter getVerificationCaptchaLimiter() {
return verificationCaptchaLimiter;
}
public RateLimiter getVerificationPushChallengeLimiter() {
return verificationPushChallengeLimiter;
return forDescriptor(For.VERIFY);
}
public RateLimiter getPinLimiter() {
return pinLimiter;
}
public RateLimiter getRegistrationLimiter() {
return registrationLimiter;
return forDescriptor(For.PIN);
}
public RateLimiter getTurnLimiter() {
return turnLimiter;
return forDescriptor(For.TURN);
}
public RateLimiter getProfileLimiter() {
return profileLimiter;
return forDescriptor(For.PROFILE);
}
public RateLimiter getStickerPackLimiter() {
return stickerPackLimiter;
return forDescriptor(For.STICKER_PACK);
}
public RateLimiter getArtPackLimiter() {
return artPackLimiter;
return forDescriptor(For.ART_PACK);
}
public RateLimiter getUsernameLookupLimiter() {
return byHandle(Handle.USERNAME_LOOKUP).orElseThrow();
return forDescriptor(For.USERNAME_LOOKUP);
}
public RateLimiter getUsernameSetLimiter() {
return usernameSetLimiter;
return forDescriptor(For.USERNAME_SET);
}
public RateLimiter getUsernameReserveLimiter() {
return usernameReserveLimiter;
return forDescriptor(For.USERNAME_RESERVE);
}
public RateLimiter getCheckAccountExistenceLimiter() {
return byHandle(Handle.CHECK_ACCOUNT_EXISTENCE).orElseThrow();
return forDescriptor(For.CHECK_ACCOUNT_EXISTENCE);
}
private static RateLimiter fromConfig(
final String name,
final RateLimitsConfiguration.RateLimitConfiguration cfg,
final FaultTolerantRedisCluster cacheCluster) {
return new RateLimiter(cacheCluster, name, cfg.getBucketSize(), cfg.getLeakRatePerMinute());
public RateLimiter getStoriesLimiter() {
return forDescriptor(For.STORIES);
}
public RateLimiter getRegistrationLimiter() {
return forDescriptor(For.REGISTRATION);
}
public RateLimiter getRateLimitResetLimiter() {
return forDescriptor(For.RATE_LIMIT_RESET);
}
public RateLimiter getRecaptchaChallengeAttemptLimiter() {
return forDescriptor(For.RECAPTCHA_CHALLENGE_ATTEMPT);
}
public RateLimiter getRecaptchaChallengeSuccessLimiter() {
return forDescriptor(For.RECAPTCHA_CHALLENGE_SUCCESS);
}
public RateLimiter getPushChallengeAttemptLimiter() {
return forDescriptor(For.PUSH_CHALLENGE_ATTEMPT);
}
public RateLimiter getPushChallengeSuccessLimiter() {
return forDescriptor(For.PUSH_CHALLENGE_SUCCESS);
}
public RateLimiter getVerificationPushChallengeLimiter() {
return forDescriptor(For.VERIFICATION_PUSH_CHALLENGE);
}
public RateLimiter getVerificationCaptchaLimiter() {
return forDescriptor(For.VERIFICATION_CAPTCHA);
}
}

View File

@ -0,0 +1,111 @@
/*
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static com.codahale.metrics.MetricRegistry.name;
import static java.util.Objects.requireNonNull;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.SystemMapper;
public class StaticRateLimiter implements RateLimiter {
private static final Logger logger = LoggerFactory.getLogger(StaticRateLimiter.class);
private static final ObjectMapper MAPPER = SystemMapper.getMapper();
protected final String name;
private final RateLimiterConfig config;
protected final FaultTolerantRedisCluster cacheCluster;
private final Meter meter;
private final Timer validateTimer;
public StaticRateLimiter(
final String name,
final RateLimiterConfig config,
final FaultTolerantRedisCluster cacheCluster) {
final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
this.name = requireNonNull(name);
this.config = requireNonNull(config);
this.cacheCluster = requireNonNull(cacheCluster);
this.meter = metricRegistry.meter(name(getClass(), name, "exceeded"));
this.validateTimer = metricRegistry.timer(name(getClass(), name, "validate"));
}
@Override
public void validate(final String key, final int amount) throws RateLimitExceededException {
try (final Timer.Context ignored = validateTimer.time()) {
final LeakyBucket bucket = getBucket(key);
if (bucket.add(amount)) {
setBucket(key, bucket);
} else {
meter.mark();
throw new RateLimitExceededException(bucket.getTimeUntilSpaceAvailable(amount), true);
}
}
}
@Override
public boolean hasAvailablePermits(final String key, final int permits) {
return getBucket(key).getTimeUntilSpaceAvailable(permits).equals(Duration.ZERO);
}
@Override
public void clear(final String key) {
cacheCluster.useCluster(connection -> connection.sync().del(getBucketName(key)));
}
@Override
public RateLimiterConfig config() {
return config;
}
private void setBucket(final String key, final LeakyBucket bucket) {
try {
final String serialized = bucket.serialize(MAPPER);
cacheCluster.useCluster(connection -> connection.sync().setex(
getBucketName(key),
(int) Math.ceil((config.bucketSize() / config.leakRatePerMillis()) / 1000),
serialized));
} catch (final JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
}
private LeakyBucket getBucket(final String key) {
try {
final String serialized = cacheCluster.withCluster(connection -> connection.sync().get(getBucketName(key)));
if (serialized != null) {
return LeakyBucket.fromSerialized(MAPPER, serialized);
}
} catch (final IOException e) {
logger.warn("Deserialization error", e);
}
return new LeakyBucket(config.bucketSize(), config.leakRatePerMillis());
}
private String getBucketName(final String key) {
return "leaky_bucket::" + name + "::" + key;
}
}

View File

@ -20,7 +20,8 @@ import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration;
import org.whispersystems.textsecuregcm.limits.RateLimiterConfig;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
@ -274,30 +275,19 @@ class DynamicConfigurationTest {
@Test
void testParseLimits() throws JsonProcessingException {
{
final String emptyConfigYaml = REQUIRED_CONFIG.concat("test: true");
final DynamicConfiguration emptyConfig =
DynamicConfigurationManager.parseConfiguration(emptyConfigYaml, DynamicConfiguration.class).orElseThrow();
assertThat(emptyConfig.getLimits().getRateLimitReset().getBucketSize()).isEqualTo(2);
assertThat(emptyConfig.getLimits().getRateLimitReset().getLeakRatePerMinute()).isEqualTo(2.0 / (60 * 24));
}
{
final String limitsConfig = REQUIRED_CONFIG.concat("""
final String limitsConfig = REQUIRED_CONFIG.concat("""
limits:
rateLimitReset:
bucketSize: 17
leakRatePerMinute: 44
""");
final RateLimitConfiguration resetRateLimitConfiguration =
DynamicConfigurationManager.parseConfiguration(limitsConfig, DynamicConfiguration.class).orElseThrow()
.getLimits().getRateLimitReset();
final RateLimiterConfig resetRateLimiterConfig =
DynamicConfigurationManager.parseConfiguration(limitsConfig, DynamicConfiguration.class).orElseThrow()
.getLimits().get(RateLimiters.For.RATE_LIMIT_RESET.id());
assertThat(resetRateLimitConfiguration.getBucketSize()).isEqualTo(17);
assertThat(resetRateLimitConfiguration.getLeakRatePerMinute()).isEqualTo(44);
}
assertThat(resetRateLimiterConfig.bucketSize()).isEqualTo(17);
assertThat(resetRateLimiterConfig.leakRatePerMinute()).isEqualTo(44);
}
@Test

View File

@ -171,6 +171,7 @@ class AccountControllerTest {
private static RateLimiter usernameSetLimiter = mock(RateLimiter.class);
private static RateLimiter usernameReserveLimiter = mock(RateLimiter.class);
private static RateLimiter usernameLookupLimiter = mock(RateLimiter.class);
private static RateLimiter checkAccountExistence = mock(RateLimiter.class);
private static RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class);
private static TurnTokenGenerator turnTokenGenerator = mock(TurnTokenGenerator.class);
private static Account senderPinAccount = mock(Account.class);
@ -250,6 +251,8 @@ class AccountControllerTest {
when(rateLimiters.getUsernameSetLimiter()).thenReturn(usernameSetLimiter);
when(rateLimiters.getUsernameReserveLimiter()).thenReturn(usernameReserveLimiter);
when(rateLimiters.getUsernameLookupLimiter()).thenReturn(usernameLookupLimiter);
when(rateLimiters.forDescriptor(eq(RateLimiters.For.USERNAME_LOOKUP))).thenReturn(usernameLookupLimiter);
when(rateLimiters.forDescriptor(eq(RateLimiters.For.CHECK_ACCOUNT_EXISTENCE))).thenReturn(checkAccountExistence);
when(senderPinAccount.getLastSeen()).thenReturn(System.currentTimeMillis());
when(senderPinAccount.getRegistrationLock()).thenReturn(
@ -2124,7 +2127,7 @@ class AccountControllerTest {
when(accountsManager.getByAccountIdentifier(accountIdentifier)).thenReturn(Optional.of(account));
MockUtils.updateRateLimiterResponseToFail(
rateLimiters, RateLimiters.Handle.CHECK_ACCOUNT_EXISTENCE, "127.0.0.1", expectedRetryAfter, true);
rateLimiters, RateLimiters.For.CHECK_ACCOUNT_EXISTENCE, "127.0.0.1", expectedRetryAfter, true);
final Response response = resources.getJerseyTest()
.target(String.format("/v1/accounts/account/%s", accountIdentifier))
@ -2189,7 +2192,7 @@ class AccountControllerTest {
void testLookupUsernameRateLimited() throws RateLimitExceededException {
final Duration expectedRetryAfter = Duration.ofSeconds(13);
MockUtils.updateRateLimiterResponseToFail(
rateLimiters, RateLimiters.Handle.USERNAME_LOOKUP, "127.0.0.1", expectedRetryAfter, true);
rateLimiters, RateLimiters.For.USERNAME_LOOKUP, "127.0.0.1", expectedRetryAfter, true);
final Response response = resources.getJerseyTest()
.target(String.format("v1/accounts/username_hash/%s", BASE_64_URL_USERNAME_HASH_1))
.request()

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/

View File

@ -1,9 +1,9 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.limits;
package org.whispersystems.textsecuregcm.limits;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
@ -15,7 +15,6 @@ import java.io.IOException;
import java.time.Duration;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.limits.LeakyBucket;
class LeakyBucketTest {

View File

@ -1,3 +1,8 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static org.mockito.ArgumentMatchers.any;
@ -12,17 +17,17 @@ import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.spam.RateLimitChallengeListener;
import org.whispersystems.textsecuregcm.captcha.AssessmentResult;
import org.whispersystems.textsecuregcm.captcha.CaptchaChecker;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.spam.RateLimitChallengeListener;
import org.whispersystems.textsecuregcm.storage.Account;
class RateLimitChallengeManagerTest {
private PushChallengeManager pushChallengeManager;
private CaptchaChecker captchaChecker;
private DynamicRateLimiters rateLimiters;
private RateLimiters rateLimiters;
private RateLimitChallengeListener rateLimitChallengeListener;
private RateLimitChallengeManager rateLimitChallengeManager;
@ -31,7 +36,7 @@ class RateLimitChallengeManagerTest {
void setUp() {
pushChallengeManager = mock(PushChallengeManager.class);
captchaChecker = mock(CaptchaChecker.class);
rateLimiters = mock(DynamicRateLimiters.class);
rateLimiters = mock(RateLimiters.class);
rateLimitChallengeListener = mock(RateLimitChallengeListener.class);
rateLimitChallengeManager = new RateLimitChallengeManager(

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -30,13 +30,13 @@ import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
class RateLimitChallengeOptionManagerTest {
private DynamicRateLimitChallengeConfiguration rateLimitChallengeConfiguration;
private DynamicRateLimiters rateLimiters;
private RateLimiters rateLimiters;
private RateLimitChallengeOptionManager rateLimitChallengeOptionManager;
@BeforeEach
void setUp() {
rateLimiters = mock(DynamicRateLimiters.class);
rateLimiters = mock(RateLimiters.class);
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager =
mock(DynamicConfigurationManager.class);

View File

@ -11,7 +11,6 @@ import com.google.common.net.HttpHeaders;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.time.Duration;
import java.util.Optional;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.core.Response;
@ -43,14 +42,14 @@ public class RateLimitedByIpTest {
public static class Controller {
@GET
@Path("/strict")
@RateLimitedByIp(RateLimiters.Handle.BACKUP_AUTH_CHECK)
@RateLimitedByIp(RateLimiters.For.BACKUP_AUTH_CHECK)
public Response strict() {
return Response.ok().build();
}
@GET
@Path("/loose")
@RateLimitedByIp(value = RateLimiters.Handle.BACKUP_AUTH_CHECK, failOnUnresolvedIp = false)
@RateLimitedByIp(value = RateLimiters.For.BACKUP_AUTH_CHECK, failOnUnresolvedIp = false)
public Response loose() {
return Response.ok().build();
}
@ -59,7 +58,7 @@ public class RateLimitedByIpTest {
private static final RateLimiter RATE_LIMITER = Mockito.mock(RateLimiter.class);
private static final RateLimiters RATE_LIMITERS = MockUtils.buildMock(RateLimiters.class, rl ->
Mockito.when(rl.byHandle(Mockito.eq(RateLimiters.Handle.BACKUP_AUTH_CHECK))).thenReturn(Optional.of(RATE_LIMITER)));
Mockito.when(rl.forDescriptor(Mockito.eq(RateLimiters.For.BACKUP_AUTH_CHECK))).thenReturn(RATE_LIMITER));
private static final ResourceExtension RESOURCES = ResourceExtension.builder()
.setMapper(SystemMapper.getMapper())

View File

@ -0,0 +1,176 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.MockUtils;
@SuppressWarnings("unchecked")
public class RateLimitersTest {
private final DynamicConfiguration configuration = mock(DynamicConfiguration.class);
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfig =
MockUtils.buildMock(DynamicConfigurationManager.class, cfg -> when(cfg.getConfiguration()).thenReturn(configuration));
private final FaultTolerantRedisCluster redisCluster = mock(FaultTolerantRedisCluster.class);
private static final String BAD_YAML = """
limits:
smsVoicePrefix:
bucketSize: 150
leakRatePerMinute: 10
unexpected:
bucketSize: 4
leakRatePerMinute: 2
""";
private static final String GOOD_YAML = """
limits:
smsVoicePrefix:
bucketSize: 150
leakRatePerMinute: 10
attachmentCreate:
bucketSize: 4
leakRatePerMinute: 2
""";
public record GenericHolder(@Valid @NotNull @JsonProperty Map<String, RateLimiterConfig> limits) {
}
@Test
public void testValidateConfigs() throws Exception {
assertThrows(IllegalArgumentException.class, () -> {
final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(BAD_YAML, GenericHolder.class).orElseThrow();
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster);
rateLimiters.validateValuesAndConfigs();
});
final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(GOOD_YAML, GenericHolder.class).orElseThrow();
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster);
rateLimiters.validateValuesAndConfigs();
}
@Test
public void testValidateDuplicates() throws Exception {
final TestDescriptor td1 = new TestDescriptor("id1");
final TestDescriptor td2 = new TestDescriptor("id2");
final TestDescriptor td3 = new TestDescriptor("id3");
final TestDescriptor tdDup = new TestDescriptor("id1");
assertThrows(IllegalStateException.class, () -> new BaseRateLimiters<>(
new TestDescriptor[] { td1, td2, td3, tdDup },
Collections.emptyMap(),
dynamicConfig,
redisCluster) {});
new BaseRateLimiters<>(
new TestDescriptor[] { td1, td2, td3 },
Collections.emptyMap(),
dynamicConfig,
redisCluster) {};
}
@Test
void testUnchangingConfiguration() {
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster);
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
final RateLimiterConfig expected = RateLimiters.For.RATE_LIMIT_RESET.defaultConfig();
assertEquals(expected, limiter.config());
}
@Test
void testChangingConfiguration() {
final RateLimiterConfig initialRateLimiterConfig = new RateLimiterConfig(4, 1);
final RateLimiterConfig updatedRateLimiterCongig = new RateLimiterConfig(17, 19);
final RateLimiterConfig baseConfig = new RateLimiterConfig(1, 1);
final Map<String, RateLimiterConfig> limitsConfigMap = new HashMap<>();
limitsConfigMap.put(RateLimiters.For.RECAPTCHA_CHALLENGE_ATTEMPT.id(), baseConfig);
limitsConfigMap.put(RateLimiters.For.RECAPTCHA_CHALLENGE_SUCCESS.id(), baseConfig);
when(configuration.getLimits()).thenReturn(limitsConfigMap);
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster);
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), initialRateLimiterConfig);
assertEquals(initialRateLimiterConfig, limiter.config());
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeAttemptLimiter().config());
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeSuccessLimiter().config());
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), updatedRateLimiterCongig);
assertEquals(updatedRateLimiterCongig, limiter.config());
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeAttemptLimiter().config());
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeSuccessLimiter().config());
}
@Test
public void testRateLimiterHasItsPrioritiesStraight() throws Exception {
final RateLimiters.For descriptor = RateLimiters.For.RECAPTCHA_CHALLENGE_ATTEMPT;
final RateLimiterConfig configForDynamic = new RateLimiterConfig(1, 1);
final RateLimiterConfig configForStatic = new RateLimiterConfig(2, 2);
final RateLimiterConfig defaultConfig = descriptor.defaultConfig();
final Map<String, RateLimiterConfig> mapForDynamic = new HashMap<>();
final Map<String, RateLimiterConfig> mapForStatic = new HashMap<>();
when(configuration.getLimits()).thenReturn(mapForDynamic);
final RateLimiters rateLimiters = new RateLimiters(mapForStatic, dynamicConfig, redisCluster);
final RateLimiter limiter = rateLimiters.forDescriptor(descriptor);
// test only default is present
mapForDynamic.remove(descriptor.id());
mapForStatic.remove(descriptor.id());
assertEquals(defaultConfig, limiter.config());
// test dynamic and no static
mapForDynamic.put(descriptor.id(), configForDynamic);
mapForStatic.remove(descriptor.id());
assertEquals(configForDynamic, limiter.config());
// test dynamic and static
mapForDynamic.put(descriptor.id(), configForDynamic);
mapForStatic.put(descriptor.id(), configForStatic);
assertEquals(configForDynamic, limiter.config());
// test static, but no dynamic
mapForDynamic.remove(descriptor.id());
mapForStatic.put(descriptor.id(), configForStatic);
assertEquals(configForStatic, limiter.config());
}
private record TestDescriptor(String id) implements RateLimiterDescriptor {
@Override
public boolean isDynamic() {
return false;
}
@Override
public RateLimiterConfig defaultConfig() {
return new RateLimiterConfig(1, 1);
}
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@ -42,17 +42,16 @@ import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV3;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper;
@ExtendWith(DropwizardExtensionsSupport.class)
class AttachmentControllerTest {
private static RateLimiters rateLimiters = mock(RateLimiters.class );
private static RateLimiter rateLimiter = mock(RateLimiter.class );
private static final RateLimiter RATE_LIMITER = mock(RateLimiter.class);
static {
when(rateLimiters.getAttachmentLimiter()).thenReturn(rateLimiter);
}
private static final RateLimiters RATE_LIMITERS = MockUtils.buildMock(RateLimiters.class, rateLimiters ->
when(rateLimiters.getAttachmentLimiter()).thenReturn(RATE_LIMITER));
public static final String RSA_PRIVATE_KEY_PEM;
@ -80,8 +79,8 @@ class AttachmentControllerTest {
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new AttachmentControllerV2(rateLimiters, "accessKey", "accessSecret", "us-east-1", "attachmentv2-bucket"))
.addResource(new AttachmentControllerV3(rateLimiters, "some-cdn.signal.org", "signal@example.com", 1000, "/attach-here", RSA_PRIVATE_KEY_PEM))
.addResource(new AttachmentControllerV2(RATE_LIMITERS, "accessKey", "accessSecret", "us-east-1", "attachmentv2-bucket"))
.addResource(new AttachmentControllerV3(RATE_LIMITERS, "some-cdn.signal.org", "signal@example.com", 1000, "/attach-here", RSA_PRIVATE_KEY_PEM))
.build();
} catch (IOException | InvalidKeyException | InvalidKeySpecException e) {
throw new AssertionError(e);

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2022 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.tests.controllers;

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/

View File

@ -1,5 +1,5 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/

View File

@ -1,78 +0,0 @@
package org.whispersystems.textsecuregcm.tests.limits;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitsConfiguration;
import org.whispersystems.textsecuregcm.limits.DynamicRateLimiters;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
class DynamicRateLimitsTest {
private DynamicConfigurationManager<DynamicConfiguration> dynamicConfig;
private FaultTolerantRedisCluster redisCluster;
@BeforeEach
void setup() {
this.dynamicConfig = mock(DynamicConfigurationManager.class);
this.redisCluster = mock(FaultTolerantRedisCluster.class);
DynamicConfiguration defaultConfig = new DynamicConfiguration();
when(dynamicConfig.getConfiguration()).thenReturn(defaultConfig);
}
@Test
void testUnchangingConfiguration() {
DynamicRateLimiters rateLimiters = new DynamicRateLimiters(redisCluster, dynamicConfig);
RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
assertThat(limiter.getBucketSize()).isEqualTo(dynamicConfig.getConfiguration().getLimits().getRateLimitReset().getBucketSize());
assertThat(limiter.getLeakRatePerMinute()).isEqualTo(dynamicConfig.getConfiguration().getLimits().getRateLimitReset().getLeakRatePerMinute());
assertSame(rateLimiters.getRateLimitResetLimiter(), limiter);
}
@Test
void testChangingConfiguration() {
DynamicConfiguration configuration = mock(DynamicConfiguration.class);
DynamicRateLimitsConfiguration limitsConfiguration = mock(DynamicRateLimitsConfiguration.class);
when(configuration.getLimits()).thenReturn(limitsConfiguration);
when(limitsConfiguration.getRecaptchaChallengeAttempt()).thenReturn(new RateLimitConfiguration());
when(limitsConfiguration.getRecaptchaChallengeSuccess()).thenReturn(new RateLimitConfiguration());
when(limitsConfiguration.getPushChallengeAttempt()).thenReturn(new RateLimitConfiguration());
when(limitsConfiguration.getPushChallengeSuccess()).thenReturn(new RateLimitConfiguration());
final RateLimitConfiguration initialRateLimitConfiguration = new RateLimitConfiguration(4, 1);
when(limitsConfiguration.getRateLimitReset()).thenReturn(initialRateLimitConfiguration);
when(dynamicConfig.getConfiguration()).thenReturn(configuration);
DynamicRateLimiters rateLimiters = new DynamicRateLimiters(redisCluster, dynamicConfig);
RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
assertThat(limiter.getBucketSize()).isEqualTo(4);
assertThat(limiter.getLeakRatePerMinute()).isEqualTo(1);
assertSame(rateLimiters.getRateLimitResetLimiter(), limiter);
when(limitsConfiguration.getRateLimitReset()).thenReturn(new RateLimitConfiguration(17, 19));
RateLimiter changed = rateLimiters.getRateLimitResetLimiter();
assertThat(changed.getBucketSize()).isEqualTo(17);
assertThat(changed.getLeakRatePerMinute()).isEqualTo(19);
assertNotSame(limiter, changed);
}
}

View File

@ -45,10 +45,10 @@ public final class MockUtils {
public static void updateRateLimiterResponseToAllow(
final RateLimiters rateLimitersMock,
final RateLimiters.Handle handle,
final RateLimiters.For handle,
final String input) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(Optional.of(mockRateLimiter)).when(rateLimitersMock).byHandle(eq(handle));
doReturn(Optional.of(mockRateLimiter)).when(rateLimitersMock).forDescriptor(eq(handle));
try {
doNothing().when(mockRateLimiter).validate(eq(input));
} catch (final RateLimitExceededException e) {
@ -58,12 +58,12 @@ public final class MockUtils {
public static void updateRateLimiterResponseToFail(
final RateLimiters rateLimitersMock,
final RateLimiters.Handle handle,
final RateLimiters.For handle,
final String input,
final Duration retryAfter,
final boolean legacyStatusCode) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(Optional.of(mockRateLimiter)).when(rateLimitersMock).byHandle(eq(handle));
doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
try {
doThrow(new RateLimitExceededException(retryAfter, legacyStatusCode)).when(mockRateLimiter).validate(eq(input));
} catch (final RateLimitExceededException e) {