diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index f1d19d6fc..0e4d13769 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -643,13 +643,14 @@ public class MessageController { UUID spamReporterUuid = auth.getAccount().getUuid(); // spam report token is optional, but if provided ensure it is valid base64. - @Nullable final byte[] spamReportToken = spamReport != null ? spamReport.token() : null; + final Optional maybeSpamReportToken = + spamReport != null ? Optional.of(spamReport.token()) : Optional.empty(); // fire-and-forget: we don't want to block the response on this action. CompletableFuture ignored = - reportSpamTokenHandler.handle(sourceNumber, sourceAci, sourcePni, messageGuid, spamReporterUuid, spamReportToken); + reportSpamTokenHandler.handle(sourceNumber, sourceAci, sourcePni, messageGuid, spamReporterUuid, maybeSpamReportToken.orElse(null)); - reportMessageManager.report(sourceNumber, sourceAci, sourcePni, messageGuid, spamReporterUuid); + reportMessageManager.report(sourceNumber, sourceAci, sourcePni, messageGuid, spamReporterUuid, maybeSpamReportToken); return Response.status(Status.ACCEPTED) .build(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/ReportedMessageMetricsListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/ReportedMessageMetricsListener.java index f0c83e0f1..3c1c6eb85 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/ReportedMessageMetricsListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/ReportedMessageMetricsListener.java @@ -9,6 +9,7 @@ import static com.codahale.metrics.MetricRegistry.name; import io.micrometer.core.instrument.Metrics; import java.util.Map; +import java.util.Optional; import java.util.UUID; import net.logstash.logback.marker.Markers; import org.slf4j.Logger; @@ -35,7 +36,9 @@ public class ReportedMessageMetricsListener implements ReportedMessageListener { } @Override - public void handleMessageReported(final String sourceNumber, final UUID messageGuid, final UUID reporterUuid) { + public void handleMessageReported(final String sourceNumber, final UUID messageGuid, final UUID reporterUuid, + final Optional reportSpamToken) { + final String sourceCountryCode = Util.getCountryCode(sourceNumber); Metrics.counter(REPORTED_COUNTER_NAME, COUNTRY_CODE_TAG_NAME, sourceCountryCode).increment(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageManager.java index 810f4768a..7ff4e3bd9 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageManager.java @@ -15,8 +15,10 @@ import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.UUID; +import io.micrometer.core.instrument.Metrics; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.util.UUIDUtil; @@ -29,6 +31,10 @@ public class ReportMessageManager { private final List reportedMessageListeners = new ArrayList<>(); + private static final String REPORT_MESSAGE_COUNTER_NAME = MetricsUtil.name(ReportMessageManager.class); + private static final String FOUND_MESSAGE_TAG = "foundMessage"; + private static final String TOKEN_PRESENT_TAG = "hasReportSpamToken"; + private static final Logger logger = LoggerFactory.getLogger(ReportMessageManager.class); public ReportMessageManager(final ReportMessageDynamoDb reportMessageDynamoDb, @@ -56,12 +62,21 @@ public class ReportMessageManager { } } - public void report(Optional sourceNumber, Optional sourceAci, Optional sourcePni, - UUID messageGuid, UUID reporterUuid) { + public void report(final Optional sourceNumber, + final Optional sourceAci, + final Optional sourcePni, + final UUID messageGuid, + final UUID reporterUuid, + final Optional reportSpamToken) { final boolean found = sourceAci.map(uuid -> reportMessageDynamoDb.remove(hash(messageGuid, uuid.toString()))) .orElse(false); + Metrics.counter(REPORT_MESSAGE_COUNTER_NAME, + FOUND_MESSAGE_TAG, String.valueOf(found), + TOKEN_PRESENT_TAG, String.valueOf(reportSpamToken.isPresent())) + .increment(); + if (found) { rateLimitCluster.useCluster(connection -> { sourcePni.ifPresent(pni -> { @@ -80,7 +95,7 @@ public class ReportMessageManager { sourceNumber.ifPresent(number -> reportedMessageListeners.forEach(listener -> { try { - listener.handleMessageReported(number, messageGuid, reporterUuid); + listener.handleMessageReported(number, messageGuid, reporterUuid, reportSpamToken); } catch (final Exception e) { logger.error("Failed to notify listener of reported message", e); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportedMessageListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportedMessageListener.java index 1dc3eee02..96e580b11 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportedMessageListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportedMessageListener.java @@ -5,9 +5,10 @@ package org.whispersystems.textsecuregcm.storage; +import java.util.Optional; import java.util.UUID; public interface ReportedMessageListener { - void handleMessageReported(String sourceNumber, UUID messageGuid, UUID reporterUuid); + void handleMessageReported(String sourceNumber, UUID messageGuid, UUID reporterUuid, Optional reportSpamToken); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index 36abcb6a1..55eeb9fcd 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -14,6 +14,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.anyBoolean; import static org.mockito.Mockito.anyString; @@ -67,6 +68,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatcher; import org.whispersystems.textsecuregcm.spam.ReportSpamTokenHandler; import org.whispersystems.textsecuregcm.spam.ReportSpamTokenProvider; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; @@ -634,7 +636,7 @@ class MessageControllerTest { assertThat(response.getStatus(), is(equalTo(202))); verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni), - messageGuid, AuthHelper.VALID_UUID); + messageGuid, AuthHelper.VALID_UUID, Optional.empty()); verify(deletedAccountsManager, never()).findDeletedAccountE164(any(UUID.class)); verify(accountsManager, never()).getPhoneNumberIdentifier(anyString()); @@ -651,7 +653,7 @@ class MessageControllerTest { assertThat(response.getStatus(), is(equalTo(202))); verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni), - messageGuid, AuthHelper.VALID_UUID); + messageGuid, AuthHelper.VALID_UUID, Optional.empty()); } @Test @@ -681,7 +683,7 @@ class MessageControllerTest { assertThat(response.getStatus(), is(equalTo(202))); verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni), - messageGuid, AuthHelper.VALID_UUID); + messageGuid, AuthHelper.VALID_UUID, Optional.empty()); verify(deletedAccountsManager, never()).findDeletedAccountE164(any(UUID.class)); verify(accountsManager, never()).getPhoneNumberIdentifier(anyString()); @@ -699,7 +701,7 @@ class MessageControllerTest { assertThat(response.getStatus(), is(equalTo(202))); verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni), - messageGuid, AuthHelper.VALID_UUID); + messageGuid, AuthHelper.VALID_UUID, Optional.empty()); } @Test @@ -733,8 +735,12 @@ class MessageControllerTest { assertThat(response.getStatus(), is(equalTo(202))); verify(REPORT_SPAM_TOKEN_HANDLER).handle(any(), any(), any(), any(), any(), captor.capture()); assertArrayEquals(new byte[3], captor.getValue()); - verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni), - messageGuid, AuthHelper.VALID_UUID); + verify(reportMessageManager).report(eq(Optional.of(senderNumber)), + eq(Optional.of(senderAci)), + eq(Optional.of(senderPni)), + eq(messageGuid), + eq(AuthHelper.VALID_UUID), + argThat(maybeBytes -> maybeBytes.map(bytes -> Arrays.equals(bytes, new byte[3])).orElse(false))); verify(deletedAccountsManager, never()).findDeletedAccountE164(any(UUID.class)); verify(accountsManager, never()).getPhoneNumberIdentifier(anyString()); when(accountsManager.getByAccountIdentifier(senderAci)).thenReturn(Optional.empty()); @@ -754,8 +760,12 @@ class MessageControllerTest { assertThat(response.getStatus(), is(equalTo(202))); verify(REPORT_SPAM_TOKEN_HANDLER).handle(any(), any(), any(), any(), any(), captor.capture()); assertArrayEquals(new byte[5], captor.getValue()); - verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni), - messageGuid, AuthHelper.VALID_UUID); + verify(reportMessageManager).report(eq(Optional.of(senderNumber)), + eq(Optional.of(senderAci)), + eq(Optional.of(senderPni)), + eq(messageGuid), + eq(AuthHelper.VALID_UUID), + argThat(maybeBytes -> maybeBytes.map(bytes -> Arrays.equals(bytes, new byte[5])).orElse(false))); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageManagerTest.java index cd1345fc3..c5599e3b1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageManagerTest.java @@ -81,16 +81,16 @@ class ReportMessageManagerTest { when(reportMessageDynamoDb.remove(any())).thenReturn(false); reportMessageManager.report(Optional.of(sourceNumber), Optional.of(sourceAci), Optional.of(sourcePni), messageGuid, - reporterUuid); + reporterUuid, Optional.empty()); assertEquals(0, reportMessageManager.getRecentReportCount(sourceAccount)); when(reportMessageDynamoDb.remove(any())).thenReturn(true); reportMessageManager.report(Optional.of(sourceNumber), Optional.of(sourceAci), Optional.of(sourcePni), messageGuid, - reporterUuid); + reporterUuid, Optional.empty()); assertEquals(1, reportMessageManager.getRecentReportCount(sourceAccount)); - verify(listener).handleMessageReported(sourceNumber, messageGuid, reporterUuid); + verify(listener).handleMessageReported(sourceNumber, messageGuid, reporterUuid, Optional.empty()); } @Test @@ -100,7 +100,7 @@ class ReportMessageManagerTest { for (int i = 0; i < 100; i++) { reportMessageManager.report(Optional.of(sourceNumber), Optional.of(sourceAci), Optional.of(sourcePni), - messageGuid, UUID.randomUUID()); + messageGuid, UUID.randomUUID(), Optional.empty()); } assertTrue(reportMessageManager.getRecentReportCount(sourceAccount) > 10); @@ -114,7 +114,7 @@ class ReportMessageManagerTest { for (int i = 0; i < 100; i++) { reportMessageManager.report(Optional.of(sourceNumber), Optional.of(sourceAci), Optional.of(sourcePni), messageGuid, - reporterUuid); + reporterUuid, Optional.empty()); } assertEquals(1, reportMessageManager.getRecentReportCount(sourceAccount)); @@ -127,11 +127,11 @@ class ReportMessageManagerTest { for (int i = 0; i < 100; i++) { reportMessageManager.report(Optional.empty(), Optional.of(sourceAci), Optional.of(sourcePni), - messageGuid, UUID.randomUUID()); + messageGuid, UUID.randomUUID(), Optional.empty()); } reportMessageManager.report(Optional.empty(), Optional.of(sourceAci), Optional.empty(), - messageGuid, UUID.randomUUID()); + messageGuid, UUID.randomUUID(), Optional.empty()); final int recentReportCount = reportMessageManager.getRecentReportCount(sourceAccount); assertTrue(recentReportCount > 10);