Rate limiters code refactored

This commit is contained in:
Sergey Skrobotov 2023-02-23 10:21:39 -08:00
parent 378b32d44d
commit 7529c35013
35 changed files with 738 additions and 774 deletions

View File

@ -12,11 +12,11 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import org.whispersystems.textsecuregcm.configuration.SpamFilterConfiguration;
import org.whispersystems.textsecuregcm.configuration.AccountDatabaseCrawlerConfiguration; import org.whispersystems.textsecuregcm.configuration.AccountDatabaseCrawlerConfiguration;
import org.whispersystems.textsecuregcm.configuration.AdminEventLoggingConfiguration; import org.whispersystems.textsecuregcm.configuration.AdminEventLoggingConfiguration;
import org.whispersystems.textsecuregcm.configuration.ApnConfiguration; import org.whispersystems.textsecuregcm.configuration.ApnConfiguration;
import org.whispersystems.textsecuregcm.configuration.AppConfigConfiguration; import org.whispersystems.textsecuregcm.configuration.AppConfigConfiguration;
import org.whispersystems.textsecuregcm.configuration.ArtServiceConfiguration;
import org.whispersystems.textsecuregcm.configuration.AwsAttachmentsConfiguration; import org.whispersystems.textsecuregcm.configuration.AwsAttachmentsConfiguration;
import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration; import org.whispersystems.textsecuregcm.configuration.BadgesConfiguration;
import org.whispersystems.textsecuregcm.configuration.BraintreeConfiguration; import org.whispersystems.textsecuregcm.configuration.BraintreeConfiguration;
@ -33,8 +33,7 @@ import org.whispersystems.textsecuregcm.configuration.MaxDeviceConfiguration;
import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration; import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration;
import org.whispersystems.textsecuregcm.configuration.OneTimeDonationConfiguration; import org.whispersystems.textsecuregcm.configuration.OneTimeDonationConfiguration;
import org.whispersystems.textsecuregcm.configuration.PaymentsServiceConfiguration; import org.whispersystems.textsecuregcm.configuration.PaymentsServiceConfiguration;
import org.whispersystems.textsecuregcm.configuration.ArtServiceConfiguration; import org.whispersystems.textsecuregcm.limits.RateLimiterConfig;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
import org.whispersystems.textsecuregcm.configuration.RecaptchaConfiguration; import org.whispersystems.textsecuregcm.configuration.RecaptchaConfiguration;
import org.whispersystems.textsecuregcm.configuration.RedisClusterConfiguration; import org.whispersystems.textsecuregcm.configuration.RedisClusterConfiguration;
import org.whispersystems.textsecuregcm.configuration.RedisConfiguration; import org.whispersystems.textsecuregcm.configuration.RedisConfiguration;
@ -44,6 +43,7 @@ import org.whispersystems.textsecuregcm.configuration.ReportMessageConfiguration
import org.whispersystems.textsecuregcm.configuration.SecureBackupServiceConfiguration; import org.whispersystems.textsecuregcm.configuration.SecureBackupServiceConfiguration;
import org.whispersystems.textsecuregcm.configuration.SecureStorageServiceConfiguration; import org.whispersystems.textsecuregcm.configuration.SecureStorageServiceConfiguration;
import org.whispersystems.textsecuregcm.configuration.SecureValueRecovery2Configuration; import org.whispersystems.textsecuregcm.configuration.SecureValueRecovery2Configuration;
import org.whispersystems.textsecuregcm.configuration.SpamFilterConfiguration;
import org.whispersystems.textsecuregcm.configuration.StripeConfiguration; import org.whispersystems.textsecuregcm.configuration.StripeConfiguration;
import org.whispersystems.textsecuregcm.configuration.SubscriptionConfiguration; import org.whispersystems.textsecuregcm.configuration.SubscriptionConfiguration;
import org.whispersystems.textsecuregcm.configuration.TestDeviceConfiguration; import org.whispersystems.textsecuregcm.configuration.TestDeviceConfiguration;
@ -167,7 +167,7 @@ public class WhisperServerConfiguration extends Configuration {
@Valid @Valid
@NotNull @NotNull
@JsonProperty @JsonProperty
private RateLimitsConfiguration limits = new RateLimitsConfiguration(); private Map<String, RateLimiterConfig> limits = new HashMap<>();
@Valid @Valid
@NotNull @NotNull
@ -351,7 +351,7 @@ public class WhisperServerConfiguration extends Configuration {
return rateLimitersCluster; return rateLimitersCluster;
} }
public RateLimitsConfiguration getLimitsConfiguration() { public Map<String, RateLimiterConfig> getLimitsConfiguration() {
return limits; return limits;
} }

View File

@ -118,7 +118,6 @@ import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter; import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter;
import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter; import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter;
import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter; import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter;
import org.whispersystems.textsecuregcm.limits.DynamicRateLimiters;
import org.whispersystems.textsecuregcm.limits.PushChallengeManager; import org.whispersystems.textsecuregcm.limits.PushChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager; import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
@ -506,8 +505,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
FcmSender fcmSender = new FcmSender(fcmSenderExecutor, config.getFcmConfiguration().credentials()); FcmSender fcmSender = new FcmSender(fcmSenderExecutor, config.getFcmConfiguration().credentials());
ApnPushNotificationScheduler apnPushNotificationScheduler = new ApnPushNotificationScheduler(pushSchedulerCluster, apnSender, accountsManager); ApnPushNotificationScheduler apnPushNotificationScheduler = new ApnPushNotificationScheduler(pushSchedulerCluster, apnSender, accountsManager);
PushNotificationManager pushNotificationManager = new PushNotificationManager(accountsManager, apnSender, fcmSender, apnPushNotificationScheduler, pushLatencyManager, dynamicConfigurationManager); PushNotificationManager pushNotificationManager = new PushNotificationManager(accountsManager, apnSender, fcmSender, apnPushNotificationScheduler, pushLatencyManager, dynamicConfigurationManager);
RateLimiters rateLimiters = new RateLimiters(config.getLimitsConfiguration(), rateLimitersCluster); RateLimiters rateLimiters = RateLimiters.createAndValidate(config.getLimitsConfiguration(), dynamicConfigurationManager, rateLimitersCluster);
DynamicRateLimiters dynamicRateLimiters = new DynamicRateLimiters(rateLimitersCluster, dynamicConfigurationManager);
ProvisioningManager provisioningManager = new ProvisioningManager(pubSubManager); ProvisioningManager provisioningManager = new ProvisioningManager(pubSubManager);
IssuedReceiptsManager issuedReceiptsManager = new IssuedReceiptsManager( IssuedReceiptsManager issuedReceiptsManager = new IssuedReceiptsManager(
config.getDynamoDbTables().getIssuedReceipts().getTableName(), config.getDynamoDbTables().getIssuedReceipts().getTableName(),
@ -551,7 +549,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
PushChallengeManager pushChallengeManager = new PushChallengeManager(pushNotificationManager, pushChallengeDynamoDb); PushChallengeManager pushChallengeManager = new PushChallengeManager(pushNotificationManager, pushChallengeDynamoDb);
RateLimitChallengeManager rateLimitChallengeManager = new RateLimitChallengeManager(pushChallengeManager, RateLimitChallengeManager rateLimitChallengeManager = new RateLimitChallengeManager(pushChallengeManager,
captchaChecker, dynamicRateLimiters); captchaChecker, rateLimiters);
MessagePersister messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, dynamicConfigurationManager, Duration.ofMinutes(config.getMessageCacheConfiguration().getPersistDelayMinutes())); MessagePersister messagePersister = new MessagePersister(messagesCache, messagesManager, accountsManager, dynamicConfigurationManager, Duration.ofMinutes(config.getMessageCacheConfiguration().getPersistDelayMinutes()));
ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager); ChangeNumberManager changeNumberManager = new ChangeNumberManager(messageSender, accountsManager);

View File

@ -1,201 +0,0 @@
/*
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.configuration;
import com.fasterxml.jackson.annotation.JsonProperty;
public class RateLimitsConfiguration {
@JsonProperty
private RateLimitConfiguration smsDestination = new RateLimitConfiguration(2, 2);
@JsonProperty
private RateLimitConfiguration voiceDestination = new RateLimitConfiguration(2, 1.0 / 2.0);
@JsonProperty
private RateLimitConfiguration voiceDestinationDaily = new RateLimitConfiguration(10, 10.0 / (24.0 * 60.0));
@JsonProperty
private RateLimitConfiguration smsVoiceIp = new RateLimitConfiguration(1000, 1000);
@JsonProperty
private RateLimitConfiguration smsVoicePrefix = new RateLimitConfiguration(1000, 1000);
@JsonProperty
private RateLimitConfiguration verifyNumber = new RateLimitConfiguration(2, 2);
@JsonProperty
private RateLimitConfiguration verifyPin = new RateLimitConfiguration(10, 1 / (24.0 * 60.0));
@JsonProperty
private RateLimitConfiguration verificationCaptcha = new RateLimitConfiguration(10, 2);
@JsonProperty
private RateLimitConfiguration verificationPushChallenge = new RateLimitConfiguration(5, 2);
@JsonProperty
private RateLimitConfiguration registration = new RateLimitConfiguration(2, 2);
@JsonProperty
private RateLimitConfiguration attachments = new RateLimitConfiguration(50, 50);
@JsonProperty
private RateLimitConfiguration prekeys = new RateLimitConfiguration(6, 1.0 / 10.0);
@JsonProperty
private RateLimitConfiguration messages = new RateLimitConfiguration(60, 60);
@JsonProperty
private RateLimitConfiguration allocateDevice = new RateLimitConfiguration(2, 1.0 / 2.0);
@JsonProperty
private RateLimitConfiguration verifyDevice = new RateLimitConfiguration(6, 1.0 / 10.0);
@JsonProperty
private RateLimitConfiguration turnAllocations = new RateLimitConfiguration(60, 60);
@JsonProperty
private RateLimitConfiguration profile = new RateLimitConfiguration(4320, 3);
@JsonProperty
private RateLimitConfiguration stickerPack = new RateLimitConfiguration(50, 20 / (24.0 * 60.0));
@JsonProperty
private RateLimitConfiguration artPack = new RateLimitConfiguration(50, 20 / (24.0 * 60.0));
@JsonProperty
private RateLimitConfiguration usernameLookup = new RateLimitConfiguration(100, 100 / (24.0 * 60.0));
@JsonProperty
private RateLimitConfiguration usernameSet = new RateLimitConfiguration(100, 100 / (24.0 * 60.0));
@JsonProperty
private RateLimitConfiguration usernameReserve = new RateLimitConfiguration(100, 100 / (24.0 * 60.0));
@JsonProperty
private RateLimitConfiguration checkAccountExistence = new RateLimitConfiguration(1_000, 1_000 / 60.0);
@JsonProperty
private RateLimitConfiguration backupAuthCheck = new RateLimitConfiguration(100, 100 / (24.0 * 60.0));
public RateLimitConfiguration getAllocateDevice() {
return allocateDevice;
}
public RateLimitConfiguration getVerifyDevice() {
return verifyDevice;
}
public RateLimitConfiguration getMessages() {
return messages;
}
public RateLimitConfiguration getPreKeys() {
return prekeys;
}
public RateLimitConfiguration getAttachments() {
return attachments;
}
public RateLimitConfiguration getSmsDestination() {
return smsDestination;
}
public RateLimitConfiguration getVoiceDestination() {
return voiceDestination;
}
public RateLimitConfiguration getVoiceDestinationDaily() {
return voiceDestinationDaily;
}
public RateLimitConfiguration getSmsVoiceIp() {
return smsVoiceIp;
}
public RateLimitConfiguration getSmsVoicePrefix() {
return smsVoicePrefix;
}
public RateLimitConfiguration getVerifyNumber() {
return verifyNumber;
}
public RateLimitConfiguration getVerifyPin() {
return verifyPin;
}
public RateLimitConfiguration getVerificationCaptcha() {
return verificationCaptcha;
}
public RateLimitConfiguration getVerificationPushChallenge() {
return verificationPushChallenge;
}
public RateLimitConfiguration getRegistration() {
return registration;
}
public RateLimitConfiguration getTurnAllocations() {
return turnAllocations;
}
public RateLimitConfiguration getProfile() {
return profile;
}
public RateLimitConfiguration getStickerPack() {
return stickerPack;
}
public RateLimitConfiguration getArtPack() {
return artPack;
}
public RateLimitConfiguration getUsernameLookup() {
return usernameLookup;
}
public RateLimitConfiguration getUsernameSet() {
return usernameSet;
}
public RateLimitConfiguration getUsernameReserve() {
return usernameReserve;
}
public RateLimitConfiguration getCheckAccountExistence() {
return checkAccountExistence;
}
public RateLimitConfiguration getBackupAuthCheck() {
return backupAuthCheck;
}
public static class RateLimitConfiguration {
@JsonProperty
private int bucketSize;
@JsonProperty
private double leakRatePerMinute;
public RateLimitConfiguration(int bucketSize, double leakRatePerMinute) {
this.bucketSize = bucketSize;
this.leakRatePerMinute = leakRatePerMinute;
}
public RateLimitConfiguration() {}
public int getBucketSize() {
return bucketSize;
}
public double getLeakRatePerMinute() {
return leakRatePerMinute;
}
}
}

View File

@ -1,10 +1,17 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.configuration.dynamic; package org.whispersystems.textsecuregcm.configuration.dynamic;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import javax.validation.Valid; import javax.validation.Valid;
import org.whispersystems.textsecuregcm.limits.RateLimiterConfig;
public class DynamicConfiguration { public class DynamicConfiguration {
@ -18,7 +25,7 @@ public class DynamicConfiguration {
@JsonProperty @JsonProperty
@Valid @Valid
private DynamicRateLimitsConfiguration limits = new DynamicRateLimitsConfiguration(); private Map<String, RateLimiterConfig> limits = new HashMap<>();
@JsonProperty @JsonProperty
@Valid @Valid
@ -65,7 +72,7 @@ public class DynamicConfiguration {
return Optional.ofNullable(preRegistrationExperiments.get(experimentName)); return Optional.ofNullable(preRegistrationExperiments.get(experimentName));
} }
public DynamicRateLimitsConfiguration getLimits() { public Map<String, RateLimiterConfig> getLimits() {
return limits; return limits;
} }

View File

@ -1,42 +0,0 @@
package org.whispersystems.textsecuregcm.configuration.dynamic;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration;
public class DynamicRateLimitsConfiguration {
@JsonProperty
private RateLimitConfiguration rateLimitReset = new RateLimitConfiguration(2, 2.0 / (60 * 24));
@JsonProperty
private RateLimitConfiguration recaptchaChallengeAttempt = new RateLimitConfiguration(10, 10.0 / (60 * 24));
@JsonProperty
private RateLimitConfiguration recaptchaChallengeSuccess = new RateLimitConfiguration(2, 2.0 / (60 * 24));
@JsonProperty
private RateLimitConfiguration pushChallengeAttempt = new RateLimitConfiguration(10, 10.0 / (60 * 24));
@JsonProperty
private RateLimitConfiguration pushChallengeSuccess = new RateLimitConfiguration(2, 2.0 / (60 * 24));
public RateLimitConfiguration getRateLimitReset() {
return rateLimitReset;
}
public RateLimitConfiguration getRecaptchaChallengeAttempt() {
return recaptchaChallengeAttempt;
}
public RateLimitConfiguration getRecaptchaChallengeSuccess() {
return recaptchaChallengeSuccess;
}
public RateLimitConfiguration getPushChallengeAttempt() {
return pushChallengeAttempt;
}
public RateLimitConfiguration getPushChallengeSuccess() {
return pushChallengeSuccess;
}
}

View File

@ -746,7 +746,7 @@ public class AccountController {
@GET @GET
@Path("/username_hash/{usernameHash}") @Path("/username_hash/{usernameHash}")
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@RateLimitedByIp(RateLimiters.Handle.USERNAME_LOOKUP) @RateLimitedByIp(RateLimiters.For.USERNAME_LOOKUP)
public AccountIdentifierResponse lookupUsernameHash( public AccountIdentifierResponse lookupUsernameHash(
@HeaderParam(HeaderUtils.X_SIGNAL_AGENT) final String userAgent, @HeaderParam(HeaderUtils.X_SIGNAL_AGENT) final String userAgent,
@HeaderParam(HttpHeaders.X_FORWARDED_FOR) final String forwardedFor, @HeaderParam(HttpHeaders.X_FORWARDED_FOR) final String forwardedFor,
@ -780,7 +780,7 @@ public class AccountController {
@HEAD @HEAD
@Path("/account/{uuid}") @Path("/account/{uuid}")
@RateLimitedByIp(RateLimiters.Handle.CHECK_ACCOUNT_EXISTENCE) @RateLimitedByIp(RateLimiters.For.CHECK_ACCOUNT_EXISTENCE)
public Response accountExists( public Response accountExists(
@PathParam("uuid") final UUID uuid, @PathParam("uuid") final UUID uuid,
@Context HttpServletRequest request) throws RateLimitExceededException { @Context HttpServletRequest request) throws RateLimitExceededException {

View File

@ -79,7 +79,7 @@ public class SecureBackupController {
@Path("/auth/check") @Path("/auth/check")
@Consumes(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON)
@RateLimitedByIp(RateLimiters.Handle.BACKUP_AUTH_CHECK) @RateLimitedByIp(RateLimiters.For.BACKUP_AUTH_CHECK)
public AuthCheckResponse authCheck(@NotNull @Valid final AuthCheckRequest request) { public AuthCheckResponse authCheck(@NotNull @Valid final AuthCheckRequest request) {
final Map<String, AuthCheckResponse.Result> results = new HashMap<>(); final Map<String, AuthCheckResponse.Result> results = new HashMap<>();
final Map<String, Pair<UUID, Long>> tokenToUuid = new HashMap<>(); final Map<String, Pair<UUID, Long>> tokenToUuid = new HashMap<>();

View File

@ -0,0 +1,82 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static java.util.Objects.requireNonNull;
import java.lang.invoke.MethodHandles;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public abstract class BaseRateLimiters<T extends RateLimiterDescriptor> {
private final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
private final Map<T, RateLimiter> rateLimiterByDescriptor;
private final Map<String, RateLimiterConfig> configs;
protected BaseRateLimiters(
final T[] values,
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final FaultTolerantRedisCluster cacheCluster) {
this.configs = configs;
this.rateLimiterByDescriptor = Arrays.stream(values)
.map(descriptor -> Pair.of(
descriptor,
createForDescriptor(descriptor, configs, dynamicConfigurationManager, cacheCluster)))
.collect(Collectors.toUnmodifiableMap(Pair::getKey, Pair::getValue));
}
public RateLimiter forDescriptor(final T handle) {
return requireNonNull(rateLimiterByDescriptor.get(handle));
}
public void validateValuesAndConfigs() {
final Set<String> ids = rateLimiterByDescriptor.keySet().stream()
.map(RateLimiterDescriptor::id)
.collect(Collectors.toSet());
for (final String key: configs.keySet()) {
if (!ids.contains(key)) {
final String message = String.format(
"Static configuration has an unexpected field '%s' that doesn't match any RateLimiterDescriptor",
key
);
logger.error(message);
throw new IllegalArgumentException(message);
}
}
}
private static RateLimiter createForDescriptor(
final RateLimiterDescriptor descriptor,
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final FaultTolerantRedisCluster cacheCluster) {
if (descriptor.isDynamic()) {
final Supplier<RateLimiterConfig> configResolver = () -> {
final RateLimiterConfig config = dynamicConfigurationManager.getConfiguration().getLimits().get(descriptor.id());
return config != null
? config
: configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
};
return new DynamicRateLimiter(descriptor.id(), configResolver, cacheCluster);
}
final RateLimiterConfig cfg = configs.getOrDefault(descriptor.id(), descriptor.defaultConfig());
return new StaticRateLimiter(descriptor.id(), cfg, cacheCluster);
}
}

View File

@ -0,0 +1,63 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static java.util.Objects.requireNonNull;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import org.apache.commons.lang3.tuple.Pair;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
public class DynamicRateLimiter implements RateLimiter {
private final String name;
private final Supplier<RateLimiterConfig> configResolver;
private final FaultTolerantRedisCluster cluster;
private final AtomicReference<Pair<RateLimiterConfig, RateLimiter>> currentHolder = new AtomicReference<>();
public DynamicRateLimiter(
final String name,
final Supplier<RateLimiterConfig> configResolver,
final FaultTolerantRedisCluster cluster) {
this.name = requireNonNull(name);
this.configResolver = requireNonNull(configResolver);
this.cluster = requireNonNull(cluster);
}
@Override
public void validate(final String key, final int amount) throws RateLimitExceededException {
current().getRight().validate(key, amount);
}
@Override
public boolean hasAvailablePermits(final String key, final int permits) {
return current().getRight().hasAvailablePermits(key, permits);
}
@Override
public void clear(final String key) {
current().getRight().clear(key);
}
@Override
public RateLimiterConfig config() {
return current().getLeft();
}
private Pair<RateLimiterConfig, RateLimiter> current() {
final RateLimiterConfig cfg = configResolver.get();
return currentHolder.updateAndGet(p -> p != null && p.getLeft().equals(cfg)
? p
: Pair.of(cfg, new StaticRateLimiter(name, cfg, cluster))
);
}
}

View File

@ -1,130 +0,0 @@
/*
* Copyright 2013-2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public class DynamicRateLimiters {
private final FaultTolerantRedisCluster cacheCluster;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final AtomicReference<RateLimiter> rateLimitResetLimiter;
private final AtomicReference<RateLimiter> recaptchaChallengeAttemptLimiter;
private final AtomicReference<RateLimiter> recaptchaChallengeSuccessLimiter;
private final AtomicReference<RateLimiter> pushChallengeAttemptLimiter;
private final AtomicReference<RateLimiter> pushChallengeSuccessLimiter;
public DynamicRateLimiters(final FaultTolerantRedisCluster rateLimitCluster,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) {
this.cacheCluster = rateLimitCluster;
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.rateLimitResetLimiter = new AtomicReference<>(
createRateLimitResetLimiter(this.cacheCluster,
this.dynamicConfigurationManager.getConfiguration().getLimits().getRateLimitReset()));
this.recaptchaChallengeAttemptLimiter = new AtomicReference<>(createRecaptchaChallengeAttemptLimiter(
this.cacheCluster,
this.dynamicConfigurationManager.getConfiguration().getLimits().getRecaptchaChallengeAttempt()));
this.recaptchaChallengeSuccessLimiter = new AtomicReference<>(createRecaptchaChallengeSuccessLimiter(
this.cacheCluster,
this.dynamicConfigurationManager.getConfiguration().getLimits().getRecaptchaChallengeSuccess()));
this.pushChallengeAttemptLimiter = new AtomicReference<>(createPushChallengeAttemptLimiter(this.cacheCluster,
this.dynamicConfigurationManager.getConfiguration().getLimits().getPushChallengeAttempt()));
this.pushChallengeSuccessLimiter = new AtomicReference<>(createPushChallengeSuccessLimiter(this.cacheCluster,
this.dynamicConfigurationManager.getConfiguration().getLimits().getPushChallengeSuccess()));
}
public RateLimiter getRateLimitResetLimiter() {
return updateAndGetRateLimiter(
rateLimitResetLimiter,
dynamicConfigurationManager.getConfiguration().getLimits().getRateLimitReset(),
this::createRateLimitResetLimiter);
}
public RateLimiter getRecaptchaChallengeAttemptLimiter() {
return updateAndGetRateLimiter(
recaptchaChallengeAttemptLimiter,
dynamicConfigurationManager.getConfiguration().getLimits().getRecaptchaChallengeAttempt(),
this::createRecaptchaChallengeAttemptLimiter);
}
public RateLimiter getRecaptchaChallengeSuccessLimiter() {
return updateAndGetRateLimiter(
recaptchaChallengeSuccessLimiter,
dynamicConfigurationManager.getConfiguration().getLimits().getRecaptchaChallengeSuccess(),
this::createRecaptchaChallengeSuccessLimiter);
}
public RateLimiter getPushChallengeAttemptLimiter() {
return updateAndGetRateLimiter(
pushChallengeAttemptLimiter,
dynamicConfigurationManager.getConfiguration().getLimits().getPushChallengeAttempt(),
this::createPushChallengeAttemptLimiter);
}
public RateLimiter getPushChallengeSuccessLimiter() {
return updateAndGetRateLimiter(
pushChallengeSuccessLimiter,
dynamicConfigurationManager.getConfiguration().getLimits().getPushChallengeSuccess(),
this::createPushChallengeSuccessLimiter);
}
private RateLimiter updateAndGetRateLimiter(final AtomicReference<RateLimiter> rateLimiter,
RateLimitConfiguration currentConfiguration,
BiFunction<FaultTolerantRedisCluster, RateLimitConfiguration, RateLimiter> rateLimitFactory) {
return rateLimiter.updateAndGet(limiter -> {
if (limiter.hasConfiguration(currentConfiguration)) {
return limiter;
} else {
return rateLimitFactory.apply(cacheCluster, currentConfiguration);
}
});
}
public RateLimiter createRateLimitResetLimiter(FaultTolerantRedisCluster cacheCluster,
RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "rateLimitReset");
}
public RateLimiter createRecaptchaChallengeAttemptLimiter(FaultTolerantRedisCluster cacheCluster,
RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "recaptchaChallengeAttempt");
}
public RateLimiter createRecaptchaChallengeSuccessLimiter(FaultTolerantRedisCluster cacheCluster,
RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "recaptchaChallengeSuccess");
}
public RateLimiter createPushChallengeAttemptLimiter(FaultTolerantRedisCluster cacheCluster,
RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "pushChallengeAttempt");
}
public RateLimiter createPushChallengeSuccessLimiter(FaultTolerantRedisCluster cacheCluster,
RateLimitConfiguration configuration) {
return createLimiter(cacheCluster, configuration, "pushChallengeSuccess");
}
private RateLimiter createLimiter(FaultTolerantRedisCluster cacheCluster, RateLimitConfiguration configuration,
String name) {
return new RateLimiter(cacheCluster, name,
configuration.getBucketSize(),
configuration.getLeakRatePerMinute());
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -16,22 +16,28 @@ import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
public class LockingRateLimiter extends RateLimiter { public class LockingRateLimiter extends StaticRateLimiter {
private static final RateLimitExceededException REUSABLE_RATE_LIMIT_EXCEEDED_EXCEPTION
= new RateLimitExceededException(Duration.ZERO, true);
private final Meter meter; private final Meter meter;
public LockingRateLimiter(FaultTolerantRedisCluster cacheCluster, String name, int bucketSize, double leakRatePerMinute) {
super(cacheCluster, name, bucketSize, leakRatePerMinute);
MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); public LockingRateLimiter(
final String name,
final RateLimiterConfig config,
final FaultTolerantRedisCluster cacheCluster) {
super(name, config, cacheCluster);
final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
this.meter = metricRegistry.meter(name(getClass(), name, "locked")); this.meter = metricRegistry.meter(name(getClass(), name, "locked"));
} }
@Override @Override
public void validate(String key, int amount) throws RateLimitExceededException { public void validate(final String key, final int amount) throws RateLimitExceededException {
if (!acquireLock(key)) { if (!acquireLock(key)) {
meter.mark(); meter.mark();
throw new RateLimitExceededException(Duration.ZERO, true); throw REUSABLE_RATE_LIMIT_EXCEEDED_EXCEPTION;
} }
try { try {
@ -41,22 +47,15 @@ public class LockingRateLimiter extends RateLimiter {
} }
} }
@Override private void releaseLock(final String key) {
public void validate(String key) throws RateLimitExceededException {
validate(key, 1);
}
private void releaseLock(String key) {
cacheCluster.useCluster(connection -> connection.sync().del(getLockName(key))); cacheCluster.useCluster(connection -> connection.sync().del(getLockName(key)));
} }
private boolean acquireLock(String key) { private boolean acquireLock(final String key) {
return cacheCluster.withCluster(connection -> connection.sync().set(getLockName(key), "L", SetArgs.Builder.nx().ex(10))) != null; return cacheCluster.withCluster(connection -> connection.sync().set(getLockName(key), "L", SetArgs.Builder.nx().ex(10))) != null;
} }
private String getLockName(String key) { private String getLockName(final String key) {
return "leaky_lock::" + name + "::" + key; return "leaky_lock::" + name + "::" + key;
} }
} }

View File

@ -58,7 +58,7 @@ public class RateLimitByIpFilter implements ContainerRequestFilter {
return; return;
} }
final RateLimiters.Handle handle = annotation.value(); final RateLimiters.For handle = annotation.value();
try { try {
final String xffHeader = requestContext.getHeaders().getFirst(HttpHeaders.X_FORWARDED_FOR); final String xffHeader = requestContext.getHeaders().getFirst(HttpHeaders.X_FORWARDED_FOR);
@ -77,13 +77,8 @@ public class RateLimitByIpFilter implements ContainerRequestFilter {
return; return;
} }
final Optional<RateLimiter> maybeRateLimiter = rateLimiters.byHandle(handle); final RateLimiter rateLimiter = rateLimiters.forDescriptor(handle);
if (maybeRateLimiter.isEmpty()) { rateLimiter.validate(maybeMostRecentProxy.get());
logger.warn("RateLimiter not found for {}. Make sure it's initialized in RateLimiters class", handle);
return;
}
maybeRateLimiter.get().validate(maybeMostRecentProxy.get());
} catch (RateLimitExceededException e) { } catch (RateLimitExceededException e) {
final Response response = EXCEPTION_MAPPER.toResponse(e); final Response response = EXCEPTION_MAPPER.toResponse(e);
throw new ClientErrorException(response); throw new ClientErrorException(response);

View File

@ -1,3 +1,8 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits; package org.whispersystems.textsecuregcm.limits;
import static com.codahale.metrics.MetricRegistry.name; import static com.codahale.metrics.MetricRegistry.name;
@ -9,11 +14,11 @@ import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import org.whispersystems.textsecuregcm.spam.RateLimitChallengeListener;
import org.whispersystems.textsecuregcm.captcha.CaptchaChecker; import org.whispersystems.textsecuregcm.captcha.CaptchaChecker;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.spam.RateLimitChallengeListener;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.util.Util; import org.whispersystems.textsecuregcm.util.Util;
@ -21,7 +26,7 @@ public class RateLimitChallengeManager {
private final PushChallengeManager pushChallengeManager; private final PushChallengeManager pushChallengeManager;
private final CaptchaChecker captchaChecker; private final CaptchaChecker captchaChecker;
private final DynamicRateLimiters rateLimiters; private final RateLimiters rateLimiters;
private final List<RateLimitChallengeListener> rateLimitChallengeListeners = private final List<RateLimitChallengeListener> rateLimitChallengeListeners =
Collections.synchronizedList(new ArrayList<>()); Collections.synchronizedList(new ArrayList<>());
@ -35,7 +40,7 @@ public class RateLimitChallengeManager {
public RateLimitChallengeManager( public RateLimitChallengeManager(
final PushChallengeManager pushChallengeManager, final PushChallengeManager pushChallengeManager,
final CaptchaChecker captchaChecker, final CaptchaChecker captchaChecker,
final DynamicRateLimiters rateLimiters) { final RateLimiters rateLimiters) {
this.pushChallengeManager = pushChallengeManager; this.pushChallengeManager = pushChallengeManager;
this.captchaChecker = captchaChecker; this.captchaChecker = captchaChecker;

View File

@ -1,30 +1,30 @@
/* /*
* Copyright 2013-2021 Signal Messenger, LLC * Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.limits; package org.whispersystems.textsecuregcm.limits;
import com.vdurmont.semver4j.Semver; import com.vdurmont.semver4j.Semver;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgent; import org.whispersystems.textsecuregcm.util.ua.UserAgent;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
public class RateLimitChallengeOptionManager { public class RateLimitChallengeOptionManager {
private final DynamicRateLimiters rateLimiters; private final RateLimiters rateLimiters;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager; private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
public static final String OPTION_RECAPTCHA = "recaptcha"; public static final String OPTION_RECAPTCHA = "recaptcha";
public static final String OPTION_PUSH_CHALLENGE = "pushChallenge"; public static final String OPTION_PUSH_CHALLENGE = "pushChallenge";
public RateLimitChallengeOptionManager(final DynamicRateLimiters rateLimiters, public RateLimitChallengeOptionManager(final RateLimiters rateLimiters,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) { final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager) {
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;

View File

@ -14,7 +14,7 @@ import java.lang.annotation.Target;
@Retention(RetentionPolicy.RUNTIME) @Retention(RetentionPolicy.RUNTIME)
public @interface RateLimitedByIp { public @interface RateLimitedByIp {
RateLimiters.Handle value(); RateLimiters.For value();
boolean failOnUnresolvedIp() default true; boolean failOnUnresolvedIp() default true;
} }

View File

@ -1,143 +1,48 @@
/* /*
* Copyright 2013-2020 Signal Messenger, LLC * Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.limits; package org.whispersystems.textsecuregcm.limits;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.time.Duration;
import java.util.UUID; import java.util.UUID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.SystemMapper;
public class RateLimiter { public interface RateLimiter {
private final Logger logger = LoggerFactory.getLogger(RateLimiter.class); void validate(String key, int amount) throws RateLimitExceededException;
private final ObjectMapper mapper = SystemMapper.getMapper();
private final Meter meter; boolean hasAvailablePermits(String key, int permits);
private final Timer validateTimer;
protected final FaultTolerantRedisCluster cacheCluster;
protected final String name;
private final int bucketSize;
private final double leakRatePerMinute;
private final double leakRatePerMillis;
public RateLimiter(FaultTolerantRedisCluster cacheCluster, String name, int bucketSize, double leakRatePerMinute) void clear(String key);
{
MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
this.meter = metricRegistry.meter(name(getClass(), name, "exceeded")); RateLimiterConfig config();
this.validateTimer = metricRegistry.timer(name(getClass(), name, "validate"));
this.cacheCluster = cacheCluster;
this.name = name;
this.bucketSize = bucketSize;
this.leakRatePerMinute = leakRatePerMinute;
this.leakRatePerMillis = leakRatePerMinute / (60.0 * 1000.0);
}
public void validate(String key, int amount) throws RateLimitExceededException { default void validate(final String key) throws RateLimitExceededException {
try (final Timer.Context ignored = validateTimer.time()) {
LeakyBucket bucket = getBucket(key);
if (bucket.add(amount)) {
setBucket(key, bucket);
} else {
meter.mark();
throw new RateLimitExceededException(bucket.getTimeUntilSpaceAvailable(amount), true);
}
}
}
public void validate(final UUID accountUuid) throws RateLimitExceededException {
validate(accountUuid.toString());
}
public void validate(final UUID sourceAccountUuid, final UUID destinationAccountUuid)
throws RateLimitExceededException {
validate(sourceAccountUuid.toString() + "__" + destinationAccountUuid.toString());
}
public void validate(String key) throws RateLimitExceededException {
validate(key, 1); validate(key, 1);
} }
public boolean hasAvailablePermits(final UUID accountUuid, final int permits) { default void validate(final UUID accountUuid) throws RateLimitExceededException {
validate(accountUuid.toString());
}
default void validate(final UUID srcAccountUuid, final UUID dstAccountUuid) throws RateLimitExceededException {
validate(srcAccountUuid.toString() + "__" + dstAccountUuid.toString());
}
default boolean hasAvailablePermits(final UUID accountUuid, final int permits) {
return hasAvailablePermits(accountUuid.toString(), permits); return hasAvailablePermits(accountUuid.toString(), permits);
} }
public boolean hasAvailablePermits(final String key, final int permits) { default void clear(final UUID accountUuid) {
return getBucket(key).getTimeUntilSpaceAvailable(permits).equals(Duration.ZERO);
}
public void clear(final UUID accountUuid) {
clear(accountUuid.toString()); clear(accountUuid.toString());
} }
public void clear(String key) {
cacheCluster.useCluster(connection -> connection.sync().del(getBucketName(key)));
}
public int getBucketSize() {
return bucketSize;
}
public double getLeakRatePerMinute() {
return leakRatePerMinute;
}
private void setBucket(String key, LeakyBucket bucket) {
try {
final String serialized = bucket.serialize(mapper);
cacheCluster.useCluster(connection -> connection.sync().setex(getBucketName(key), (int) Math.ceil((bucketSize / leakRatePerMillis) / 1000), serialized));
} catch (JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
}
private LeakyBucket getBucket(String key) {
try {
final String serialized = cacheCluster.withCluster(connection -> connection.sync().get(getBucketName(key)));
if (serialized != null) {
return LeakyBucket.fromSerialized(mapper, serialized);
}
} catch (IOException e) {
logger.warn("Deserialization error", e);
}
return new LeakyBucket(bucketSize, leakRatePerMillis);
}
private String getBucketName(String key) {
return "leaky_bucket::" + name + "::" + key;
}
public boolean hasConfiguration(final RateLimitConfiguration configuration) {
return bucketSize == configuration.getBucketSize() && leakRatePerMinute == configuration.getLeakRatePerMinute();
}
/** /**
* If the wrapped {@code validate()} call throws a {@link RateLimitExceededException}, it will adapt it to ensure that * If the wrapped {@code validate()} call throws a {@link RateLimitExceededException}, it will adapt it to ensure that
* {@link RateLimitExceededException#isLegacy()} returns {@code true} * {@link RateLimitExceededException#isLegacy()} returns {@code true}
*/ */
public static void adaptLegacyException(final RateLimitValidator validator) throws RateLimitExceededException { static void adaptLegacyException(final RateLimitValidator validator) throws RateLimitExceededException {
try { try {
validator.validate(); validator.validate();
} catch (final RateLimitExceededException e) { } catch (final RateLimitExceededException e) {
@ -146,9 +51,8 @@ public class RateLimiter {
} }
@FunctionalInterface @FunctionalInterface
public interface RateLimitValidator { interface RateLimitValidator {
void validate() throws RateLimitExceededException; void validate() throws RateLimitExceededException;
} }
} }

View File

@ -0,0 +1,13 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
public record RateLimiterConfig(int bucketSize, double leakRatePerMinute) {
public double leakRatePerMillis() {
return leakRatePerMinute / (60.0 * 1000.0);
}
}

View File

@ -0,0 +1,28 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
/**
* Represents an information that defines a rate limiter.
*/
public interface RateLimiterDescriptor {
/**
* Implementing classes will likely be Enums, so name is chosen not to clash with {@link Enum#name()}.
* @return id of this rate limiter to be used in `yml` config files and as a part of the bucket key.
*/
String id();
/**
* @return {@code true} if this rate limiter needs to watch for dynamic configuration changes.
*/
boolean isDynamic();
/**
* @return an instance of {@link RateLimiterConfig} to be used by default,
* i.e. if there is no overrides in the application configuration files (static or dynamic).
*/
RateLimiterConfig defaultConfig();
}

View File

@ -5,193 +5,232 @@
package org.whispersystems.textsecuregcm.limits; package org.whispersystems.textsecuregcm.limits;
import com.google.common.annotations.VisibleForTesting;
import java.util.Map; import java.util.Map;
import java.util.Optional; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.Pair;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
public class RateLimiters { public class RateLimiters extends BaseRateLimiters<RateLimiters.For> {
public enum Handle { public enum For implements RateLimiterDescriptor {
USERNAME_LOOKUP("usernameLookup"), BACKUP_AUTH_CHECK("backupAuthCheck", false, new RateLimiterConfig(100, 100 / (24.0 * 60.0))),
CHECK_ACCOUNT_EXISTENCE("checkAccountExistence"),
BACKUP_AUTH_CHECK; SMS_DESTINATION("smsDestination", false, new RateLimiterConfig(2, 2)),
VOICE_DESTINATION("voxDestination", false, new RateLimiterConfig(2, 1.0 / 2.0)),
VOICE_DESTINATION_DAILY("voxDestinationDaily", false, new RateLimiterConfig(10, 10.0 / (24.0 * 60.0))),
SMS_VOICE_IP("smsVoiceIp", false, new RateLimiterConfig(1000, 1000)),
SMS_VOICE_PREFIX("smsVoicePrefix", false, new RateLimiterConfig(1000, 1000)),
VERIFY("verify", false, new RateLimiterConfig(2, 2)),
PIN("pin", false, new RateLimiterConfig(10, 1 / (24.0 * 60.0))),
ATTACHMENT("attachmentCreate", false, new RateLimiterConfig(50, 50)),
PRE_KEYS("prekeys", false, new RateLimiterConfig(6, 1.0 / 10.0)),
MESSAGES("messages", false, new RateLimiterConfig(60, 60)),
ALLOCATE_DEVICE("allocateDevice", false, new RateLimiterConfig(2, 1.0 / 2.0)),
VERIFY_DEVICE("verifyDevice", false, new RateLimiterConfig(6, 1.0 / 10.0)),
TURN("turnAllocate", false, new RateLimiterConfig(60, 60)),
PROFILE("profile", false, new RateLimiterConfig(4320, 3)),
STICKER_PACK("stickerPack", false, new RateLimiterConfig(50, 20 / (24.0 * 60.0))),
ART_PACK("artPack", false, new RateLimiterConfig(50, 20 / (24.0 * 60.0))),
USERNAME_LOOKUP("usernameLookup", false, new RateLimiterConfig(100, 100 / (24.0 * 60.0))),
USERNAME_SET("usernameSet", false, new RateLimiterConfig(100, 100 / (24.0 * 60.0))),
USERNAME_RESERVE("usernameReserve", false, new RateLimiterConfig(100, 100 / (24.0 * 60.0))),
CHECK_ACCOUNT_EXISTENCE("checkAccountExistence", false, new RateLimiterConfig(1_000, 1_000 / 60.0)),
STORIES("stories", false, new RateLimiterConfig(10_000, 10_000 / (24.0 * 60.0))),
REGISTRATION("registration", false, new RateLimiterConfig(2, 2)),
VERIFICATION_PUSH_CHALLENGE("verificationPushChallenge", false, new RateLimiterConfig(5, 2)),
VERIFICATION_CAPTCHA("verificationCaptcha", false, new RateLimiterConfig(10, 2)),
RATE_LIMIT_RESET("rateLimitReset", true, new RateLimiterConfig(2, 2.0 / (60 * 24))),
RECAPTCHA_CHALLENGE_ATTEMPT("recaptchaChallengeAttempt", true, new RateLimiterConfig(10, 10.0 / (60 * 24))),
RECAPTCHA_CHALLENGE_SUCCESS("recaptchaChallengeSuccess", true, new RateLimiterConfig(2, 2.0 / (60 * 24))),
PUSH_CHALLENGE_ATTEMPT("pushChallengeAttempt", true, new RateLimiterConfig(10, 10.0 / (60 * 24))),
PUSH_CHALLENGE_SUCCESS("pushChallengeSuccess", true, new RateLimiterConfig(2, 2.0 / (60 * 24))),
;
private final String id; private final String id;
private final boolean dynamic;
Handle(final String id) { private final RateLimiterConfig defaultConfig;
For(final String id, final boolean dynamic, final RateLimiterConfig defaultConfig) {
this.id = id; this.id = id;
} this.dynamic = dynamic;
this.defaultConfig = defaultConfig;
Handle() {
this.id = name();
} }
public String id() { public String id() {
return id; return id;
} }
@Override
public boolean isDynamic() {
return dynamic;
}
public RateLimiterConfig defaultConfig() {
return defaultConfig;
}
} }
private final RateLimiter smsDestinationLimiter; public static RateLimiters createAndValidate(
private final RateLimiter voiceDestinationLimiter; final Map<String, RateLimiterConfig> configs,
private final RateLimiter voiceDestinationDailyLimiter; final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
private final RateLimiter smsVoiceIpLimiter; final FaultTolerantRedisCluster cacheCluster) {
private final RateLimiter smsVoicePrefixLimiter; final RateLimiters rateLimiters = new RateLimiters(configs, dynamicConfigurationManager, cacheCluster);
private final RateLimiter verifyLimiter; rateLimiters.validateValuesAndConfigs();
private final RateLimiter verificationCaptchaLimiter; return rateLimiters;
private final RateLimiter verificationPushChallengeLimiter;
private final RateLimiter pinLimiter;
private final RateLimiter registrationLimiter;
private final RateLimiter attachmentLimiter;
private final RateLimiter preKeysLimiter;
private final RateLimiter messagesLimiter;
private final RateLimiter allocateDeviceLimiter;
private final RateLimiter verifyDeviceLimiter;
private final RateLimiter turnLimiter;
private final RateLimiter profileLimiter;
private final RateLimiter stickerPackLimiter;
private final RateLimiter artPackLimiter;
private final RateLimiter usernameSetLimiter;
private final RateLimiter usernameReserveLimiter;
private final Map<String, RateLimiter> rateLimiterByHandle;
public RateLimiters(final RateLimitsConfiguration config, final FaultTolerantRedisCluster cacheCluster) {
this.smsDestinationLimiter = fromConfig("smsDestination", config.getSmsDestination(), cacheCluster);
this.voiceDestinationLimiter = fromConfig("voxDestination", config.getVoiceDestination(), cacheCluster);
this.voiceDestinationDailyLimiter = fromConfig("voxDestinationDaily", config.getVoiceDestinationDaily(),
cacheCluster);
this.smsVoiceIpLimiter = fromConfig("smsVoiceIp", config.getSmsVoiceIp(), cacheCluster);
this.smsVoicePrefixLimiter = fromConfig("smsVoicePrefix", config.getSmsVoicePrefix(), cacheCluster);
this.verifyLimiter = fromConfig("verify", config.getVerifyNumber(), cacheCluster);
this.verificationCaptchaLimiter = fromConfig("verificationCaptcha", config.getVerificationCaptcha(), cacheCluster);
this.verificationPushChallengeLimiter = fromConfig("verificationPushChallenge",
config.getVerificationPushChallenge(), cacheCluster);
this.pinLimiter = fromConfig("pin", config.getVerifyPin(), cacheCluster);
this.registrationLimiter = fromConfig("registration", config.getRegistration(), cacheCluster);
this.attachmentLimiter = fromConfig("attachmentCreate", config.getAttachments(), cacheCluster);
this.preKeysLimiter = fromConfig("prekeys", config.getPreKeys(), cacheCluster);
this.messagesLimiter = fromConfig("messages", config.getMessages(), cacheCluster);
this.allocateDeviceLimiter = fromConfig("allocateDevice", config.getAllocateDevice(), cacheCluster);
this.verifyDeviceLimiter = fromConfig("verifyDevice", config.getVerifyDevice(), cacheCluster);
this.turnLimiter = fromConfig("turnAllocate", config.getTurnAllocations(), cacheCluster);
this.profileLimiter = fromConfig("profile", config.getProfile(), cacheCluster);
this.stickerPackLimiter = fromConfig("stickerPack", config.getStickerPack(), cacheCluster);
this.artPackLimiter = fromConfig("artPack", config.getArtPack(), cacheCluster);
this.usernameSetLimiter = fromConfig("usernameSet", config.getUsernameSet(), cacheCluster);
this.usernameReserveLimiter = fromConfig("usernameReserve", config.getUsernameReserve(), cacheCluster);
this.rateLimiterByHandle = Stream.of(
fromConfig(Handle.BACKUP_AUTH_CHECK.id(), config.getBackupAuthCheck(), cacheCluster),
fromConfig(Handle.CHECK_ACCOUNT_EXISTENCE.id(), config.getCheckAccountExistence(), cacheCluster),
fromConfig(Handle.USERNAME_LOOKUP.id(), config.getUsernameLookup(), cacheCluster)
).map(rl -> Pair.of(rl.name, rl)).collect(Collectors.toMap(Pair::getKey, Pair::getValue));
} }
public Optional<RateLimiter> byHandle(final Handle handle) { @VisibleForTesting
return Optional.ofNullable(rateLimiterByHandle.get(handle.id())); RateLimiters(
final Map<String, RateLimiterConfig> configs,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final FaultTolerantRedisCluster cacheCluster) {
super(For.values(), configs, dynamicConfigurationManager, cacheCluster);
} }
public RateLimiter getAllocateDeviceLimiter() { public RateLimiter getAllocateDeviceLimiter() {
return allocateDeviceLimiter; return forDescriptor(For.ALLOCATE_DEVICE);
} }
public RateLimiter getVerifyDeviceLimiter() { public RateLimiter getVerifyDeviceLimiter() {
return verifyDeviceLimiter; return forDescriptor(For.VERIFY_DEVICE);
} }
public RateLimiter getMessagesLimiter() { public RateLimiter getMessagesLimiter() {
return messagesLimiter; return forDescriptor(For.MESSAGES);
} }
public RateLimiter getPreKeysLimiter() { public RateLimiter getPreKeysLimiter() {
return preKeysLimiter; return forDescriptor(For.PRE_KEYS);
} }
public RateLimiter getAttachmentLimiter() { public RateLimiter getAttachmentLimiter() {
return this.attachmentLimiter; return forDescriptor(For.ATTACHMENT);
} }
public RateLimiter getSmsDestinationLimiter() { public RateLimiter getSmsDestinationLimiter() {
return smsDestinationLimiter; return forDescriptor(For.SMS_DESTINATION);
} }
public RateLimiter getSmsVoiceIpLimiter() { public RateLimiter getSmsVoiceIpLimiter() {
return smsVoiceIpLimiter; return forDescriptor(For.SMS_VOICE_IP);
} }
public RateLimiter getSmsVoicePrefixLimiter() { public RateLimiter getSmsVoicePrefixLimiter() {
return smsVoicePrefixLimiter; return forDescriptor(For.SMS_VOICE_PREFIX);
} }
public RateLimiter getVoiceDestinationLimiter() { public RateLimiter getVoiceDestinationLimiter() {
return voiceDestinationLimiter; return forDescriptor(For.VOICE_DESTINATION);
} }
public RateLimiter getVoiceDestinationDailyLimiter() { public RateLimiter getVoiceDestinationDailyLimiter() {
return voiceDestinationDailyLimiter; return forDescriptor(For.VOICE_DESTINATION_DAILY);
} }
public RateLimiter getVerifyLimiter() { public RateLimiter getVerifyLimiter() {
return verifyLimiter; return forDescriptor(For.VERIFY);
}
public RateLimiter getVerificationCaptchaLimiter() {
return verificationCaptchaLimiter;
}
public RateLimiter getVerificationPushChallengeLimiter() {
return verificationPushChallengeLimiter;
} }
public RateLimiter getPinLimiter() { public RateLimiter getPinLimiter() {
return pinLimiter; return forDescriptor(For.PIN);
}
public RateLimiter getRegistrationLimiter() {
return registrationLimiter;
} }
public RateLimiter getTurnLimiter() { public RateLimiter getTurnLimiter() {
return turnLimiter; return forDescriptor(For.TURN);
} }
public RateLimiter getProfileLimiter() { public RateLimiter getProfileLimiter() {
return profileLimiter; return forDescriptor(For.PROFILE);
} }
public RateLimiter getStickerPackLimiter() { public RateLimiter getStickerPackLimiter() {
return stickerPackLimiter; return forDescriptor(For.STICKER_PACK);
} }
public RateLimiter getArtPackLimiter() { public RateLimiter getArtPackLimiter() {
return artPackLimiter; return forDescriptor(For.ART_PACK);
} }
public RateLimiter getUsernameLookupLimiter() { public RateLimiter getUsernameLookupLimiter() {
return byHandle(Handle.USERNAME_LOOKUP).orElseThrow(); return forDescriptor(For.USERNAME_LOOKUP);
} }
public RateLimiter getUsernameSetLimiter() { public RateLimiter getUsernameSetLimiter() {
return usernameSetLimiter; return forDescriptor(For.USERNAME_SET);
} }
public RateLimiter getUsernameReserveLimiter() { public RateLimiter getUsernameReserveLimiter() {
return usernameReserveLimiter; return forDescriptor(For.USERNAME_RESERVE);
} }
public RateLimiter getCheckAccountExistenceLimiter() { public RateLimiter getCheckAccountExistenceLimiter() {
return byHandle(Handle.CHECK_ACCOUNT_EXISTENCE).orElseThrow(); return forDescriptor(For.CHECK_ACCOUNT_EXISTENCE);
} }
private static RateLimiter fromConfig( public RateLimiter getStoriesLimiter() {
final String name, return forDescriptor(For.STORIES);
final RateLimitsConfiguration.RateLimitConfiguration cfg, }
final FaultTolerantRedisCluster cacheCluster) {
return new RateLimiter(cacheCluster, name, cfg.getBucketSize(), cfg.getLeakRatePerMinute()); public RateLimiter getRegistrationLimiter() {
return forDescriptor(For.REGISTRATION);
}
public RateLimiter getRateLimitResetLimiter() {
return forDescriptor(For.RATE_LIMIT_RESET);
}
public RateLimiter getRecaptchaChallengeAttemptLimiter() {
return forDescriptor(For.RECAPTCHA_CHALLENGE_ATTEMPT);
}
public RateLimiter getRecaptchaChallengeSuccessLimiter() {
return forDescriptor(For.RECAPTCHA_CHALLENGE_SUCCESS);
}
public RateLimiter getPushChallengeAttemptLimiter() {
return forDescriptor(For.PUSH_CHALLENGE_ATTEMPT);
}
public RateLimiter getPushChallengeSuccessLimiter() {
return forDescriptor(For.PUSH_CHALLENGE_SUCCESS);
}
public RateLimiter getVerificationPushChallengeLimiter() {
return forDescriptor(For.VERIFICATION_PUSH_CHALLENGE);
}
public RateLimiter getVerificationCaptchaLimiter() {
return forDescriptor(For.VERIFICATION_CAPTCHA);
} }
} }

View File

@ -0,0 +1,111 @@
/*
* Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static com.codahale.metrics.MetricRegistry.name;
import static java.util.Objects.requireNonNull;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.SystemMapper;
public class StaticRateLimiter implements RateLimiter {
private static final Logger logger = LoggerFactory.getLogger(StaticRateLimiter.class);
private static final ObjectMapper MAPPER = SystemMapper.getMapper();
protected final String name;
private final RateLimiterConfig config;
protected final FaultTolerantRedisCluster cacheCluster;
private final Meter meter;
private final Timer validateTimer;
public StaticRateLimiter(
final String name,
final RateLimiterConfig config,
final FaultTolerantRedisCluster cacheCluster) {
final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
this.name = requireNonNull(name);
this.config = requireNonNull(config);
this.cacheCluster = requireNonNull(cacheCluster);
this.meter = metricRegistry.meter(name(getClass(), name, "exceeded"));
this.validateTimer = metricRegistry.timer(name(getClass(), name, "validate"));
}
@Override
public void validate(final String key, final int amount) throws RateLimitExceededException {
try (final Timer.Context ignored = validateTimer.time()) {
final LeakyBucket bucket = getBucket(key);
if (bucket.add(amount)) {
setBucket(key, bucket);
} else {
meter.mark();
throw new RateLimitExceededException(bucket.getTimeUntilSpaceAvailable(amount), true);
}
}
}
@Override
public boolean hasAvailablePermits(final String key, final int permits) {
return getBucket(key).getTimeUntilSpaceAvailable(permits).equals(Duration.ZERO);
}
@Override
public void clear(final String key) {
cacheCluster.useCluster(connection -> connection.sync().del(getBucketName(key)));
}
@Override
public RateLimiterConfig config() {
return config;
}
private void setBucket(final String key, final LeakyBucket bucket) {
try {
final String serialized = bucket.serialize(MAPPER);
cacheCluster.useCluster(connection -> connection.sync().setex(
getBucketName(key),
(int) Math.ceil((config.bucketSize() / config.leakRatePerMillis()) / 1000),
serialized));
} catch (final JsonProcessingException e) {
throw new IllegalArgumentException(e);
}
}
private LeakyBucket getBucket(final String key) {
try {
final String serialized = cacheCluster.withCluster(connection -> connection.sync().get(getBucketName(key)));
if (serialized != null) {
return LeakyBucket.fromSerialized(MAPPER, serialized);
}
} catch (final IOException e) {
logger.warn("Deserialization error", e);
}
return new LeakyBucket(config.bucketSize(), config.leakRatePerMillis());
}
private String getBucketName(final String key) {
return "leaky_bucket::" + name + "::" + key;
}
}

View File

@ -20,7 +20,8 @@ import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration; import org.whispersystems.textsecuregcm.limits.RateLimiterConfig;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
@ -274,30 +275,19 @@ class DynamicConfigurationTest {
@Test @Test
void testParseLimits() throws JsonProcessingException { void testParseLimits() throws JsonProcessingException {
{ final String limitsConfig = REQUIRED_CONFIG.concat("""
final String emptyConfigYaml = REQUIRED_CONFIG.concat("test: true");
final DynamicConfiguration emptyConfig =
DynamicConfigurationManager.parseConfiguration(emptyConfigYaml, DynamicConfiguration.class).orElseThrow();
assertThat(emptyConfig.getLimits().getRateLimitReset().getBucketSize()).isEqualTo(2);
assertThat(emptyConfig.getLimits().getRateLimitReset().getLeakRatePerMinute()).isEqualTo(2.0 / (60 * 24));
}
{
final String limitsConfig = REQUIRED_CONFIG.concat("""
limits: limits:
rateLimitReset: rateLimitReset:
bucketSize: 17 bucketSize: 17
leakRatePerMinute: 44 leakRatePerMinute: 44
"""); """);
final RateLimitConfiguration resetRateLimitConfiguration = final RateLimiterConfig resetRateLimiterConfig =
DynamicConfigurationManager.parseConfiguration(limitsConfig, DynamicConfiguration.class).orElseThrow() DynamicConfigurationManager.parseConfiguration(limitsConfig, DynamicConfiguration.class).orElseThrow()
.getLimits().getRateLimitReset(); .getLimits().get(RateLimiters.For.RATE_LIMIT_RESET.id());
assertThat(resetRateLimitConfiguration.getBucketSize()).isEqualTo(17); assertThat(resetRateLimiterConfig.bucketSize()).isEqualTo(17);
assertThat(resetRateLimitConfiguration.getLeakRatePerMinute()).isEqualTo(44); assertThat(resetRateLimiterConfig.leakRatePerMinute()).isEqualTo(44);
}
} }
@Test @Test

View File

@ -171,6 +171,7 @@ class AccountControllerTest {
private static RateLimiter usernameSetLimiter = mock(RateLimiter.class); private static RateLimiter usernameSetLimiter = mock(RateLimiter.class);
private static RateLimiter usernameReserveLimiter = mock(RateLimiter.class); private static RateLimiter usernameReserveLimiter = mock(RateLimiter.class);
private static RateLimiter usernameLookupLimiter = mock(RateLimiter.class); private static RateLimiter usernameLookupLimiter = mock(RateLimiter.class);
private static RateLimiter checkAccountExistence = mock(RateLimiter.class);
private static RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class); private static RegistrationServiceClient registrationServiceClient = mock(RegistrationServiceClient.class);
private static TurnTokenGenerator turnTokenGenerator = mock(TurnTokenGenerator.class); private static TurnTokenGenerator turnTokenGenerator = mock(TurnTokenGenerator.class);
private static Account senderPinAccount = mock(Account.class); private static Account senderPinAccount = mock(Account.class);
@ -250,6 +251,8 @@ class AccountControllerTest {
when(rateLimiters.getUsernameSetLimiter()).thenReturn(usernameSetLimiter); when(rateLimiters.getUsernameSetLimiter()).thenReturn(usernameSetLimiter);
when(rateLimiters.getUsernameReserveLimiter()).thenReturn(usernameReserveLimiter); when(rateLimiters.getUsernameReserveLimiter()).thenReturn(usernameReserveLimiter);
when(rateLimiters.getUsernameLookupLimiter()).thenReturn(usernameLookupLimiter); when(rateLimiters.getUsernameLookupLimiter()).thenReturn(usernameLookupLimiter);
when(rateLimiters.forDescriptor(eq(RateLimiters.For.USERNAME_LOOKUP))).thenReturn(usernameLookupLimiter);
when(rateLimiters.forDescriptor(eq(RateLimiters.For.CHECK_ACCOUNT_EXISTENCE))).thenReturn(checkAccountExistence);
when(senderPinAccount.getLastSeen()).thenReturn(System.currentTimeMillis()); when(senderPinAccount.getLastSeen()).thenReturn(System.currentTimeMillis());
when(senderPinAccount.getRegistrationLock()).thenReturn( when(senderPinAccount.getRegistrationLock()).thenReturn(
@ -2124,7 +2127,7 @@ class AccountControllerTest {
when(accountsManager.getByAccountIdentifier(accountIdentifier)).thenReturn(Optional.of(account)); when(accountsManager.getByAccountIdentifier(accountIdentifier)).thenReturn(Optional.of(account));
MockUtils.updateRateLimiterResponseToFail( MockUtils.updateRateLimiterResponseToFail(
rateLimiters, RateLimiters.Handle.CHECK_ACCOUNT_EXISTENCE, "127.0.0.1", expectedRetryAfter, true); rateLimiters, RateLimiters.For.CHECK_ACCOUNT_EXISTENCE, "127.0.0.1", expectedRetryAfter, true);
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
.target(String.format("/v1/accounts/account/%s", accountIdentifier)) .target(String.format("/v1/accounts/account/%s", accountIdentifier))
@ -2189,7 +2192,7 @@ class AccountControllerTest {
void testLookupUsernameRateLimited() throws RateLimitExceededException { void testLookupUsernameRateLimited() throws RateLimitExceededException {
final Duration expectedRetryAfter = Duration.ofSeconds(13); final Duration expectedRetryAfter = Duration.ofSeconds(13);
MockUtils.updateRateLimiterResponseToFail( MockUtils.updateRateLimiterResponseToFail(
rateLimiters, RateLimiters.Handle.USERNAME_LOOKUP, "127.0.0.1", expectedRetryAfter, true); rateLimiters, RateLimiters.For.USERNAME_LOOKUP, "127.0.0.1", expectedRetryAfter, true);
final Response response = resources.getJerseyTest() final Response response = resources.getJerseyTest()
.target(String.format("v1/accounts/username_hash/%s", BASE_64_URL_USERNAME_HASH_1)) .target(String.format("v1/accounts/username_hash/%s", BASE_64_URL_USERNAME_HASH_1))
.request() .request()

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2022 Signal Messenger, LLC * Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2022 Signal Messenger, LLC * Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */

View File

@ -1,9 +1,9 @@
/* /*
* Copyright 2013-2021 Signal Messenger, LLC * Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.tests.limits; package org.whispersystems.textsecuregcm.limits;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
@ -15,7 +15,6 @@ import java.io.IOException;
import java.time.Duration; import java.time.Duration;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.limits.LeakyBucket;
class LeakyBucketTest { class LeakyBucketTest {

View File

@ -1,3 +1,8 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits; package org.whispersystems.textsecuregcm.limits;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
@ -12,17 +17,17 @@ import java.util.UUID;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource; import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.spam.RateLimitChallengeListener;
import org.whispersystems.textsecuregcm.captcha.AssessmentResult; import org.whispersystems.textsecuregcm.captcha.AssessmentResult;
import org.whispersystems.textsecuregcm.captcha.CaptchaChecker; import org.whispersystems.textsecuregcm.captcha.CaptchaChecker;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.spam.RateLimitChallengeListener;
import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.Account;
class RateLimitChallengeManagerTest { class RateLimitChallengeManagerTest {
private PushChallengeManager pushChallengeManager; private PushChallengeManager pushChallengeManager;
private CaptchaChecker captchaChecker; private CaptchaChecker captchaChecker;
private DynamicRateLimiters rateLimiters; private RateLimiters rateLimiters;
private RateLimitChallengeListener rateLimitChallengeListener; private RateLimitChallengeListener rateLimitChallengeListener;
private RateLimitChallengeManager rateLimitChallengeManager; private RateLimitChallengeManager rateLimitChallengeManager;
@ -31,7 +36,7 @@ class RateLimitChallengeManagerTest {
void setUp() { void setUp() {
pushChallengeManager = mock(PushChallengeManager.class); pushChallengeManager = mock(PushChallengeManager.class);
captchaChecker = mock(CaptchaChecker.class); captchaChecker = mock(CaptchaChecker.class);
rateLimiters = mock(DynamicRateLimiters.class); rateLimiters = mock(RateLimiters.class);
rateLimitChallengeListener = mock(RateLimitChallengeListener.class); rateLimitChallengeListener = mock(RateLimitChallengeListener.class);
rateLimitChallengeManager = new RateLimitChallengeManager( rateLimitChallengeManager = new RateLimitChallengeManager(

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2021 Signal Messenger, LLC * Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -30,13 +30,13 @@ import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
class RateLimitChallengeOptionManagerTest { class RateLimitChallengeOptionManagerTest {
private DynamicRateLimitChallengeConfiguration rateLimitChallengeConfiguration; private DynamicRateLimitChallengeConfiguration rateLimitChallengeConfiguration;
private DynamicRateLimiters rateLimiters; private RateLimiters rateLimiters;
private RateLimitChallengeOptionManager rateLimitChallengeOptionManager; private RateLimitChallengeOptionManager rateLimitChallengeOptionManager;
@BeforeEach @BeforeEach
void setUp() { void setUp() {
rateLimiters = mock(DynamicRateLimiters.class); rateLimiters = mock(RateLimiters.class);
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager = final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager =
mock(DynamicConfigurationManager.class); mock(DynamicConfigurationManager.class);

View File

@ -11,7 +11,6 @@ import com.google.common.net.HttpHeaders;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension; import io.dropwizard.testing.junit5.ResourceExtension;
import java.time.Duration; import java.time.Duration;
import java.util.Optional;
import javax.ws.rs.GET; import javax.ws.rs.GET;
import javax.ws.rs.Path; import javax.ws.rs.Path;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
@ -43,14 +42,14 @@ public class RateLimitedByIpTest {
public static class Controller { public static class Controller {
@GET @GET
@Path("/strict") @Path("/strict")
@RateLimitedByIp(RateLimiters.Handle.BACKUP_AUTH_CHECK) @RateLimitedByIp(RateLimiters.For.BACKUP_AUTH_CHECK)
public Response strict() { public Response strict() {
return Response.ok().build(); return Response.ok().build();
} }
@GET @GET
@Path("/loose") @Path("/loose")
@RateLimitedByIp(value = RateLimiters.Handle.BACKUP_AUTH_CHECK, failOnUnresolvedIp = false) @RateLimitedByIp(value = RateLimiters.For.BACKUP_AUTH_CHECK, failOnUnresolvedIp = false)
public Response loose() { public Response loose() {
return Response.ok().build(); return Response.ok().build();
} }
@ -59,7 +58,7 @@ public class RateLimitedByIpTest {
private static final RateLimiter RATE_LIMITER = Mockito.mock(RateLimiter.class); private static final RateLimiter RATE_LIMITER = Mockito.mock(RateLimiter.class);
private static final RateLimiters RATE_LIMITERS = MockUtils.buildMock(RateLimiters.class, rl -> private static final RateLimiters RATE_LIMITERS = MockUtils.buildMock(RateLimiters.class, rl ->
Mockito.when(rl.byHandle(Mockito.eq(RateLimiters.Handle.BACKUP_AUTH_CHECK))).thenReturn(Optional.of(RATE_LIMITER))); Mockito.when(rl.forDescriptor(Mockito.eq(RateLimiters.For.BACKUP_AUTH_CHECK))).thenReturn(RATE_LIMITER));
private static final ResourceExtension RESOURCES = ResourceExtension.builder() private static final ResourceExtension RESOURCES = ResourceExtension.builder()
.setMapper(SystemMapper.getMapper()) .setMapper(SystemMapper.getMapper())

View File

@ -0,0 +1,176 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.limits;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.MockUtils;
@SuppressWarnings("unchecked")
public class RateLimitersTest {
private final DynamicConfiguration configuration = mock(DynamicConfiguration.class);
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfig =
MockUtils.buildMock(DynamicConfigurationManager.class, cfg -> when(cfg.getConfiguration()).thenReturn(configuration));
private final FaultTolerantRedisCluster redisCluster = mock(FaultTolerantRedisCluster.class);
private static final String BAD_YAML = """
limits:
smsVoicePrefix:
bucketSize: 150
leakRatePerMinute: 10
unexpected:
bucketSize: 4
leakRatePerMinute: 2
""";
private static final String GOOD_YAML = """
limits:
smsVoicePrefix:
bucketSize: 150
leakRatePerMinute: 10
attachmentCreate:
bucketSize: 4
leakRatePerMinute: 2
""";
public record GenericHolder(@Valid @NotNull @JsonProperty Map<String, RateLimiterConfig> limits) {
}
@Test
public void testValidateConfigs() throws Exception {
assertThrows(IllegalArgumentException.class, () -> {
final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(BAD_YAML, GenericHolder.class).orElseThrow();
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster);
rateLimiters.validateValuesAndConfigs();
});
final GenericHolder cfg = DynamicConfigurationManager.parseConfiguration(GOOD_YAML, GenericHolder.class).orElseThrow();
final RateLimiters rateLimiters = new RateLimiters(cfg.limits(), dynamicConfig, redisCluster);
rateLimiters.validateValuesAndConfigs();
}
@Test
public void testValidateDuplicates() throws Exception {
final TestDescriptor td1 = new TestDescriptor("id1");
final TestDescriptor td2 = new TestDescriptor("id2");
final TestDescriptor td3 = new TestDescriptor("id3");
final TestDescriptor tdDup = new TestDescriptor("id1");
assertThrows(IllegalStateException.class, () -> new BaseRateLimiters<>(
new TestDescriptor[] { td1, td2, td3, tdDup },
Collections.emptyMap(),
dynamicConfig,
redisCluster) {});
new BaseRateLimiters<>(
new TestDescriptor[] { td1, td2, td3 },
Collections.emptyMap(),
dynamicConfig,
redisCluster) {};
}
@Test
void testUnchangingConfiguration() {
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster);
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
final RateLimiterConfig expected = RateLimiters.For.RATE_LIMIT_RESET.defaultConfig();
assertEquals(expected, limiter.config());
}
@Test
void testChangingConfiguration() {
final RateLimiterConfig initialRateLimiterConfig = new RateLimiterConfig(4, 1);
final RateLimiterConfig updatedRateLimiterCongig = new RateLimiterConfig(17, 19);
final RateLimiterConfig baseConfig = new RateLimiterConfig(1, 1);
final Map<String, RateLimiterConfig> limitsConfigMap = new HashMap<>();
limitsConfigMap.put(RateLimiters.For.RECAPTCHA_CHALLENGE_ATTEMPT.id(), baseConfig);
limitsConfigMap.put(RateLimiters.For.RECAPTCHA_CHALLENGE_SUCCESS.id(), baseConfig);
when(configuration.getLimits()).thenReturn(limitsConfigMap);
final RateLimiters rateLimiters = new RateLimiters(Collections.emptyMap(), dynamicConfig, redisCluster);
final RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), initialRateLimiterConfig);
assertEquals(initialRateLimiterConfig, limiter.config());
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeAttemptLimiter().config());
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeSuccessLimiter().config());
limitsConfigMap.put(RateLimiters.For.RATE_LIMIT_RESET.id(), updatedRateLimiterCongig);
assertEquals(updatedRateLimiterCongig, limiter.config());
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeAttemptLimiter().config());
assertEquals(baseConfig, rateLimiters.getRecaptchaChallengeSuccessLimiter().config());
}
@Test
public void testRateLimiterHasItsPrioritiesStraight() throws Exception {
final RateLimiters.For descriptor = RateLimiters.For.RECAPTCHA_CHALLENGE_ATTEMPT;
final RateLimiterConfig configForDynamic = new RateLimiterConfig(1, 1);
final RateLimiterConfig configForStatic = new RateLimiterConfig(2, 2);
final RateLimiterConfig defaultConfig = descriptor.defaultConfig();
final Map<String, RateLimiterConfig> mapForDynamic = new HashMap<>();
final Map<String, RateLimiterConfig> mapForStatic = new HashMap<>();
when(configuration.getLimits()).thenReturn(mapForDynamic);
final RateLimiters rateLimiters = new RateLimiters(mapForStatic, dynamicConfig, redisCluster);
final RateLimiter limiter = rateLimiters.forDescriptor(descriptor);
// test only default is present
mapForDynamic.remove(descriptor.id());
mapForStatic.remove(descriptor.id());
assertEquals(defaultConfig, limiter.config());
// test dynamic and no static
mapForDynamic.put(descriptor.id(), configForDynamic);
mapForStatic.remove(descriptor.id());
assertEquals(configForDynamic, limiter.config());
// test dynamic and static
mapForDynamic.put(descriptor.id(), configForDynamic);
mapForStatic.put(descriptor.id(), configForStatic);
assertEquals(configForDynamic, limiter.config());
// test static, but no dynamic
mapForDynamic.remove(descriptor.id());
mapForStatic.put(descriptor.id(), configForStatic);
assertEquals(configForStatic, limiter.config());
}
private record TestDescriptor(String id) implements RateLimiterDescriptor {
@Override
public boolean isDynamic() {
return false;
}
@Override
public RateLimiterConfig defaultConfig() {
return new RateLimiterConfig(1, 1);
}
}
}

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2021 Signal Messenger, LLC * Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
@ -42,17 +42,16 @@ import org.whispersystems.textsecuregcm.entities.AttachmentDescriptorV3;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
class AttachmentControllerTest { class AttachmentControllerTest {
private static RateLimiters rateLimiters = mock(RateLimiters.class ); private static final RateLimiter RATE_LIMITER = mock(RateLimiter.class);
private static RateLimiter rateLimiter = mock(RateLimiter.class );
static { private static final RateLimiters RATE_LIMITERS = MockUtils.buildMock(RateLimiters.class, rateLimiters ->
when(rateLimiters.getAttachmentLimiter()).thenReturn(rateLimiter); when(rateLimiters.getAttachmentLimiter()).thenReturn(RATE_LIMITER));
}
public static final String RSA_PRIVATE_KEY_PEM; public static final String RSA_PRIVATE_KEY_PEM;
@ -80,8 +79,8 @@ class AttachmentControllerTest {
ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) ImmutableSet.of(AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setMapper(SystemMapper.getMapper()) .setMapper(SystemMapper.getMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new AttachmentControllerV2(rateLimiters, "accessKey", "accessSecret", "us-east-1", "attachmentv2-bucket")) .addResource(new AttachmentControllerV2(RATE_LIMITERS, "accessKey", "accessSecret", "us-east-1", "attachmentv2-bucket"))
.addResource(new AttachmentControllerV3(rateLimiters, "some-cdn.signal.org", "signal@example.com", 1000, "/attach-here", RSA_PRIVATE_KEY_PEM)) .addResource(new AttachmentControllerV3(RATE_LIMITERS, "some-cdn.signal.org", "signal@example.com", 1000, "/attach-here", RSA_PRIVATE_KEY_PEM))
.build(); .build();
} catch (IOException | InvalidKeyException | InvalidKeySpecException e) { } catch (IOException | InvalidKeyException | InvalidKeySpecException e) {
throw new AssertionError(e); throw new AssertionError(e);

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2022 Signal Messenger, LLC * Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */
package org.whispersystems.textsecuregcm.tests.controllers; package org.whispersystems.textsecuregcm.tests.controllers;

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2021 Signal Messenger, LLC * Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */

View File

@ -1,5 +1,5 @@
/* /*
* Copyright 2013-2021 Signal Messenger, LLC * Copyright 2013 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only * SPDX-License-Identifier: AGPL-3.0-only
*/ */

View File

@ -1,78 +0,0 @@
package org.whispersystems.textsecuregcm.tests.limits;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRateLimitsConfiguration;
import org.whispersystems.textsecuregcm.limits.DynamicRateLimiters;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
class DynamicRateLimitsTest {
private DynamicConfigurationManager<DynamicConfiguration> dynamicConfig;
private FaultTolerantRedisCluster redisCluster;
@BeforeEach
void setup() {
this.dynamicConfig = mock(DynamicConfigurationManager.class);
this.redisCluster = mock(FaultTolerantRedisCluster.class);
DynamicConfiguration defaultConfig = new DynamicConfiguration();
when(dynamicConfig.getConfiguration()).thenReturn(defaultConfig);
}
@Test
void testUnchangingConfiguration() {
DynamicRateLimiters rateLimiters = new DynamicRateLimiters(redisCluster, dynamicConfig);
RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
assertThat(limiter.getBucketSize()).isEqualTo(dynamicConfig.getConfiguration().getLimits().getRateLimitReset().getBucketSize());
assertThat(limiter.getLeakRatePerMinute()).isEqualTo(dynamicConfig.getConfiguration().getLimits().getRateLimitReset().getLeakRatePerMinute());
assertSame(rateLimiters.getRateLimitResetLimiter(), limiter);
}
@Test
void testChangingConfiguration() {
DynamicConfiguration configuration = mock(DynamicConfiguration.class);
DynamicRateLimitsConfiguration limitsConfiguration = mock(DynamicRateLimitsConfiguration.class);
when(configuration.getLimits()).thenReturn(limitsConfiguration);
when(limitsConfiguration.getRecaptchaChallengeAttempt()).thenReturn(new RateLimitConfiguration());
when(limitsConfiguration.getRecaptchaChallengeSuccess()).thenReturn(new RateLimitConfiguration());
when(limitsConfiguration.getPushChallengeAttempt()).thenReturn(new RateLimitConfiguration());
when(limitsConfiguration.getPushChallengeSuccess()).thenReturn(new RateLimitConfiguration());
final RateLimitConfiguration initialRateLimitConfiguration = new RateLimitConfiguration(4, 1);
when(limitsConfiguration.getRateLimitReset()).thenReturn(initialRateLimitConfiguration);
when(dynamicConfig.getConfiguration()).thenReturn(configuration);
DynamicRateLimiters rateLimiters = new DynamicRateLimiters(redisCluster, dynamicConfig);
RateLimiter limiter = rateLimiters.getRateLimitResetLimiter();
assertThat(limiter.getBucketSize()).isEqualTo(4);
assertThat(limiter.getLeakRatePerMinute()).isEqualTo(1);
assertSame(rateLimiters.getRateLimitResetLimiter(), limiter);
when(limitsConfiguration.getRateLimitReset()).thenReturn(new RateLimitConfiguration(17, 19));
RateLimiter changed = rateLimiters.getRateLimitResetLimiter();
assertThat(changed.getBucketSize()).isEqualTo(17);
assertThat(changed.getLeakRatePerMinute()).isEqualTo(19);
assertNotSame(limiter, changed);
}
}

View File

@ -45,10 +45,10 @@ public final class MockUtils {
public static void updateRateLimiterResponseToAllow( public static void updateRateLimiterResponseToAllow(
final RateLimiters rateLimitersMock, final RateLimiters rateLimitersMock,
final RateLimiters.Handle handle, final RateLimiters.For handle,
final String input) { final String input) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class); final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(Optional.of(mockRateLimiter)).when(rateLimitersMock).byHandle(eq(handle)); doReturn(Optional.of(mockRateLimiter)).when(rateLimitersMock).forDescriptor(eq(handle));
try { try {
doNothing().when(mockRateLimiter).validate(eq(input)); doNothing().when(mockRateLimiter).validate(eq(input));
} catch (final RateLimitExceededException e) { } catch (final RateLimitExceededException e) {
@ -58,12 +58,12 @@ public final class MockUtils {
public static void updateRateLimiterResponseToFail( public static void updateRateLimiterResponseToFail(
final RateLimiters rateLimitersMock, final RateLimiters rateLimitersMock,
final RateLimiters.Handle handle, final RateLimiters.For handle,
final String input, final String input,
final Duration retryAfter, final Duration retryAfter,
final boolean legacyStatusCode) { final boolean legacyStatusCode) {
final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class); final RateLimiter mockRateLimiter = Mockito.mock(RateLimiter.class);
doReturn(Optional.of(mockRateLimiter)).when(rateLimitersMock).byHandle(eq(handle)); doReturn(mockRateLimiter).when(rateLimitersMock).forDescriptor(eq(handle));
try { try {
doThrow(new RateLimitExceededException(retryAfter, legacyStatusCode)).when(mockRateLimiter).validate(eq(input)); doThrow(new RateLimitExceededException(retryAfter, legacyStatusCode)).when(mockRateLimiter).validate(eq(input));
} catch (final RateLimitExceededException e) { } catch (final RateLimitExceededException e) {