diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 87a16fb51..406491ee2 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -603,7 +603,8 @@ public class MessageController { @Auth AuthenticatedAccount auth, @PathParam("source") String source, @PathParam("messageGuid") UUID messageGuid, - @Nullable @Valid SpamReport spamReport + @Nullable @Valid SpamReport spamReport, + @HeaderParam(HttpHeaders.USER_AGENT) String userAgent ) { final Optional sourceNumber; @@ -640,7 +641,7 @@ public class MessageController { final Optional maybeSpamReportToken = spamReport != null ? Optional.of(spamReport.token()) : Optional.empty(); - reportMessageManager.report(sourceNumber, sourceAci, sourcePni, messageGuid, spamReporterUuid, maybeSpamReportToken); + reportMessageManager.report(sourceNumber, sourceAci, sourcePni, messageGuid, spamReporterUuid, maybeSpamReportToken, userAgent); return Response.status(Status.ACCEPTED) .build(); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageManager.java index 08f9307e1..26ebe939b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/ReportMessageManager.java @@ -16,9 +16,11 @@ import java.util.Objects; import java.util.Optional; import java.util.UUID; import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Tags; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; import org.whispersystems.textsecuregcm.util.UUIDUtil; @@ -67,14 +69,16 @@ public class ReportMessageManager { final Optional sourcePni, final UUID messageGuid, final UUID reporterUuid, - final Optional reportSpamToken) { + final Optional reportSpamToken, + final String reporterUserAgent) { final boolean found = sourceAci.map(uuid -> reportMessageDynamoDb.remove(hash(messageGuid, uuid.toString()))) .orElse(false); Metrics.counter(REPORT_MESSAGE_COUNTER_NAME, - FOUND_MESSAGE_TAG, String.valueOf(found), - TOKEN_PRESENT_TAG, String.valueOf(reportSpamToken.isPresent())) + Tags.of(FOUND_MESSAGE_TAG, String.valueOf(found), + TOKEN_PRESENT_TAG, String.valueOf(reportSpamToken.isPresent())) + .and(UserAgentTagUtil.getPlatformTag(reporterUserAgent))) .increment(); if (found) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index b02cc0ead..538ca603c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -29,7 +29,6 @@ import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.asJson; import static org.whispersystems.textsecuregcm.tests.util.JsonHelpers.jsonFixture; import com.google.common.collect.ImmutableSet; -import com.google.common.net.HttpHeaders; import com.google.protobuf.ByteString; import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; @@ -54,6 +53,7 @@ import java.util.concurrent.TimeUnit; import java.util.stream.Stream; import javax.ws.rs.client.Entity; import javax.ws.rs.client.Invocation; +import javax.ws.rs.core.HttpHeaders; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import org.glassfish.jersey.server.ServerProperties; @@ -607,6 +607,7 @@ class MessageControllerTest { final String senderNumber = "+12125550001"; final UUID senderAci = UUID.randomUUID(); final UUID senderPni = UUID.randomUUID(); + final String userAgent = "user-agent"; UUID messageGuid = UUID.randomUUID(); final Account account = mock(Account.class); @@ -623,12 +624,13 @@ class MessageControllerTest { .target(String.format("/v1/messages/report/%s/%s", senderNumber, messageGuid)) .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.USER_AGENT, userAgent) .post(null); assertThat(response.getStatus(), is(equalTo(202))); verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni), - messageGuid, AuthHelper.VALID_UUID, Optional.empty()); + messageGuid, AuthHelper.VALID_UUID, Optional.empty(), userAgent); verify(deletedAccountsManager, never()).findDeletedAccountE164(any(UUID.class)); verify(accountsManager, never()).getPhoneNumberIdentifier(anyString()); @@ -640,12 +642,13 @@ class MessageControllerTest { .target(String.format("/v1/messages/report/%s/%s", senderNumber, messageGuid)) .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.USER_AGENT, userAgent) .post(null); assertThat(response.getStatus(), is(equalTo(202))); verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni), - messageGuid, AuthHelper.VALID_UUID, Optional.empty()); + messageGuid, AuthHelper.VALID_UUID, Optional.empty(), userAgent); } @Test @@ -654,6 +657,7 @@ class MessageControllerTest { final String senderNumber = "+12125550001"; final UUID senderAci = UUID.randomUUID(); final UUID senderPni = UUID.randomUUID(); + final String userAgent = "user-agent"; UUID messageGuid = UUID.randomUUID(); final Account account = mock(Account.class); @@ -670,12 +674,13 @@ class MessageControllerTest { .target(String.format("/v1/messages/report/%s/%s", senderAci, messageGuid)) .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.USER_AGENT, userAgent) .post(null); assertThat(response.getStatus(), is(equalTo(202))); verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni), - messageGuid, AuthHelper.VALID_UUID, Optional.empty()); + messageGuid, AuthHelper.VALID_UUID, Optional.empty(), userAgent); verify(deletedAccountsManager, never()).findDeletedAccountE164(any(UUID.class)); verify(accountsManager, never()).getPhoneNumberIdentifier(anyString()); @@ -688,12 +693,13 @@ class MessageControllerTest { .target(String.format("/v1/messages/report/%s/%s", senderAci, messageGuid)) .request() .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .header(HttpHeaders.USER_AGENT, userAgent) .post(null); assertThat(response.getStatus(), is(equalTo(202))); verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni), - messageGuid, AuthHelper.VALID_UUID, Optional.empty()); + messageGuid, AuthHelper.VALID_UUID, Optional.empty(), userAgent); } @Test @@ -727,7 +733,8 @@ class MessageControllerTest { eq(Optional.of(senderPni)), eq(messageGuid), eq(AuthHelper.VALID_UUID), - argThat(maybeBytes -> maybeBytes.map(bytes -> Arrays.equals(bytes, new byte[3])).orElse(false))); + argThat(maybeBytes -> maybeBytes.map(bytes -> Arrays.equals(bytes, new byte[3])).orElse(false)), + any()); verify(deletedAccountsManager, never()).findDeletedAccountE164(any(UUID.class)); verify(accountsManager, never()).getPhoneNumberIdentifier(anyString()); when(accountsManager.getByAccountIdentifier(senderAci)).thenReturn(Optional.empty()); @@ -748,7 +755,8 @@ class MessageControllerTest { eq(Optional.of(senderPni)), eq(messageGuid), eq(AuthHelper.VALID_UUID), - argThat(maybeBytes -> maybeBytes.map(bytes -> Arrays.equals(bytes, new byte[5])).orElse(false))); + argThat(maybeBytes -> maybeBytes.map(bytes -> Arrays.equals(bytes, new byte[5])).orElse(false)), + any()); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageManagerTest.java index c5599e3b1..2cc6d9adb 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ReportMessageManagerTest.java @@ -81,13 +81,13 @@ class ReportMessageManagerTest { when(reportMessageDynamoDb.remove(any())).thenReturn(false); reportMessageManager.report(Optional.of(sourceNumber), Optional.of(sourceAci), Optional.of(sourcePni), messageGuid, - reporterUuid, Optional.empty()); + reporterUuid, Optional.empty(), "user-agent"); assertEquals(0, reportMessageManager.getRecentReportCount(sourceAccount)); when(reportMessageDynamoDb.remove(any())).thenReturn(true); reportMessageManager.report(Optional.of(sourceNumber), Optional.of(sourceAci), Optional.of(sourcePni), messageGuid, - reporterUuid, Optional.empty()); + reporterUuid, Optional.empty(), "user-agent"); assertEquals(1, reportMessageManager.getRecentReportCount(sourceAccount)); verify(listener).handleMessageReported(sourceNumber, messageGuid, reporterUuid, Optional.empty()); @@ -100,7 +100,7 @@ class ReportMessageManagerTest { for (int i = 0; i < 100; i++) { reportMessageManager.report(Optional.of(sourceNumber), Optional.of(sourceAci), Optional.of(sourcePni), - messageGuid, UUID.randomUUID(), Optional.empty()); + messageGuid, UUID.randomUUID(), Optional.empty(), "user-agent"); } assertTrue(reportMessageManager.getRecentReportCount(sourceAccount) > 10); @@ -114,7 +114,7 @@ class ReportMessageManagerTest { for (int i = 0; i < 100; i++) { reportMessageManager.report(Optional.of(sourceNumber), Optional.of(sourceAci), Optional.of(sourcePni), messageGuid, - reporterUuid, Optional.empty()); + reporterUuid, Optional.empty(), "user-agent"); } assertEquals(1, reportMessageManager.getRecentReportCount(sourceAccount)); @@ -127,11 +127,11 @@ class ReportMessageManagerTest { for (int i = 0; i < 100; i++) { reportMessageManager.report(Optional.empty(), Optional.of(sourceAci), Optional.of(sourcePni), - messageGuid, UUID.randomUUID(), Optional.empty()); + messageGuid, UUID.randomUUID(), Optional.empty(), "user-agent"); } reportMessageManager.report(Optional.empty(), Optional.of(sourceAci), Optional.empty(), - messageGuid, UUID.randomUUID(), Optional.empty()); + messageGuid, UUID.randomUUID(), Optional.empty(), "user-agent"); final int recentReportCount = reportMessageManager.getRecentReportCount(sourceAccount); assertTrue(recentReportCount > 10);