Add secondaryCacheCluster to RateLimiter

This commit is contained in:
Chris Eager 2021-04-30 11:15:25 -05:00 committed by Chris Eager
parent b7c611a466
commit 4f6b132449
8 changed files with 132 additions and 50 deletions

View File

@ -109,6 +109,11 @@ public class WhisperServerConfiguration extends Configuration {
@JsonProperty @JsonProperty
private RedisClusterConfiguration pushSchedulerCluster; private RedisClusterConfiguration pushSchedulerCluster;
@NotNull
@Valid
@JsonProperty
private RedisClusterConfiguration rateLimitersCluster;
@NotNull @NotNull
@Valid @Valid
@JsonProperty @JsonProperty
@ -309,6 +314,10 @@ public class WhisperServerConfiguration extends Configuration {
return pushSchedulerCluster; return pushSchedulerCluster;
} }
public RedisClusterConfiguration getRateLimitersCluster() {
return rateLimitersCluster;
}
public MessageDynamoDbConfiguration getMessageDynamoDbConfiguration() { public MessageDynamoDbConfiguration getMessageDynamoDbConfiguration() {
return messageDynamoDb; return messageDynamoDb;
} }

View File

@ -44,7 +44,6 @@ import io.micrometer.core.instrument.distribution.DistributionStatisticConfig;
import io.micrometer.wavefront.WavefrontConfig; import io.micrometer.wavefront.WavefrontConfig;
import io.micrometer.wavefront.WavefrontMeterRegistry; import io.micrometer.wavefront.WavefrontMeterRegistry;
import java.net.http.HttpClient; import java.net.http.HttpClient;
import java.security.Security;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
@ -345,6 +344,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
ClientResources presenceClientResources = ClientResources.builder().build(); ClientResources presenceClientResources = ClientResources.builder().build();
ClientResources metricsCacheClientResources = ClientResources.builder().build(); ClientResources metricsCacheClientResources = ClientResources.builder().build();
ClientResources pushSchedulerCacheClientResources = ClientResources.builder().ioThreadPoolSize(4).build(); ClientResources pushSchedulerCacheClientResources = ClientResources.builder().ioThreadPoolSize(4).build();
ClientResources rateLimitersCacheClientResources = ClientResources.builder().build();
ConnectionEventLogger.logConnectionEvents(generalCacheClientResources); ConnectionEventLogger.logConnectionEvents(generalCacheClientResources);
ConnectionEventLogger.logConnectionEvents(messageCacheClientResources); ConnectionEventLogger.logConnectionEvents(messageCacheClientResources);
@ -356,6 +356,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
FaultTolerantRedisCluster clientPresenceCluster = new FaultTolerantRedisCluster("client_presence_cluster", config.getClientPresenceClusterConfiguration(), presenceClientResources); FaultTolerantRedisCluster clientPresenceCluster = new FaultTolerantRedisCluster("client_presence_cluster", config.getClientPresenceClusterConfiguration(), presenceClientResources);
FaultTolerantRedisCluster metricsCluster = new FaultTolerantRedisCluster("metrics_cluster", config.getMetricsClusterConfiguration(), metricsCacheClientResources); FaultTolerantRedisCluster metricsCluster = new FaultTolerantRedisCluster("metrics_cluster", config.getMetricsClusterConfiguration(), metricsCacheClientResources);
FaultTolerantRedisCluster pushSchedulerCluster = new FaultTolerantRedisCluster("push_scheduler", config.getPushSchedulerCluster(), pushSchedulerCacheClientResources); FaultTolerantRedisCluster pushSchedulerCluster = new FaultTolerantRedisCluster("push_scheduler", config.getPushSchedulerCluster(), pushSchedulerCacheClientResources);
FaultTolerantRedisCluster rateLimitersCluster = new FaultTolerantRedisCluster("rate_limiters", config.getRateLimitersCluster(), rateLimitersCacheClientResources);
BlockingQueue<Runnable> keyspaceNotificationDispatchQueue = new ArrayBlockingQueue<>(10_000); BlockingQueue<Runnable> keyspaceNotificationDispatchQueue = new ArrayBlockingQueue<>(10_000);
Metrics.gaugeCollectionSize(name(getClass(), "keyspaceNotificationDispatchQueueSize"), Collections.emptyList(), keyspaceNotificationDispatchQueue); Metrics.gaugeCollectionSize(name(getClass(), "keyspaceNotificationDispatchQueueSize"), Collections.emptyList(), keyspaceNotificationDispatchQueue);
@ -403,7 +404,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
PubSubManager pubSubManager = new PubSubManager(pubsubClient, dispatchManager); PubSubManager pubSubManager = new PubSubManager(pubsubClient, dispatchManager);
APNSender apnSender = new APNSender(apnSenderExecutor, accountsManager, config.getApnConfiguration()); APNSender apnSender = new APNSender(apnSenderExecutor, accountsManager, config.getApnConfiguration());
GCMSender gcmSender = new GCMSender(gcmSenderExecutor, accountsManager, config.getGcmConfiguration().getApiKey()); GCMSender gcmSender = new GCMSender(gcmSenderExecutor, accountsManager, config.getGcmConfiguration().getApiKey());
RateLimiters rateLimiters = new RateLimiters(config.getLimitsConfiguration(), dynamicConfigurationManager, cacheCluster); RateLimiters rateLimiters = new RateLimiters(config.getLimitsConfiguration(), dynamicConfigurationManager, cacheCluster, rateLimitersCluster);
ProvisioningManager provisioningManager = new ProvisioningManager(pubSubManager); ProvisioningManager provisioningManager = new ProvisioningManager(pubSubManager);
AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager); AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager);

View File

@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.limits;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.CardinalityRateLimitConfiguration; import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.CardinalityRateLimitConfiguration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import javax.annotation.Nullable;
import java.time.Duration; import java.time.Duration;
import java.util.Random; import java.util.Random;
@ -20,6 +21,8 @@ import java.util.Random;
public class CardinalityRateLimiter { public class CardinalityRateLimiter {
private final FaultTolerantRedisCluster cacheCluster; private final FaultTolerantRedisCluster cacheCluster;
@Nullable
private final FaultTolerantRedisCluster secondaryCacheCluster;
private final String name; private final String name;
@ -29,8 +32,9 @@ public class CardinalityRateLimiter {
private final Random random = new Random(); private final Random random = new Random();
public CardinalityRateLimiter(final FaultTolerantRedisCluster cacheCluster, final String name, final Duration ttl, final Duration ttlJitter, final int maxCardinality) { public CardinalityRateLimiter(final FaultTolerantRedisCluster cacheCluster, @Nullable final FaultTolerantRedisCluster secondaryCacheCluster, final String name, final Duration ttl, final Duration ttlJitter, final int maxCardinality) {
this.cacheCluster = cacheCluster; this.cacheCluster = cacheCluster;
this.secondaryCacheCluster = secondaryCacheCluster;
this.name = name; this.name = name;
@ -58,7 +62,28 @@ public class CardinalityRateLimiter {
return changed && cardinality > maxCardinality; return changed && cardinality > maxCardinality;
}); });
if (rateLimitExceeded) { final boolean secondaryRateLimitExceeded;
if (secondaryCacheCluster != null) {
secondaryRateLimitExceeded = secondaryCacheCluster.withCluster(connection -> {
final boolean changed = connection.sync().pfadd(hllKey, target) == 1;
final long cardinality = connection.sync().pfcount(hllKey);
final boolean mayNeedExpiration = changed && cardinality == 1;
// If the set already existed, we can assume it already had an expiration time and can save a round trip by
// skipping the ttl check.
if (mayNeedExpiration && connection.sync().ttl(hllKey) == -1) {
final long expireSeconds = ttl.plusSeconds(random.nextInt((int) ttlJitter.toSeconds())).toSeconds();
connection.sync().expire(hllKey, expireSeconds);
}
return changed && cardinality > maxCardinality;
});
} else {
secondaryRateLimitExceeded = false;
}
if (rateLimitExceeded || secondaryRateLimitExceeded) {
// Using the TTL as the "retry after" time isn't EXACTLY right, but it's a reasonable approximation // Using the TTL as the "retry after" time isn't EXACTLY right, but it's a reasonable approximation
throw new RateLimitExceededException(ttl); throw new RateLimitExceededException(ttl);
} }

View File

@ -5,22 +5,29 @@
package org.whispersystems.textsecuregcm.limits; package org.whispersystems.textsecuregcm.limits;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.Meter; import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.SharedMetricRegistries;
import io.lettuce.core.SetArgs; import io.lettuce.core.SetArgs;
import java.time.Duration;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import java.time.Duration;
import static com.codahale.metrics.MetricRegistry.name;
public class LockingRateLimiter extends RateLimiter { public class LockingRateLimiter extends RateLimiter {
private final Meter meter; private final Meter meter;
public LockingRateLimiter(FaultTolerantRedisCluster cacheCluster, FaultTolerantRedisCluster secondaryCacheCluster, String name, int bucketSize, double leakRatePerMinute) {
super(cacheCluster, secondaryCacheCluster, name, bucketSize, leakRatePerMinute);
MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
this.meter = metricRegistry.meter(name(getClass(), name, "locked"));
}
public LockingRateLimiter(FaultTolerantRedisCluster cacheCluster, String name, int bucketSize, double leakRatePerMinute) { public LockingRateLimiter(FaultTolerantRedisCluster cacheCluster, String name, int bucketSize, double leakRatePerMinute) {
super(cacheCluster, name, bucketSize, leakRatePerMinute); super(cacheCluster, name, bucketSize, leakRatePerMinute);

View File

@ -4,12 +4,16 @@
*/ */
package org.whispersystems.textsecuregcm.limits; package org.whispersystems.textsecuregcm.limits;
import static com.codahale.metrics.MetricRegistry.name;
import com.codahale.metrics.Meter; import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer; import com.codahale.metrics.Timer;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import javax.annotation.Nullable;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration; import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration;
@ -18,10 +22,6 @@ import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import java.io.IOException;
import static com.codahale.metrics.MetricRegistry.name;
public class RateLimiter { public class RateLimiter {
private final Logger logger = LoggerFactory.getLogger(RateLimiter.class); private final Logger logger = LoggerFactory.getLogger(RateLimiter.class);
@ -34,28 +34,30 @@ public class RateLimiter {
private final int bucketSize; private final int bucketSize;
private final double leakRatePerMinute; private final double leakRatePerMinute;
private final double leakRatePerMillis; private final double leakRatePerMillis;
private final boolean reportLimits;
@Nullable
private final FaultTolerantRedisCluster secondaryCacheCluster;
public RateLimiter(FaultTolerantRedisCluster cacheCluster, String name, public RateLimiter(FaultTolerantRedisCluster cacheCluster, String name,
int bucketSize, double leakRatePerMinute) int bucketSize, double leakRatePerMinute)
{ {
this(cacheCluster, name, bucketSize, leakRatePerMinute, false); this(cacheCluster, null, name, bucketSize, leakRatePerMinute);
} }
public RateLimiter(FaultTolerantRedisCluster cacheCluster, String name, public RateLimiter(FaultTolerantRedisCluster cacheCluster, @Nullable FaultTolerantRedisCluster secondaryCacheCluster,
int bucketSize, double leakRatePerMinute, String name,
boolean reportLimits) int bucketSize, double leakRatePerMinute)
{ {
MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
this.meter = metricRegistry.meter(name(getClass(), name, "exceeded")); this.meter = metricRegistry.meter(name(getClass(), name, "exceeded"));
this.validateTimer = metricRegistry.timer(name(getClass(), name, "validate")); this.validateTimer = metricRegistry.timer(name(getClass(), name, "validate"));
this.cacheCluster = cacheCluster; this.cacheCluster = cacheCluster;
this.secondaryCacheCluster = secondaryCacheCluster;
this.name = name; this.name = name;
this.bucketSize = bucketSize; this.bucketSize = bucketSize;
this.leakRatePerMinute = leakRatePerMinute; this.leakRatePerMinute = leakRatePerMinute;
this.leakRatePerMillis = leakRatePerMinute / (60.0 * 1000.0); this.leakRatePerMillis = leakRatePerMinute / (60.0 * 1000.0);
this.reportLimits = reportLimits;
} }
public void validate(String key, int amount) throws RateLimitExceededException { public void validate(String key, int amount) throws RateLimitExceededException {
@ -77,6 +79,10 @@ public class RateLimiter {
public void clear(String key) { public void clear(String key) {
cacheCluster.useCluster(connection -> connection.sync().del(getBucketName(key))); cacheCluster.useCluster(connection -> connection.sync().del(getBucketName(key)));
if (secondaryCacheCluster != null) {
secondaryCacheCluster.useCluster(connection -> connection.sync().del(getBucketName(key)));
}
} }
public int getBucketSize() { public int getBucketSize() {
@ -88,13 +94,31 @@ public class RateLimiter {
} }
private void setBucket(String key, LeakyBucket bucket) { private void setBucket(String key, LeakyBucket bucket) {
IllegalArgumentException ex = null;
try { try {
final String serialized = bucket.serialize(mapper); final String serialized = bucket.serialize(mapper);
cacheCluster.useCluster(connection -> connection.sync().setex(getBucketName(key), (int) Math.ceil((bucketSize / leakRatePerMillis) / 1000), serialized)); cacheCluster.useCluster(connection -> connection.sync().setex(getBucketName(key), (int) Math.ceil((bucketSize / leakRatePerMillis) / 1000), serialized));
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
throw new IllegalArgumentException(e); ex = new IllegalArgumentException(e);
} }
if (secondaryCacheCluster != null) {
try {
final String serialized = bucket.serialize(mapper);
secondaryCacheCluster.useCluster(connection -> connection.sync()
.setex(getBucketName(key), (int) Math.ceil((bucketSize / leakRatePerMillis) / 1000), serialized));
} catch (JsonProcessingException e) {
ex = ex == null ? new IllegalArgumentException(e) : ex;
}
}
if (ex != null) {
throw ex;
}
} }
private LeakyBucket getBucket(String key) { private LeakyBucket getBucket(String key) {
@ -108,6 +132,16 @@ public class RateLimiter {
logger.warn("Deserialization error", e); logger.warn("Deserialization error", e);
} }
try {
final String serialized = secondaryCacheCluster.withCluster(connection -> connection.sync().get(getBucketName(key)));
if (serialized != null) {
return LeakyBucket.fromSerialized(mapper, serialized);
}
} catch (IOException e) {
logger.warn("Deserialization error", e);
}
return new LeakyBucket(bucketSize, leakRatePerMillis); return new LeakyBucket(bucketSize, leakRatePerMillis);
} }

View File

@ -11,6 +11,7 @@ import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.Ca
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration; import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import javax.annotation.Nullable;
public class RateLimiters { public class RateLimiters {
@ -41,10 +42,12 @@ public class RateLimiters {
private final AtomicReference<RateLimiter> unsealedIpLimiter; private final AtomicReference<RateLimiter> unsealedIpLimiter;
private final FaultTolerantRedisCluster cacheCluster; private final FaultTolerantRedisCluster cacheCluster;
private final FaultTolerantRedisCluster newCacheCluster;
private final DynamicConfigurationManager dynamicConfig; private final DynamicConfigurationManager dynamicConfig;
public RateLimiters(RateLimitsConfiguration config, DynamicConfigurationManager dynamicConfig, FaultTolerantRedisCluster cacheCluster) { public RateLimiters(RateLimitsConfiguration config, DynamicConfigurationManager dynamicConfig, FaultTolerantRedisCluster cacheCluster, FaultTolerantRedisCluster newCacheCluster) {
this.cacheCluster = cacheCluster; this.cacheCluster = cacheCluster;
this.newCacheCluster = newCacheCluster;
this.dynamicConfig = dynamicConfig; this.dynamicConfig = dynamicConfig;
this.smsDestinationLimiter = new RateLimiter(cacheCluster, "smsDestination", this.smsDestinationLimiter = new RateLimiter(cacheCluster, "smsDestination",
@ -67,11 +70,11 @@ public class RateLimiters {
config.getSmsVoicePrefix().getBucketSize(), config.getSmsVoicePrefix().getBucketSize(),
config.getSmsVoicePrefix().getLeakRatePerMinute()); config.getSmsVoicePrefix().getLeakRatePerMinute());
this.autoBlockLimiter = new RateLimiter(cacheCluster, "autoBlock", this.autoBlockLimiter = new RateLimiter(cacheCluster, newCacheCluster, "autoBlock",
config.getAutoBlock().getBucketSize(), config.getAutoBlock().getBucketSize(),
config.getAutoBlock().getLeakRatePerMinute()); config.getAutoBlock().getLeakRatePerMinute());
this.verifyLimiter = new LockingRateLimiter(cacheCluster, "verify", this.verifyLimiter = new LockingRateLimiter(cacheCluster, newCacheCluster, "verify",
config.getVerifyNumber().getBucketSize(), config.getVerifyNumber().getBucketSize(),
config.getVerifyNumber().getLeakRatePerMinute()); config.getVerifyNumber().getLeakRatePerMinute());
@ -103,7 +106,7 @@ public class RateLimiters {
config.getTurnAllocations().getBucketSize(), config.getTurnAllocations().getBucketSize(),
config.getTurnAllocations().getLeakRatePerMinute()); config.getTurnAllocations().getLeakRatePerMinute());
this.profileLimiter = new RateLimiter(cacheCluster, "profile", this.profileLimiter = new RateLimiter(cacheCluster, newCacheCluster, "profile",
config.getProfile().getBucketSize(), config.getProfile().getBucketSize(),
config.getProfile().getLeakRatePerMinute()); config.getProfile().getLeakRatePerMinute());
@ -119,8 +122,8 @@ public class RateLimiters {
config.getUsernameSet().getBucketSize(), config.getUsernameSet().getBucketSize(),
config.getUsernameSet().getLeakRatePerMinute()); config.getUsernameSet().getLeakRatePerMinute());
this.unsealedSenderLimiter = new AtomicReference<>(createUnsealedSenderLimiter(cacheCluster, dynamicConfig.getConfiguration().getLimits().getUnsealedSenderNumber())); this.unsealedSenderLimiter = new AtomicReference<>(createUnsealedSenderLimiter(cacheCluster, null, dynamicConfig.getConfiguration().getLimits().getUnsealedSenderNumber()));
this.unsealedIpLimiter = new AtomicReference<>(createUnsealedIpLimiter(cacheCluster, dynamicConfig.getConfiguration().getLimits().getUnsealedSenderIp())); this.unsealedIpLimiter = new AtomicReference<>(createUnsealedIpLimiter(cacheCluster, newCacheCluster, dynamicConfig.getConfiguration().getLimits().getUnsealedSenderIp()));
} }
public CardinalityRateLimiter getUnsealedSenderLimiter() { public CardinalityRateLimiter getUnsealedSenderLimiter() {
@ -130,7 +133,7 @@ public class RateLimiters {
if (rateLimiter.hasConfiguration(currentConfiguration)) { if (rateLimiter.hasConfiguration(currentConfiguration)) {
return rateLimiter; return rateLimiter;
} else { } else {
return createUnsealedSenderLimiter(cacheCluster, currentConfiguration); return createUnsealedSenderLimiter(cacheCluster, null, currentConfiguration);
} }
}); });
} }
@ -142,7 +145,7 @@ public class RateLimiters {
if (rateLimiter.hasConfiguration(currentConfiguration)) { if (rateLimiter.hasConfiguration(currentConfiguration)) {
return rateLimiter; return rateLimiter;
} else { } else {
return createUnsealedIpLimiter(cacheCluster, currentConfiguration); return createUnsealedIpLimiter(cacheCluster, newCacheCluster, currentConfiguration);
} }
}); });
} }
@ -219,18 +222,19 @@ public class RateLimiters {
return usernameSetLimiter; return usernameSetLimiter;
} }
private CardinalityRateLimiter createUnsealedSenderLimiter(FaultTolerantRedisCluster cacheCluster, CardinalityRateLimitConfiguration configuration) { private CardinalityRateLimiter createUnsealedSenderLimiter(FaultTolerantRedisCluster cacheCluster, FaultTolerantRedisCluster secondaryCacheCluster, CardinalityRateLimitConfiguration configuration) {
return new CardinalityRateLimiter(cacheCluster, "unsealedSender", configuration.getTtl(), configuration.getTtlJitter(), configuration.getMaxCardinality()); return new CardinalityRateLimiter(cacheCluster, secondaryCacheCluster, "unsealedSender", configuration.getTtl(), configuration.getTtlJitter(), configuration.getMaxCardinality());
} }
private RateLimiter createUnsealedIpLimiter(FaultTolerantRedisCluster cacheCluster, private RateLimiter createUnsealedIpLimiter(FaultTolerantRedisCluster cacheCluster,
@Nullable FaultTolerantRedisCluster secondaryCacheCluster,
RateLimitConfiguration configuration) RateLimitConfiguration configuration)
{ {
return createLimiter(cacheCluster, configuration, "unsealedIp"); return createLimiter(cacheCluster, secondaryCacheCluster, configuration, "unsealedIp");
} }
private RateLimiter createLimiter(FaultTolerantRedisCluster cacheCluster, RateLimitConfiguration configuration, String name) { private RateLimiter createLimiter(FaultTolerantRedisCluster cacheCluster, @Nullable FaultTolerantRedisCluster secondaryCacheCluster, RateLimitConfiguration configuration, String name) {
return new RateLimiter(cacheCluster, name, return new RateLimiter(cacheCluster, secondaryCacheCluster, name,
configuration.getBucketSize(), configuration.getBucketSize(),
configuration.getLeakRatePerMinute()); configuration.getLeakRatePerMinute());
} }

View File

@ -5,16 +5,16 @@
package org.whispersystems.textsecuregcm.limits; package org.whispersystems.textsecuregcm.limits;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.time.Duration;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
import java.time.Duration;
import static org.junit.Assert.*;
public class CardinalityRateLimiterTest extends AbstractRedisClusterTest { public class CardinalityRateLimiterTest extends AbstractRedisClusterTest {
@Before @Before
@ -30,7 +30,7 @@ public class CardinalityRateLimiterTest extends AbstractRedisClusterTest {
@Test @Test
public void testValidate() { public void testValidate() {
final int maxCardinality = 10; final int maxCardinality = 10;
final CardinalityRateLimiter rateLimiter = new CardinalityRateLimiter(getRedisCluster(), "test", Duration.ofDays(1), Duration.ofDays(1), maxCardinality); final CardinalityRateLimiter rateLimiter = new CardinalityRateLimiter(getRedisCluster(), null, "test", Duration.ofDays(1), Duration.ofDays(1), maxCardinality);
final String source = "+18005551234"; final String source = "+18005551234";
int validatedAttempts = 0; int validatedAttempts = 0;

View File

@ -1,5 +1,12 @@
package org.whispersystems.textsecuregcm.tests.limits; package org.whispersystems.textsecuregcm.tests.limits;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertSame;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.time.Duration;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration; import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration;
@ -11,22 +18,17 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import java.time.Duration;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class DynamicRateLimitsTest { public class DynamicRateLimitsTest {
private DynamicConfigurationManager dynamicConfig; private DynamicConfigurationManager dynamicConfig;
private FaultTolerantRedisCluster redisCluster; private FaultTolerantRedisCluster redisCluster;
private FaultTolerantRedisCluster newRedisCluster;
@Before @Before
public void setup() { public void setup() {
this.dynamicConfig = mock(DynamicConfigurationManager.class); this.dynamicConfig = mock(DynamicConfigurationManager.class);
this.redisCluster = mock(FaultTolerantRedisCluster.class); this.redisCluster = mock(FaultTolerantRedisCluster.class);
this.newRedisCluster = mock(FaultTolerantRedisCluster.class);
DynamicConfiguration defaultConfig = new DynamicConfiguration(); DynamicConfiguration defaultConfig = new DynamicConfiguration();
when(dynamicConfig.getConfiguration()).thenReturn(defaultConfig); when(dynamicConfig.getConfiguration()).thenReturn(defaultConfig);
@ -35,7 +37,7 @@ public class DynamicRateLimitsTest {
@Test @Test
public void testUnchangingConfiguration() { public void testUnchangingConfiguration() {
RateLimiters rateLimiters = new RateLimiters(new RateLimitsConfiguration(), dynamicConfig, redisCluster); RateLimiters rateLimiters = new RateLimiters(new RateLimitsConfiguration(), dynamicConfig, redisCluster, newRedisCluster);
RateLimiter limiter = rateLimiters.getUnsealedIpLimiter(); RateLimiter limiter = rateLimiters.getUnsealedIpLimiter();
@ -55,7 +57,7 @@ public class DynamicRateLimitsTest {
when(dynamicConfig.getConfiguration()).thenReturn(configuration); when(dynamicConfig.getConfiguration()).thenReturn(configuration);
RateLimiters rateLimiters = new RateLimiters(new RateLimitsConfiguration(), dynamicConfig, redisCluster); RateLimiters rateLimiters = new RateLimiters(new RateLimitsConfiguration(), dynamicConfig, redisCluster, newRedisCluster);
CardinalityRateLimiter limiter = rateLimiters.getUnsealedSenderLimiter(); CardinalityRateLimiter limiter = rateLimiters.getUnsealedSenderLimiter();