Add spam report token support to `ReportedMessageListener`
This commit is contained in:
parent
00e08b8402
commit
4a2768b81d
|
@ -643,13 +643,14 @@ public class MessageController {
|
|||
UUID spamReporterUuid = auth.getAccount().getUuid();
|
||||
|
||||
// spam report token is optional, but if provided ensure it is valid base64.
|
||||
@Nullable final byte[] spamReportToken = spamReport != null ? spamReport.token() : null;
|
||||
final Optional<byte[]> maybeSpamReportToken =
|
||||
spamReport != null ? Optional.of(spamReport.token()) : Optional.empty();
|
||||
|
||||
// fire-and-forget: we don't want to block the response on this action.
|
||||
CompletableFuture<Boolean> ignored =
|
||||
reportSpamTokenHandler.handle(sourceNumber, sourceAci, sourcePni, messageGuid, spamReporterUuid, spamReportToken);
|
||||
reportSpamTokenHandler.handle(sourceNumber, sourceAci, sourcePni, messageGuid, spamReporterUuid, maybeSpamReportToken.orElse(null));
|
||||
|
||||
reportMessageManager.report(sourceNumber, sourceAci, sourcePni, messageGuid, spamReporterUuid);
|
||||
reportMessageManager.report(sourceNumber, sourceAci, sourcePni, messageGuid, spamReporterUuid, maybeSpamReportToken);
|
||||
|
||||
return Response.status(Status.ACCEPTED)
|
||||
.build();
|
||||
|
|
|
@ -9,6 +9,7 @@ import static com.codahale.metrics.MetricRegistry.name;
|
|||
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import net.logstash.logback.marker.Markers;
|
||||
import org.slf4j.Logger;
|
||||
|
@ -35,7 +36,9 @@ public class ReportedMessageMetricsListener implements ReportedMessageListener {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void handleMessageReported(final String sourceNumber, final UUID messageGuid, final UUID reporterUuid) {
|
||||
public void handleMessageReported(final String sourceNumber, final UUID messageGuid, final UUID reporterUuid,
|
||||
final Optional<byte[]> reportSpamToken) {
|
||||
|
||||
final String sourceCountryCode = Util.getCountryCode(sourceNumber);
|
||||
|
||||
Metrics.counter(REPORTED_COUNTER_NAME, COUNTRY_CODE_TAG_NAME, sourceCountryCode).increment();
|
||||
|
|
|
@ -15,8 +15,10 @@ import java.util.List;
|
|||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import io.micrometer.core.instrument.Metrics;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
|
||||
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
|
||||
import org.whispersystems.textsecuregcm.util.UUIDUtil;
|
||||
|
||||
|
@ -29,6 +31,10 @@ public class ReportMessageManager {
|
|||
|
||||
private final List<ReportedMessageListener> reportedMessageListeners = new ArrayList<>();
|
||||
|
||||
private static final String REPORT_MESSAGE_COUNTER_NAME = MetricsUtil.name(ReportMessageManager.class);
|
||||
private static final String FOUND_MESSAGE_TAG = "foundMessage";
|
||||
private static final String TOKEN_PRESENT_TAG = "hasReportSpamToken";
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(ReportMessageManager.class);
|
||||
|
||||
public ReportMessageManager(final ReportMessageDynamoDb reportMessageDynamoDb,
|
||||
|
@ -56,12 +62,21 @@ public class ReportMessageManager {
|
|||
}
|
||||
}
|
||||
|
||||
public void report(Optional<String> sourceNumber, Optional<UUID> sourceAci, Optional<UUID> sourcePni,
|
||||
UUID messageGuid, UUID reporterUuid) {
|
||||
public void report(final Optional<String> sourceNumber,
|
||||
final Optional<UUID> sourceAci,
|
||||
final Optional<UUID> sourcePni,
|
||||
final UUID messageGuid,
|
||||
final UUID reporterUuid,
|
||||
final Optional<byte[]> reportSpamToken) {
|
||||
|
||||
final boolean found = sourceAci.map(uuid -> reportMessageDynamoDb.remove(hash(messageGuid, uuid.toString())))
|
||||
.orElse(false);
|
||||
|
||||
Metrics.counter(REPORT_MESSAGE_COUNTER_NAME,
|
||||
FOUND_MESSAGE_TAG, String.valueOf(found),
|
||||
TOKEN_PRESENT_TAG, String.valueOf(reportSpamToken.isPresent()))
|
||||
.increment();
|
||||
|
||||
if (found) {
|
||||
rateLimitCluster.useCluster(connection -> {
|
||||
sourcePni.ifPresent(pni -> {
|
||||
|
@ -80,7 +95,7 @@ public class ReportMessageManager {
|
|||
sourceNumber.ifPresent(number ->
|
||||
reportedMessageListeners.forEach(listener -> {
|
||||
try {
|
||||
listener.handleMessageReported(number, messageGuid, reporterUuid);
|
||||
listener.handleMessageReported(number, messageGuid, reporterUuid, reportSpamToken);
|
||||
} catch (final Exception e) {
|
||||
logger.error("Failed to notify listener of reported message", e);
|
||||
}
|
||||
|
|
|
@ -5,9 +5,10 @@
|
|||
|
||||
package org.whispersystems.textsecuregcm.storage;
|
||||
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
public interface ReportedMessageListener {
|
||||
|
||||
void handleMessageReported(String sourceNumber, UUID messageGuid, UUID reporterUuid);
|
||||
void handleMessageReported(String sourceNumber, UUID messageGuid, UUID reporterUuid, Optional<byte[]> reportSpamToken);
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
|
|||
import static org.junit.jupiter.api.Assertions.assertNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.argThat;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.anyBoolean;
|
||||
import static org.mockito.Mockito.anyString;
|
||||
|
@ -67,6 +68,7 @@ import org.junit.jupiter.params.ParameterizedTest;
|
|||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.ArgumentMatcher;
|
||||
import org.whispersystems.textsecuregcm.spam.ReportSpamTokenHandler;
|
||||
import org.whispersystems.textsecuregcm.spam.ReportSpamTokenProvider;
|
||||
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
|
||||
|
@ -634,7 +636,7 @@ class MessageControllerTest {
|
|||
assertThat(response.getStatus(), is(equalTo(202)));
|
||||
|
||||
verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni),
|
||||
messageGuid, AuthHelper.VALID_UUID);
|
||||
messageGuid, AuthHelper.VALID_UUID, Optional.empty());
|
||||
verify(deletedAccountsManager, never()).findDeletedAccountE164(any(UUID.class));
|
||||
verify(accountsManager, never()).getPhoneNumberIdentifier(anyString());
|
||||
|
||||
|
@ -651,7 +653,7 @@ class MessageControllerTest {
|
|||
assertThat(response.getStatus(), is(equalTo(202)));
|
||||
|
||||
verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni),
|
||||
messageGuid, AuthHelper.VALID_UUID);
|
||||
messageGuid, AuthHelper.VALID_UUID, Optional.empty());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -681,7 +683,7 @@ class MessageControllerTest {
|
|||
assertThat(response.getStatus(), is(equalTo(202)));
|
||||
|
||||
verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni),
|
||||
messageGuid, AuthHelper.VALID_UUID);
|
||||
messageGuid, AuthHelper.VALID_UUID, Optional.empty());
|
||||
verify(deletedAccountsManager, never()).findDeletedAccountE164(any(UUID.class));
|
||||
verify(accountsManager, never()).getPhoneNumberIdentifier(anyString());
|
||||
|
||||
|
@ -699,7 +701,7 @@ class MessageControllerTest {
|
|||
assertThat(response.getStatus(), is(equalTo(202)));
|
||||
|
||||
verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni),
|
||||
messageGuid, AuthHelper.VALID_UUID);
|
||||
messageGuid, AuthHelper.VALID_UUID, Optional.empty());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -733,8 +735,12 @@ class MessageControllerTest {
|
|||
assertThat(response.getStatus(), is(equalTo(202)));
|
||||
verify(REPORT_SPAM_TOKEN_HANDLER).handle(any(), any(), any(), any(), any(), captor.capture());
|
||||
assertArrayEquals(new byte[3], captor.getValue());
|
||||
verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni),
|
||||
messageGuid, AuthHelper.VALID_UUID);
|
||||
verify(reportMessageManager).report(eq(Optional.of(senderNumber)),
|
||||
eq(Optional.of(senderAci)),
|
||||
eq(Optional.of(senderPni)),
|
||||
eq(messageGuid),
|
||||
eq(AuthHelper.VALID_UUID),
|
||||
argThat(maybeBytes -> maybeBytes.map(bytes -> Arrays.equals(bytes, new byte[3])).orElse(false)));
|
||||
verify(deletedAccountsManager, never()).findDeletedAccountE164(any(UUID.class));
|
||||
verify(accountsManager, never()).getPhoneNumberIdentifier(anyString());
|
||||
when(accountsManager.getByAccountIdentifier(senderAci)).thenReturn(Optional.empty());
|
||||
|
@ -754,8 +760,12 @@ class MessageControllerTest {
|
|||
assertThat(response.getStatus(), is(equalTo(202)));
|
||||
verify(REPORT_SPAM_TOKEN_HANDLER).handle(any(), any(), any(), any(), any(), captor.capture());
|
||||
assertArrayEquals(new byte[5], captor.getValue());
|
||||
verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni),
|
||||
messageGuid, AuthHelper.VALID_UUID);
|
||||
verify(reportMessageManager).report(eq(Optional.of(senderNumber)),
|
||||
eq(Optional.of(senderAci)),
|
||||
eq(Optional.of(senderPni)),
|
||||
eq(messageGuid),
|
||||
eq(AuthHelper.VALID_UUID),
|
||||
argThat(maybeBytes -> maybeBytes.map(bytes -> Arrays.equals(bytes, new byte[5])).orElse(false)));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -81,16 +81,16 @@ class ReportMessageManagerTest {
|
|||
|
||||
when(reportMessageDynamoDb.remove(any())).thenReturn(false);
|
||||
reportMessageManager.report(Optional.of(sourceNumber), Optional.of(sourceAci), Optional.of(sourcePni), messageGuid,
|
||||
reporterUuid);
|
||||
reporterUuid, Optional.empty());
|
||||
|
||||
assertEquals(0, reportMessageManager.getRecentReportCount(sourceAccount));
|
||||
|
||||
when(reportMessageDynamoDb.remove(any())).thenReturn(true);
|
||||
reportMessageManager.report(Optional.of(sourceNumber), Optional.of(sourceAci), Optional.of(sourcePni), messageGuid,
|
||||
reporterUuid);
|
||||
reporterUuid, Optional.empty());
|
||||
|
||||
assertEquals(1, reportMessageManager.getRecentReportCount(sourceAccount));
|
||||
verify(listener).handleMessageReported(sourceNumber, messageGuid, reporterUuid);
|
||||
verify(listener).handleMessageReported(sourceNumber, messageGuid, reporterUuid, Optional.empty());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -100,7 +100,7 @@ class ReportMessageManagerTest {
|
|||
|
||||
for (int i = 0; i < 100; i++) {
|
||||
reportMessageManager.report(Optional.of(sourceNumber), Optional.of(sourceAci), Optional.of(sourcePni),
|
||||
messageGuid, UUID.randomUUID());
|
||||
messageGuid, UUID.randomUUID(), Optional.empty());
|
||||
}
|
||||
|
||||
assertTrue(reportMessageManager.getRecentReportCount(sourceAccount) > 10);
|
||||
|
@ -114,7 +114,7 @@ class ReportMessageManagerTest {
|
|||
for (int i = 0; i < 100; i++) {
|
||||
reportMessageManager.report(Optional.of(sourceNumber), Optional.of(sourceAci), Optional.of(sourcePni),
|
||||
messageGuid,
|
||||
reporterUuid);
|
||||
reporterUuid, Optional.empty());
|
||||
}
|
||||
|
||||
assertEquals(1, reportMessageManager.getRecentReportCount(sourceAccount));
|
||||
|
@ -127,11 +127,11 @@ class ReportMessageManagerTest {
|
|||
|
||||
for (int i = 0; i < 100; i++) {
|
||||
reportMessageManager.report(Optional.empty(), Optional.of(sourceAci), Optional.of(sourcePni),
|
||||
messageGuid, UUID.randomUUID());
|
||||
messageGuid, UUID.randomUUID(), Optional.empty());
|
||||
}
|
||||
|
||||
reportMessageManager.report(Optional.empty(), Optional.of(sourceAci), Optional.empty(),
|
||||
messageGuid, UUID.randomUUID());
|
||||
messageGuid, UUID.randomUUID(), Optional.empty());
|
||||
|
||||
final int recentReportCount = reportMessageManager.getRecentReportCount(sourceAccount);
|
||||
assertTrue(recentReportCount > 10);
|
||||
|
|
Loading…
Reference in New Issue