diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java index a585bbb72..69728f76c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerConfiguration.java @@ -35,6 +35,7 @@ import org.whispersystems.textsecuregcm.configuration.GcpAttachmentsConfiguratio import org.whispersystems.textsecuregcm.configuration.GenericZkConfig; import org.whispersystems.textsecuregcm.configuration.HCaptchaConfiguration; import org.whispersystems.textsecuregcm.configuration.MaxDeviceConfiguration; +import org.whispersystems.textsecuregcm.configuration.MessageByteLimitCardinalityEstimatorConfiguration; import org.whispersystems.textsecuregcm.configuration.MessageCacheConfiguration; import org.whispersystems.textsecuregcm.configuration.OneTimeDonationConfiguration; import org.whispersystems.textsecuregcm.configuration.PaymentsServiceConfiguration; @@ -288,6 +289,11 @@ public class WhisperServerConfiguration extends Configuration { @JsonProperty private ClientReleaseConfiguration clientRelease = new ClientReleaseConfiguration(Duration.ofHours(4)); + @Valid + @NotNull + @JsonProperty + private MessageByteLimitCardinalityEstimatorConfiguration messageByteLimitCardinalityEstimator = new MessageByteLimitCardinalityEstimatorConfiguration(Duration.ofDays(1)); + public AdminEventLoggingConfiguration getAdminEventLoggingConfiguration() { return adminEventLoggingConfiguration; } @@ -478,4 +484,8 @@ public class WhisperServerConfiguration extends Configuration { public ClientReleaseConfiguration getClientReleaseConfiguration() { return clientRelease; } + + public MessageByteLimitCardinalityEstimatorConfiguration getMessageByteLimitCardinalityEstimator() { + return messageByteLimitCardinalityEstimator; + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 223290f16..e80644169 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -123,6 +123,7 @@ import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter; import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter; import org.whispersystems.textsecuregcm.grpc.KeysGrpcService; import org.whispersystems.textsecuregcm.grpc.KeysAnonymousGrpcService; +import org.whispersystems.textsecuregcm.limits.CardinalityEstimator; import org.whispersystems.textsecuregcm.limits.PushChallengeManager; import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager; import org.whispersystems.textsecuregcm.limits.RateLimiters; @@ -571,6 +572,11 @@ public class WhisperServerService extends Application dynamicConfigurationManager) { this.rateLimiters = rateLimiters; + this.messageByteLimitEstimator = messageByteLimitEstimator; this.messageSender = messageSender; this.receiptSender = receiptSender; this.accountsManager = accountsManager; @@ -237,6 +241,7 @@ public class MessageController { rateLimiters.getInboundMessageBytes().validate(destinationIdentifier.uuid(), totalContentLength); } catch (final RateLimitExceededException e) { if (dynamicConfigurationManager.getConfiguration().getInboundMessageByteLimitConfiguration().enforceInboundLimit()) { + messageByteLimitEstimator.add(destinationIdentifier.uuid().toString()); throw e; } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/CardinalityEstimator.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/CardinalityEstimator.java new file mode 100644 index 000000000..40c62dc5d --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/CardinalityEstimator.java @@ -0,0 +1,76 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.limits; + +import com.google.common.annotations.VisibleForTesting; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Tags; +import java.time.Duration; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import org.whispersystems.textsecuregcm.util.Util; + +/** + * Estimate the number of unique items seen over a configurable period and update a metric + */ +public class CardinalityEstimator { + + private volatile double uniqueElementCount; + private final FaultTolerantRedisCluster redisCluster; + private final String hllName; + private final Duration period; + + public CardinalityEstimator(final FaultTolerantRedisCluster redisCluster, final String name, final Duration period) { + this.redisCluster = redisCluster; + this.hllName = "cardinality_estimator::" + name; + this.period = period; + Metrics.gauge( + MetricsUtil.name(getClass(), "unique"), + Tags.of("name", name), + this, + obj -> obj.uniqueElementCount); + } + + public void add(String element) { + addAsync(element).toCompletableFuture().join(); + } + + public CompletionStage addAsync(String element) { + return redisCluster.withCluster(connection -> connection.async() + .pfadd(hllName, element) + .thenCompose(modCount -> { + if (modCount == 0) { + return CompletableFuture.completedFuture(false); + } + + // The hll changed - update our local view of the cardinality, and + // initialize the TTL if required + return connection.async() + .pfcount(hllName) + .thenCompose(count -> { + uniqueElementCount = count; + // check if this is a new hll with no TTL set + return connection.async().ttl(hllName).thenApply(ttl -> ttl == -1); + }); + }) + .thenCompose(isNewHll -> { + if (!isNewHll) { + return CompletableFuture.completedFuture(null); + } + + // If this is a new hll, we need to set the TTL. This could be + // a single atomic op in redis 7.x with EXPIRE NX + return connection.async().expire(hllName, period).thenRun(Util.NOOP); + })); + } + + @VisibleForTesting + long estimate() { + return (long) this.uniqueElementCount; + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 993e46ea6..937fdfc91 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -100,6 +100,7 @@ import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.identity.AciServiceIdentifier; import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.identity.ServiceIdentifier; +import org.whispersystems.textsecuregcm.limits.CardinalityEstimator; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; @@ -161,6 +162,7 @@ class MessageControllerTest { private static final DeletedAccounts deletedAccounts = mock(DeletedAccounts.class); private static final MessagesManager messagesManager = mock(MessagesManager.class); private static final RateLimiters rateLimiters = mock(RateLimiters.class); + private static final CardinalityEstimator cardinalityEstimator = mock(CardinalityEstimator.class); private static final RateLimiter rateLimiter = mock(RateLimiter.class); private static final PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class); private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class); @@ -177,7 +179,7 @@ class MessageControllerTest { .addProvider(MultiRecipientMessageProvider.class) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource( - new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, deletedAccounts, + new MessageController(rateLimiters, cardinalityEstimator, messageSender, receiptSender, accountsManager, deletedAccounts, messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor, messageDeliveryScheduler, ReportSpamTokenProvider.noop(), mock(ClientReleaseManager.class), dynamicConfigurationManager)) .build(); @@ -247,6 +249,7 @@ class MessageControllerTest { messagesManager, rateLimiters, rateLimiter, + cardinalityEstimator, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/CardinalityEstimatorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/CardinalityEstimatorTest.java new file mode 100644 index 000000000..643d2a9fa --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/CardinalityEstimatorTest.java @@ -0,0 +1,53 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.limits; + +import static org.assertj.core.api.Assertions.assertThat; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; +import java.time.Duration; + +public class CardinalityEstimatorTest { + + @RegisterExtension + private static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); + + @Test + public void testAdd() throws Exception { + final FaultTolerantRedisCluster redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster(); + final CardinalityEstimator estimator = new CardinalityEstimator(redisCluster, "test", Duration.ofSeconds(1)); + + estimator.add("1"); + + long count = redisCluster.withCluster(conn -> conn.sync().pfcount("cardinality_estimator::test")); + assertThat(count).isEqualTo(1).isEqualTo(estimator.estimate()); + + estimator.add("2"); + count = redisCluster.withCluster(conn -> conn.sync().pfcount("cardinality_estimator::test")); + assertThat(count).isEqualTo(2).isEqualTo(estimator.estimate()); + + estimator.add("1"); + count = redisCluster.withCluster(conn -> conn.sync().pfcount("cardinality_estimator::test")); + assertThat(count).isEqualTo(2).isEqualTo(estimator.estimate()); + } + + @Test + @Timeout(5) + public void testEventuallyExpires() throws InterruptedException { + final FaultTolerantRedisCluster redisCluster = REDIS_CLUSTER_EXTENSION.getRedisCluster(); + final CardinalityEstimator estimator = new CardinalityEstimator(redisCluster, "test", Duration.ofMillis(100)); + estimator.add("1"); + long count; + do { + count = redisCluster.withCluster(conn -> conn.sync().pfcount("cardinality_estimator::test")); + Thread.sleep(1); + } while (count != 0); + } + +}