From 6919354520ba159e82b55d5c933a24c4d0675032 Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Tue, 23 Nov 2021 17:28:39 -0500 Subject: [PATCH] Fix a counting bug with reported messages --- .../controllers/MessageController.java | 2 +- .../storage/ReportMessageManager.java | 4 +- .../controllers/MessageControllerTest.java | 2 +- .../storage/ReportMessageManagerTest.java | 61 ++++++++++++++++--- 4 files changed, 58 insertions(+), 11 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 036da397c..1ededf174 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -503,7 +503,7 @@ public class MessageController { public Response reportMessage(@Auth AuthenticatedAccount auth, @PathParam("sourceNumber") String sourceNumber, @PathParam("messageGuid") UUID messageGuid) { - reportMessageManager.report(sourceNumber, messageGuid); + reportMessageManager.report(sourceNumber, messageGuid, auth.getAccount().getUuid()); return Response.status(Status.ACCEPTED) .build(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageManager.java index 347f0d1a4..fa972de3b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageManager.java @@ -54,7 +54,7 @@ public class ReportMessageManager { } } - public void report(String sourceNumber, UUID messageGuid) { + public void report(String sourceNumber, UUID messageGuid, UUID reporterUuid) { final boolean found = reportMessageDynamoDb.remove(hash(messageGuid, sourceNumber)); @@ -62,7 +62,7 @@ public class ReportMessageManager { rateLimitCluster.useCluster(connection -> { final String reportedSenderKey = getReportedSenderKey(sourceNumber); - connection.sync().pfadd(reportedSenderKey, sourceNumber); + connection.sync().pfadd(reportedSenderKey, reporterUuid.toString()); connection.sync().expire(reportedSenderKey, counterTtl.toSeconds()); }); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 69e59a7be..08e86097c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -575,7 +575,7 @@ class MessageControllerTest { assertThat(response.getStatus(), is(equalTo(202))); - verify(reportMessageManager).report(senderNumber, messageGuid); + verify(reportMessageManager).report(senderNumber, messageGuid, AuthHelper.VALID_UUID); } static Account mockAccountWithDeviceAndRegId(Object... deviceAndRegistrationIds) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageManagerTest.java index 44ed907b8..b3ed73664 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageManagerTest.java @@ -2,6 +2,7 @@ package org.whispersystems.textsecuregcm.storage; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; @@ -14,16 +15,29 @@ import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.simple.SimpleMeterRegistry; import java.time.Duration; import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; class ReportMessageManagerTest { - private final ReportMessageDynamoDb reportMessageDynamoDb = mock(ReportMessageDynamoDb.class); - private final MeterRegistry meterRegistry = new SimpleMeterRegistry(); + private ReportMessageDynamoDb reportMessageDynamoDb; + private MeterRegistry meterRegistry; - private final ReportMessageManager reportMessageManager = new ReportMessageManager(reportMessageDynamoDb, - mock(FaultTolerantRedisCluster.class), meterRegistry, Duration.ofDays(1)); + private ReportMessageManager reportMessageManager; + + @RegisterExtension + static RedisClusterExtension RATE_LIMIT_CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); + + @BeforeEach + void setUp() { + reportMessageDynamoDb = mock(ReportMessageDynamoDb.class); + meterRegistry = new SimpleMeterRegistry(); + + reportMessageManager = new ReportMessageManager(reportMessageDynamoDb, + RATE_LIMIT_CLUSTER_EXTENSION.getRedisCluster(), meterRegistry, Duration.ofDays(1)); + } @Test void testStore() { @@ -49,16 +63,19 @@ class ReportMessageManagerTest { void testReport() { final String sourceNumber = "+15105551111"; final UUID messageGuid = UUID.randomUUID(); + final UUID reporterUuid = UUID.randomUUID(); when(reportMessageDynamoDb.remove(any())).thenReturn(false); - reportMessageManager.report(sourceNumber, messageGuid); + reportMessageManager.report(sourceNumber, messageGuid, reporterUuid); assertEquals(0, getCounterTotal(ReportMessageManager.REPORT_COUNTER_NAME)); + assertEquals(0, reportMessageManager.getRecentReportCount(sourceNumber)); when(reportMessageDynamoDb.remove(any())).thenReturn(true); - reportMessageManager.report(sourceNumber, messageGuid); + reportMessageManager.report(sourceNumber, messageGuid, reporterUuid); assertEquals(1, getCounterTotal(ReportMessageManager.REPORT_COUNTER_NAME)); + assertEquals(1, reportMessageManager.getRecentReportCount(sourceNumber)); } private double getCounterTotal(final String counterName) { @@ -68,4 +85,34 @@ class ReportMessageManagerTest { .orElse(0.0); } + @Test + void testReportMultipleReporters() { + final String sourceNumber = "+15105551111"; + final UUID messageGuid = UUID.randomUUID(); + + when(reportMessageDynamoDb.remove(any())).thenReturn(true); + assertEquals(0, reportMessageManager.getRecentReportCount(sourceNumber)); + + for (int i = 0; i < 100; i++) { + reportMessageManager.report(sourceNumber, messageGuid, UUID.randomUUID()); + } + + assertTrue(reportMessageManager.getRecentReportCount(sourceNumber) > 10); + } + + @Test + void testReportSingleReporter() { + final String sourceNumber = "+15105551111"; + final UUID messageGuid = UUID.randomUUID(); + final UUID reporterUuid = UUID.randomUUID(); + + when(reportMessageDynamoDb.remove(any())).thenReturn(true); + assertEquals(0, reportMessageManager.getRecentReportCount(sourceNumber)); + + for (int i = 0; i < 100; i++) { + reportMessageManager.report(sourceNumber, messageGuid, reporterUuid); + } + + assertEquals(1, reportMessageManager.getRecentReportCount(sourceNumber)); + } }