diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java index 2d3004286..405c8c1cc 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java @@ -109,6 +109,11 @@ public class WhisperServerConfiguration extends Configuration { @JsonProperty private RedisClusterConfiguration pushSchedulerCluster; + @NotNull + @Valid + @JsonProperty + private RedisClusterConfiguration rateLimitersCluster; + @NotNull @Valid @JsonProperty @@ -309,6 +314,10 @@ public class WhisperServerConfiguration extends Configuration { return pushSchedulerCluster; } + public RedisClusterConfiguration getRateLimitersCluster() { + return rateLimitersCluster; + } + public MessageDynamoDbConfiguration getMessageDynamoDbConfiguration() { return messageDynamoDb; } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index e96c4fec0..67a028ff8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -44,7 +44,6 @@ import io.micrometer.core.instrument.distribution.DistributionStatisticConfig; import io.micrometer.wavefront.WavefrontConfig; import io.micrometer.wavefront.WavefrontMeterRegistry; import java.net.http.HttpClient; -import java.security.Security; import java.time.Duration; import java.util.ArrayList; import java.util.Collections; @@ -345,6 +344,7 @@ public class WhisperServerService extends Application keyspaceNotificationDispatchQueue = new ArrayBlockingQueue<>(10_000); Metrics.gaugeCollectionSize(name(getClass(), "keyspaceNotificationDispatchQueueSize"), Collections.emptyList(), keyspaceNotificationDispatchQueue); @@ -403,7 +404,7 @@ public class WhisperServerService extends Application maxCardinality; }); - if (rateLimitExceeded) { + final boolean secondaryRateLimitExceeded; + if (secondaryCacheCluster != null) { + secondaryRateLimitExceeded = secondaryCacheCluster.withCluster(connection -> { + final boolean changed = connection.sync().pfadd(hllKey, target) == 1; + final long cardinality = connection.sync().pfcount(hllKey); + + final boolean mayNeedExpiration = changed && cardinality == 1; + + // If the set already existed, we can assume it already had an expiration time and can save a round trip by + // skipping the ttl check. + if (mayNeedExpiration && connection.sync().ttl(hllKey) == -1) { + final long expireSeconds = ttl.plusSeconds(random.nextInt((int) ttlJitter.toSeconds())).toSeconds(); + connection.sync().expire(hllKey, expireSeconds); + } + + return changed && cardinality > maxCardinality; + }); + } else { + secondaryRateLimitExceeded = false; + } + + if (rateLimitExceeded || secondaryRateLimitExceeded) { // Using the TTL as the "retry after" time isn't EXACTLY right, but it's a reasonable approximation throw new RateLimitExceededException(ttl); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java index 935cabfa3..1a2f46e0b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/LockingRateLimiter.java @@ -5,22 +5,29 @@ package org.whispersystems.textsecuregcm.limits; +import static com.codahale.metrics.MetricRegistry.name; + import com.codahale.metrics.Meter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; import io.lettuce.core.SetArgs; +import java.time.Duration; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.util.Constants; -import java.time.Duration; - -import static com.codahale.metrics.MetricRegistry.name; - public class LockingRateLimiter extends RateLimiter { private final Meter meter; + public LockingRateLimiter(FaultTolerantRedisCluster cacheCluster, FaultTolerantRedisCluster secondaryCacheCluster, String name, int bucketSize, double leakRatePerMinute) { + super(cacheCluster, secondaryCacheCluster, name, bucketSize, leakRatePerMinute); + + MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); + this.meter = metricRegistry.meter(name(getClass(), name, "locked")); + } + + public LockingRateLimiter(FaultTolerantRedisCluster cacheCluster, String name, int bucketSize, double leakRatePerMinute) { super(cacheCluster, name, bucketSize, leakRatePerMinute); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java index 5da6b6dc9..dfa01e7e1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiter.java @@ -4,12 +4,16 @@ */ package org.whispersystems.textsecuregcm.limits; +import static com.codahale.metrics.MetricRegistry.name; + import com.codahale.metrics.Meter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SharedMetricRegistries; import com.codahale.metrics.Timer; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration; @@ -18,10 +22,6 @@ import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.util.Constants; import org.whispersystems.textsecuregcm.util.SystemMapper; -import java.io.IOException; - -import static com.codahale.metrics.MetricRegistry.name; - public class RateLimiter { private final Logger logger = LoggerFactory.getLogger(RateLimiter.class); @@ -34,28 +34,30 @@ public class RateLimiter { private final int bucketSize; private final double leakRatePerMinute; private final double leakRatePerMillis; - private final boolean reportLimits; + + @Nullable + private final FaultTolerantRedisCluster secondaryCacheCluster; public RateLimiter(FaultTolerantRedisCluster cacheCluster, String name, int bucketSize, double leakRatePerMinute) { - this(cacheCluster, name, bucketSize, leakRatePerMinute, false); + this(cacheCluster, null, name, bucketSize, leakRatePerMinute); } - public RateLimiter(FaultTolerantRedisCluster cacheCluster, String name, - int bucketSize, double leakRatePerMinute, - boolean reportLimits) + public RateLimiter(FaultTolerantRedisCluster cacheCluster, @Nullable FaultTolerantRedisCluster secondaryCacheCluster, + String name, + int bucketSize, double leakRatePerMinute) { MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); this.meter = metricRegistry.meter(name(getClass(), name, "exceeded")); this.validateTimer = metricRegistry.timer(name(getClass(), name, "validate")); this.cacheCluster = cacheCluster; + this.secondaryCacheCluster = secondaryCacheCluster; this.name = name; this.bucketSize = bucketSize; this.leakRatePerMinute = leakRatePerMinute; this.leakRatePerMillis = leakRatePerMinute / (60.0 * 1000.0); - this.reportLimits = reportLimits; } public void validate(String key, int amount) throws RateLimitExceededException { @@ -77,6 +79,10 @@ public class RateLimiter { public void clear(String key) { cacheCluster.useCluster(connection -> connection.sync().del(getBucketName(key))); + + if (secondaryCacheCluster != null) { + secondaryCacheCluster.useCluster(connection -> connection.sync().del(getBucketName(key))); + } } public int getBucketSize() { @@ -88,13 +94,31 @@ public class RateLimiter { } private void setBucket(String key, LeakyBucket bucket) { + + IllegalArgumentException ex = null; try { final String serialized = bucket.serialize(mapper); cacheCluster.useCluster(connection -> connection.sync().setex(getBucketName(key), (int) Math.ceil((bucketSize / leakRatePerMillis) / 1000), serialized)); } catch (JsonProcessingException e) { - throw new IllegalArgumentException(e); + ex = new IllegalArgumentException(e); } + + if (secondaryCacheCluster != null) { + try { + final String serialized = bucket.serialize(mapper); + + secondaryCacheCluster.useCluster(connection -> connection.sync() + .setex(getBucketName(key), (int) Math.ceil((bucketSize / leakRatePerMillis) / 1000), serialized)); + } catch (JsonProcessingException e) { + ex = ex == null ? new IllegalArgumentException(e) : ex; + } + } + + if (ex != null) { + throw ex; + } + } private LeakyBucket getBucket(String key) { @@ -108,6 +132,16 @@ public class RateLimiter { logger.warn("Deserialization error", e); } + try { + final String serialized = secondaryCacheCluster.withCluster(connection -> connection.sync().get(getBucketName(key))); + + if (serialized != null) { + return LeakyBucket.fromSerialized(mapper, serialized); + } + } catch (IOException e) { + logger.warn("Deserialization error", e); + } + return new LeakyBucket(bucketSize, leakRatePerMillis); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java index a7c6f21aa..91a53ff25 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java @@ -11,6 +11,7 @@ import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.Ca import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration.RateLimitConfiguration; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; +import javax.annotation.Nullable; public class RateLimiters { @@ -41,11 +42,13 @@ public class RateLimiters { private final AtomicReference unsealedIpLimiter; private final FaultTolerantRedisCluster cacheCluster; + private final FaultTolerantRedisCluster newCacheCluster; private final DynamicConfigurationManager dynamicConfig; - public RateLimiters(RateLimitsConfiguration config, DynamicConfigurationManager dynamicConfig, FaultTolerantRedisCluster cacheCluster) { - this.cacheCluster = cacheCluster; - this.dynamicConfig = dynamicConfig; + public RateLimiters(RateLimitsConfiguration config, DynamicConfigurationManager dynamicConfig, FaultTolerantRedisCluster cacheCluster, FaultTolerantRedisCluster newCacheCluster) { + this.cacheCluster = cacheCluster; + this.newCacheCluster = newCacheCluster; + this.dynamicConfig = dynamicConfig; this.smsDestinationLimiter = new RateLimiter(cacheCluster, "smsDestination", config.getSmsDestination().getBucketSize(), @@ -67,11 +70,11 @@ public class RateLimiters { config.getSmsVoicePrefix().getBucketSize(), config.getSmsVoicePrefix().getLeakRatePerMinute()); - this.autoBlockLimiter = new RateLimiter(cacheCluster, "autoBlock", + this.autoBlockLimiter = new RateLimiter(cacheCluster, newCacheCluster, "autoBlock", config.getAutoBlock().getBucketSize(), config.getAutoBlock().getLeakRatePerMinute()); - this.verifyLimiter = new LockingRateLimiter(cacheCluster, "verify", + this.verifyLimiter = new LockingRateLimiter(cacheCluster, newCacheCluster, "verify", config.getVerifyNumber().getBucketSize(), config.getVerifyNumber().getLeakRatePerMinute()); @@ -103,7 +106,7 @@ public class RateLimiters { config.getTurnAllocations().getBucketSize(), config.getTurnAllocations().getLeakRatePerMinute()); - this.profileLimiter = new RateLimiter(cacheCluster, "profile", + this.profileLimiter = new RateLimiter(cacheCluster, newCacheCluster, "profile", config.getProfile().getBucketSize(), config.getProfile().getLeakRatePerMinute()); @@ -119,8 +122,8 @@ public class RateLimiters { config.getUsernameSet().getBucketSize(), config.getUsernameSet().getLeakRatePerMinute()); - this.unsealedSenderLimiter = new AtomicReference<>(createUnsealedSenderLimiter(cacheCluster, dynamicConfig.getConfiguration().getLimits().getUnsealedSenderNumber())); - this.unsealedIpLimiter = new AtomicReference<>(createUnsealedIpLimiter(cacheCluster, dynamicConfig.getConfiguration().getLimits().getUnsealedSenderIp())); + this.unsealedSenderLimiter = new AtomicReference<>(createUnsealedSenderLimiter(cacheCluster, null, dynamicConfig.getConfiguration().getLimits().getUnsealedSenderNumber())); + this.unsealedIpLimiter = new AtomicReference<>(createUnsealedIpLimiter(cacheCluster, newCacheCluster, dynamicConfig.getConfiguration().getLimits().getUnsealedSenderIp())); } public CardinalityRateLimiter getUnsealedSenderLimiter() { @@ -130,7 +133,7 @@ public class RateLimiters { if (rateLimiter.hasConfiguration(currentConfiguration)) { return rateLimiter; } else { - return createUnsealedSenderLimiter(cacheCluster, currentConfiguration); + return createUnsealedSenderLimiter(cacheCluster, null, currentConfiguration); } }); } @@ -142,7 +145,7 @@ public class RateLimiters { if (rateLimiter.hasConfiguration(currentConfiguration)) { return rateLimiter; } else { - return createUnsealedIpLimiter(cacheCluster, currentConfiguration); + return createUnsealedIpLimiter(cacheCluster, newCacheCluster, currentConfiguration); } }); } @@ -219,18 +222,19 @@ public class RateLimiters { return usernameSetLimiter; } - private CardinalityRateLimiter createUnsealedSenderLimiter(FaultTolerantRedisCluster cacheCluster, CardinalityRateLimitConfiguration configuration) { - return new CardinalityRateLimiter(cacheCluster, "unsealedSender", configuration.getTtl(), configuration.getTtlJitter(), configuration.getMaxCardinality()); + private CardinalityRateLimiter createUnsealedSenderLimiter(FaultTolerantRedisCluster cacheCluster, FaultTolerantRedisCluster secondaryCacheCluster, CardinalityRateLimitConfiguration configuration) { + return new CardinalityRateLimiter(cacheCluster, secondaryCacheCluster, "unsealedSender", configuration.getTtl(), configuration.getTtlJitter(), configuration.getMaxCardinality()); } private RateLimiter createUnsealedIpLimiter(FaultTolerantRedisCluster cacheCluster, + @Nullable FaultTolerantRedisCluster secondaryCacheCluster, RateLimitConfiguration configuration) { - return createLimiter(cacheCluster, configuration, "unsealedIp"); + return createLimiter(cacheCluster, secondaryCacheCluster, configuration, "unsealedIp"); } - private RateLimiter createLimiter(FaultTolerantRedisCluster cacheCluster, RateLimitConfiguration configuration, String name) { - return new RateLimiter(cacheCluster, name, + private RateLimiter createLimiter(FaultTolerantRedisCluster cacheCluster, @Nullable FaultTolerantRedisCluster secondaryCacheCluster, RateLimitConfiguration configuration, String name) { + return new RateLimiter(cacheCluster, secondaryCacheCluster, name, configuration.getBucketSize(), configuration.getLeakRatePerMinute()); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/CardinalityRateLimiterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/CardinalityRateLimiterTest.java index 725d09628..03311d84c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/CardinalityRateLimiterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/CardinalityRateLimiterTest.java @@ -5,16 +5,16 @@ package org.whispersystems.textsecuregcm.limits; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.time.Duration; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; -import java.time.Duration; - -import static org.junit.Assert.*; - public class CardinalityRateLimiterTest extends AbstractRedisClusterTest { @Before @@ -30,7 +30,7 @@ public class CardinalityRateLimiterTest extends AbstractRedisClusterTest { @Test public void testValidate() { final int maxCardinality = 10; - final CardinalityRateLimiter rateLimiter = new CardinalityRateLimiter(getRedisCluster(), "test", Duration.ofDays(1), Duration.ofDays(1), maxCardinality); + final CardinalityRateLimiter rateLimiter = new CardinalityRateLimiter(getRedisCluster(), null, "test", Duration.ofDays(1), Duration.ofDays(1), maxCardinality); final String source = "+18005551234"; int validatedAttempts = 0; diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/limits/DynamicRateLimitsTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/limits/DynamicRateLimitsTest.java index 4c12e4412..1e09044c1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/limits/DynamicRateLimitsTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/limits/DynamicRateLimitsTest.java @@ -1,5 +1,12 @@ package org.whispersystems.textsecuregcm.tests.limits; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.time.Duration; import org.junit.Before; import org.junit.Test; import org.whispersystems.textsecuregcm.configuration.RateLimitsConfiguration; @@ -11,22 +18,17 @@ import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; -import java.time.Duration; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.Assert.*; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - public class DynamicRateLimitsTest { private DynamicConfigurationManager dynamicConfig; private FaultTolerantRedisCluster redisCluster; + private FaultTolerantRedisCluster newRedisCluster; @Before public void setup() { - this.dynamicConfig = mock(DynamicConfigurationManager.class); - this.redisCluster = mock(FaultTolerantRedisCluster.class); + this.dynamicConfig = mock(DynamicConfigurationManager.class); + this.redisCluster = mock(FaultTolerantRedisCluster.class); + this.newRedisCluster = mock(FaultTolerantRedisCluster.class); DynamicConfiguration defaultConfig = new DynamicConfiguration(); when(dynamicConfig.getConfiguration()).thenReturn(defaultConfig); @@ -35,7 +37,7 @@ public class DynamicRateLimitsTest { @Test public void testUnchangingConfiguration() { - RateLimiters rateLimiters = new RateLimiters(new RateLimitsConfiguration(), dynamicConfig, redisCluster); + RateLimiters rateLimiters = new RateLimiters(new RateLimitsConfiguration(), dynamicConfig, redisCluster, newRedisCluster); RateLimiter limiter = rateLimiters.getUnsealedIpLimiter(); @@ -55,7 +57,7 @@ public class DynamicRateLimitsTest { when(dynamicConfig.getConfiguration()).thenReturn(configuration); - RateLimiters rateLimiters = new RateLimiters(new RateLimitsConfiguration(), dynamicConfig, redisCluster); + RateLimiters rateLimiters = new RateLimiters(new RateLimitsConfiguration(), dynamicConfig, redisCluster, newRedisCluster); CardinalityRateLimiter limiter = rateLimiters.getUnsealedSenderLimiter();