diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index e6eae4ee8..055b4c15f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -73,6 +73,8 @@ import org.whispersystems.dispatch.DispatchManager; import org.whispersystems.textsecuregcm.abuse.AbusiveMessageFilter; import org.whispersystems.textsecuregcm.abuse.FilterAbusiveMessages; import org.whispersystems.textsecuregcm.abuse.RateLimitChallengeListener; +import org.whispersystems.textsecuregcm.abuse.ReportSpamTokenHandler; +import org.whispersystems.textsecuregcm.abuse.ReportSpamTokenProvider; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.CertificateGenerator; @@ -678,50 +680,29 @@ public class WhisperServerService extends Application commonControllers = Lists.newArrayList( - new ArtController(rateLimiters, artCredentialsGenerator), - new AttachmentControllerV2(rateLimiters, config.getAwsAttachmentsConfiguration().getAccessKey(), config.getAwsAttachmentsConfiguration().getAccessSecret(), config.getAwsAttachmentsConfiguration().getRegion(), config.getAwsAttachmentsConfiguration().getBucket()), - new AttachmentControllerV3(rateLimiters, config.getGcpAttachmentsConfiguration().getDomain(), config.getGcpAttachmentsConfiguration().getEmail(), config.getGcpAttachmentsConfiguration().getMaxSizeInBytes(), config.getGcpAttachmentsConfiguration().getPathPrefix(), config.getGcpAttachmentsConfiguration().getRsaSigningKey()), - new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().getCertificate(), config.getDeliveryCertificate().getPrivateKey(), config.getDeliveryCertificate().getExpiresDays()), zkAuthOperations, clock), - new ChallengeController(rateLimitChallengeManager), - new DeviceController(pendingDevicesManager, accountsManager, messagesManager, keys, rateLimiters, config.getMaxDevices()), - new DirectoryController(directoryCredentialsGenerator), - new DirectoryV2Controller(directoryV2CredentialsGenerator), - new DonationController(clock, zkReceiptOperations, redeemedReceiptsManager, accountsManager, config.getBadges(), - ReceiptCredentialPresentation::new), - new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, deletedAccountsManager, messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor), - new PaymentsController(currencyManager, paymentsCredentialsGenerator), - new ProfileController(clock, rateLimiters, accountsManager, profilesManager, dynamicConfigurationManager, - profileBadgeConverter, config.getBadges(), cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, - config.getCdnConfiguration().getBucket(), zkProfileOperations, batchIdentityCheckExecutor), - new ProvisioningController(rateLimiters, provisioningManager), - new RemoteConfigController(remoteConfigsManager, adminEventLogger, - config.getRemoteConfigConfiguration().getAuthorizedTokens(), - config.getRemoteConfigConfiguration().getGlobalConfig()), - new SecureBackupController(backupCredentialsGenerator), - new SecureStorageController(storageCredentialsGenerator), - new StickerController(rateLimiters, config.getCdnConfiguration().getAccessKey(), - config.getCdnConfiguration().getAccessSecret(), config.getCdnConfiguration().getRegion(), - config.getCdnConfiguration().getBucket()) - ); - if (config.getSubscription() != null && config.getOneTimeDonations() != null) { - commonControllers.add(new SubscriptionController(clock, config.getSubscription(), config.getOneTimeDonations(), - subscriptionManager, stripeManager, braintreeManager, zkReceiptOperations, issuedReceiptsManager, profileBadgeConverter, - resourceBundleLevelTranslator)); - } - - for (Object controller : commonControllers) { - environment.jersey().register(controller); - webSocketEnvironment.jersey().register(controller); - } - boolean registeredAbusiveMessageFilter = false; + ReportSpamTokenProvider reportSpamTokenProvider = null; + ReportSpamTokenHandler reportSpamTokenHandler = null; for (final AbusiveMessageFilter filter : ServiceLoader.load(AbusiveMessageFilter.class)) { if (filter.getClass().isAnnotationPresent(FilterAbusiveMessages.class)) { try { filter.configure(config.getAbusiveMessageFilterConfiguration().getEnvironment()); + ReportSpamTokenProvider thisProvider = filter.getReportSpamTokenProvider(); + if (reportSpamTokenProvider == null) { + reportSpamTokenProvider = thisProvider; + } else if (thisProvider != null) { + log.info("Multiple spam report token providers found. Using the first."); + } + + ReportSpamTokenHandler thisHandler = filter.getReportSpamTokenHandler(); + if (reportSpamTokenHandler == null) { + reportSpamTokenHandler = thisHandler; + } else if (thisProvider != null) { + log.info("Multiple spam report token handlers found. Using the first."); + } + environment.lifecycle().manage(filter); environment.jersey().register(filter); webSocketEnvironment.jersey().register(filter); @@ -746,6 +727,52 @@ public class WhisperServerService extends Application commonControllers = Lists.newArrayList( + new ArtController(rateLimiters, artCredentialsGenerator), + new AttachmentControllerV2(rateLimiters, config.getAwsAttachmentsConfiguration().getAccessKey(), config.getAwsAttachmentsConfiguration().getAccessSecret(), config.getAwsAttachmentsConfiguration().getRegion(), config.getAwsAttachmentsConfiguration().getBucket()), + new AttachmentControllerV3(rateLimiters, config.getGcpAttachmentsConfiguration().getDomain(), config.getGcpAttachmentsConfiguration().getEmail(), config.getGcpAttachmentsConfiguration().getMaxSizeInBytes(), config.getGcpAttachmentsConfiguration().getPathPrefix(), config.getGcpAttachmentsConfiguration().getRsaSigningKey()), + new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().getCertificate(), config.getDeliveryCertificate().getPrivateKey(), config.getDeliveryCertificate().getExpiresDays()), zkAuthOperations, clock), + new ChallengeController(rateLimitChallengeManager), + new DeviceController(pendingDevicesManager, accountsManager, messagesManager, keys, rateLimiters, config.getMaxDevices()), + new DirectoryController(directoryCredentialsGenerator), + new DirectoryV2Controller(directoryV2CredentialsGenerator), + new DonationController(clock, zkReceiptOperations, redeemedReceiptsManager, accountsManager, config.getBadges(), + ReceiptCredentialPresentation::new), + new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, deletedAccountsManager, messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor, + reportSpamTokenProvider, reportSpamTokenHandler), + new PaymentsController(currencyManager, paymentsCredentialsGenerator), + new ProfileController(clock, rateLimiters, accountsManager, profilesManager, dynamicConfigurationManager, + profileBadgeConverter, config.getBadges(), cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, + config.getCdnConfiguration().getBucket(), zkProfileOperations, batchIdentityCheckExecutor), + new ProvisioningController(rateLimiters, provisioningManager), + new RemoteConfigController(remoteConfigsManager, adminEventLogger, + config.getRemoteConfigConfiguration().getAuthorizedTokens(), + config.getRemoteConfigConfiguration().getGlobalConfig()), + new SecureBackupController(backupCredentialsGenerator), + new SecureStorageController(storageCredentialsGenerator), + new StickerController(rateLimiters, config.getCdnConfiguration().getAccessKey(), + config.getCdnConfiguration().getAccessSecret(), config.getCdnConfiguration().getRegion(), + config.getCdnConfiguration().getBucket()) + ); + if (config.getSubscription() != null && config.getOneTimeDonations() != null) { + commonControllers.add(new SubscriptionController(clock, config.getSubscription(), config.getOneTimeDonations(), + subscriptionManager, stripeManager, braintreeManager, zkReceiptOperations, issuedReceiptsManager, profileBadgeConverter, + resourceBundleLevelTranslator)); + } + + for (Object controller : commonControllers) { + environment.jersey().register(controller); + webSocketEnvironment.jersey().register(controller); + } + WebSocketEnvironment provisioningEnvironment = new WebSocketEnvironment<>(environment, webSocketEnvironment.getRequestLog(), 60000); provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/abuse/AbusiveMessageFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/abuse/AbusiveMessageFilter.java index 0a0538769..654e15951 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/abuse/AbusiveMessageFilter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/abuse/AbusiveMessageFilter.java @@ -30,4 +30,19 @@ public interface AbusiveMessageFilter extends ContainerRequestFilter, Managed { * @throws IOException if the filter could not read its configuration source for any reason */ void configure(String environmentName) throws IOException; + + /** + * Builds a spam report token provider. This will generate tokens used by the spam reporting system. + * + * @return the configured spam report token provider. + */ + ReportSpamTokenProvider getReportSpamTokenProvider(); + + /** + * Builds a spam report token handler. This will handle tokens received by the spam reporting system. + * + * @return the configured spam report token handler + */ + ReportSpamTokenHandler getReportSpamTokenHandler(); + } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/abuse/ReportSpamTokenHandler.java b/service/src/main/java/org/whispersystems/textsecuregcm/abuse/ReportSpamTokenHandler.java new file mode 100644 index 000000000..7e5c1e387 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/abuse/ReportSpamTokenHandler.java @@ -0,0 +1,47 @@ +package org.whispersystems.textsecuregcm.abuse; + +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; + +/** + * Handles ReportSpamTokens during spam reports. + */ +public interface ReportSpamTokenHandler { + + /** + * Handle spam reports using the given ReportSpamToken and other provided parameters. + * + * @param reportSpamToken binary data representing a spam report token. + * @return true if the token could be handled (and was), false otherwise. + */ + CompletableFuture handle( + Optional sourceNumber, + Optional sourceAci, + Optional sourcePni, + UUID messageGuid, + UUID spamReporterUuid, + byte[] reportSpamToken); + + /** + * Handler which does nothing. + * + * @return the handler + */ + static ReportSpamTokenHandler noop() { + return new ReportSpamTokenHandler() { + @Override + public CompletableFuture handle( + final Optional sourceNumber, + final Optional sourceAci, + final Optional sourcePni, + final UUID messageGuid, + final UUID spamReporterUuid, + final byte[] reportSpamToken) { + return CompletableFuture.completedFuture(false); + } + }; + } + + +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/abuse/ReportSpamTokenProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/abuse/ReportSpamTokenProvider.java new file mode 100644 index 000000000..fbfe66ec1 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/abuse/ReportSpamTokenProvider.java @@ -0,0 +1,38 @@ +package org.whispersystems.textsecuregcm.abuse; + +import javax.ws.rs.container.ContainerRequestContext; +import java.util.Optional; +import java.util.function.Function; + +/** + * Generates ReportSpamTokens to be used for spam reports. + */ +public interface ReportSpamTokenProvider { + + /** + * Generate a new ReportSpamToken + * + * @param context the message request context + * @return either a generated token or nothing + */ + Optional makeReportSpamToken(ContainerRequestContext context); + + /** + * Provider which generates nothing + * + * @return the provider + */ + static ReportSpamTokenProvider noop() { + return create(c -> Optional.empty()); + } + + /** + * Provider which generates ReportSpamTokens using the given function + * + * @param fn function from message requests to optional tokens + * @return the provider + */ + static ReportSpamTokenProvider create(Function> fn) { + return fn::apply; + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java index 258cb581a..a638bd981 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/MessageController.java @@ -55,12 +55,16 @@ import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; import javax.ws.rs.WebApplicationException; +import javax.ws.rs.container.ContainerRequestContext; +import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import javax.ws.rs.core.Response.Status; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.abuse.FilterAbusiveMessages; +import org.whispersystems.textsecuregcm.abuse.ReportSpamTokenHandler; +import org.whispersystems.textsecuregcm.abuse.ReportSpamTokenProvider; import org.whispersystems.textsecuregcm.auth.Anonymous; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKeys; @@ -78,6 +82,7 @@ import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.SendMessageResponse; import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; +import org.whispersystems.textsecuregcm.entities.SpamReport; import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; @@ -117,6 +122,8 @@ public class MessageController { private final PushNotificationManager pushNotificationManager; private final ReportMessageManager reportMessageManager; private final ExecutorService multiRecipientMessageExecutor; + private final ReportSpamTokenProvider reportSpamTokenProvider; + private final ReportSpamTokenHandler reportSpamTokenHandler; private static final String REJECT_OVERSIZE_MESSAGE_COUNTER = name(MessageController.class, "rejectOversizeMessage"); private static final String SENT_MESSAGE_COUNTER_NAME = name(MessageController.class, "sentMessages"); @@ -147,7 +154,9 @@ public class MessageController { MessagesManager messagesManager, PushNotificationManager pushNotificationManager, ReportMessageManager reportMessageManager, - @Nonnull ExecutorService multiRecipientMessageExecutor) { + @Nonnull ExecutorService multiRecipientMessageExecutor, + @Nonnull ReportSpamTokenProvider reportSpamTokenProvider, + @Nonnull ReportSpamTokenHandler reportSpamTokenHandler) { this.rateLimiters = rateLimiters; this.messageSender = messageSender; this.receiptSender = receiptSender; @@ -157,6 +166,9 @@ public class MessageController { this.pushNotificationManager = pushNotificationManager; this.reportMessageManager = reportMessageManager; this.multiRecipientMessageExecutor = Objects.requireNonNull(multiRecipientMessageExecutor); + this.reportSpamTokenProvider = reportSpamTokenProvider; + this.reportSpamTokenHandler = reportSpamTokenHandler; + } @Timed @@ -171,7 +183,9 @@ public class MessageController { @HeaderParam(HttpHeaders.X_FORWARDED_FOR) String forwardedFor, @PathParam("destination") UUID destinationUuid, @QueryParam("story") boolean isStory, - @NotNull @Valid IncomingMessageList messages) + @NotNull @Valid IncomingMessageList messages, + @Context ContainerRequestContext context + ) throws RateLimitExceededException { if (source.isEmpty() && accessKey.isEmpty() && !isStory) { @@ -190,6 +204,13 @@ public class MessageController { senderType = SENDER_TYPE_UNIDENTIFIED; } + final Optional spamReportToken; + if (senderType.equals(SENDER_TYPE_IDENTIFIED)) { + spamReportToken = reportSpamTokenProvider.makeReportSpamToken(context); + } else { + spamReportToken = Optional.empty(); + } + for (final IncomingMessage message : messages.messages()) { int contentLength = 0; @@ -267,7 +288,18 @@ public class MessageController { if (destinationDevice.isPresent()) { Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment(); - sendIndividualMessage(source, destination.get(), destinationDevice.get(), destinationUuid, messages.timestamp(), messages.online(), isStory, messages.urgent(), incomingMessage, userAgent); + sendIndividualMessage( + source, + destination.get(), + destinationDevice.get(), + destinationUuid, + messages.timestamp(), + messages.online(), + isStory, + messages.urgent(), + incomingMessage, + userAgent, + spamReportToken); } } @@ -570,9 +602,14 @@ public class MessageController { @Timed @POST + @Consumes(MediaType.APPLICATION_JSON) @Path("/report/{source}/{messageGuid}") - public Response reportMessage(@Auth AuthenticatedAccount auth, @PathParam("source") String source, - @PathParam("messageGuid") UUID messageGuid) { + public Response reportSpamMessage( + @Auth AuthenticatedAccount auth, + @PathParam("source") String source, + @PathParam("messageGuid") UUID messageGuid, + @Nullable @Valid SpamReport spamReport + ) { final Optional sourceNumber; final Optional sourceAci; @@ -602,13 +639,30 @@ public class MessageController { } } - reportMessageManager.report(sourceNumber, sourceAci, sourcePni, messageGuid, auth.getAccount().getUuid()); + UUID spamReporterUuid = auth.getAccount().getUuid(); + + // spam report token is optional, but if provided ensure it is valid base64. + byte[] spamReportToken = null; + if (spamReport != null) { + try { + spamReportToken = Base64.getDecoder().decode(spamReport.token()); + } catch (IllegalArgumentException e) { + throw new WebApplicationException(Response.status(400).build()); + } + } + + // fire-and-forget: we don't want to block the response on this action. + CompletableFuture ignored = + reportSpamTokenHandler.handle(sourceNumber, sourceAci, sourcePni, messageGuid, spamReporterUuid, spamReportToken); + + reportMessageManager.report(sourceNumber, sourceAci, sourcePni, messageGuid, spamReporterUuid); return Response.status(Status.ACCEPTED) .build(); } - private void sendIndividualMessage(Optional source, + private void sendIndividualMessage( + Optional source, Account destinationAccount, Device destinationDevice, UUID destinationUuid, @@ -617,18 +671,23 @@ public class MessageController { boolean story, boolean urgent, IncomingMessage incomingMessage, - String userAgentString) + String userAgentString, + Optional spamReportToken) throws NoSuchUserException { try { final Envelope envelope; try { - envelope = incomingMessage.toEnvelope(destinationUuid, - source.map(AuthenticatedAccount::getAccount).orElse(null), - source.map(authenticatedAccount -> authenticatedAccount.getAuthenticatedDevice().getId()).orElse(null), + Account sourceAccount = source.map(AuthenticatedAccount::getAccount).orElse(null); + Long sourceDeviceId = source.map(account -> account.getAuthenticatedDevice().getId()).orElse(null); + envelope = incomingMessage.toEnvelope( + destinationUuid, + sourceAccount, + sourceDeviceId, timestamp == 0 ? System.currentTimeMillis() : timestamp, story, - urgent); + urgent, + spamReportToken.orElse(null)); } catch (final IllegalArgumentException e) { logger.warn("Received bad envelope type {} from {}", incomingMessage.type(), userAgentString); throw new BadRequestException(e); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java index 7cbefe6c3..17496e9a7 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/IncomingMessage.java @@ -18,7 +18,8 @@ public record IncomingMessage(int type, long destinationDeviceId, int destinatio @Nullable Long sourceDeviceId, final long timestamp, final boolean story, - final boolean urgent) { + final boolean urgent, + @Nullable byte[] reportSpamToken) { final MessageProtos.Envelope.Type envelopeType = MessageProtos.Envelope.Type.forNumber(type()); @@ -36,10 +37,15 @@ public record IncomingMessage(int type, long destinationDeviceId, int destinatio .setUrgent(urgent); if (sourceAccount != null && sourceDeviceId != null) { - envelopeBuilder.setSourceUuid(sourceAccount.getUuid().toString()) + envelopeBuilder + .setSourceUuid(sourceAccount.getUuid().toString()) .setSourceDevice(sourceDeviceId.intValue()); } + if (reportSpamToken != null) { + envelopeBuilder.setReportSpamToken(ByteString.copyFrom(reportSpamToken)); + } + if (StringUtils.isNotEmpty(content())) { envelopeBuilder.setContent(ByteString.copyFrom(Base64.getDecoder().decode(content()))); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/SpamReport.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/SpamReport.java new file mode 100644 index 000000000..a035ce2bc --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/SpamReport.java @@ -0,0 +1,7 @@ +package org.whispersystems.textsecuregcm.entities; + +import com.fasterxml.jackson.annotation.JsonProperty; +import javax.validation.Valid; +import javax.validation.constraints.NotEmpty; + +public record SpamReport(@JsonProperty("token") @NotEmpty String token) {} diff --git a/service/src/main/proto/TextSecure.proto b/service/src/main/proto/TextSecure.proto index bae56dde3..ef340fb1b 100644 --- a/service/src/main/proto/TextSecure.proto +++ b/service/src/main/proto/TextSecure.proto @@ -33,7 +33,8 @@ message Envelope { optional bool urgent = 14 [default=true]; optional string updated_pni = 15; optional bool story = 16; // indicates that the content is a story. - // next: 17 + optional bytes report_spam_token = 17; // token sent when reporting spam + // next: 18 } message ProvisioningUuid { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java index c390c0c1d..c11b70f8f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/MessageControllerTest.java @@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.controllers; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNull; @@ -17,6 +18,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.anyBoolean; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; @@ -65,6 +67,8 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; +import org.whispersystems.textsecuregcm.abuse.ReportSpamTokenHandler; +import org.whispersystems.textsecuregcm.abuse.ReportSpamTokenProvider; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.OptionalAccess; @@ -79,6 +83,7 @@ import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.entities.SpamReport; import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; @@ -124,8 +129,6 @@ class MessageControllerTest { private static final String INTERNATIONAL_RECIPIENT = "+61123456789"; private static final UUID INTERNATIONAL_UUID = UUID.fromString("33333333-3333-3333-3333-333333333333"); - private Account internationalAccount; - @SuppressWarnings("unchecked") private static final RedisAdvancedClusterCommands redisCommands = mock(RedisAdvancedClusterCommands.class); @@ -139,6 +142,7 @@ class MessageControllerTest { private static final PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class); private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class); private static final ExecutorService multiRecipientMessageExecutor = mock(ExecutorService.class); + private static final ReportSpamTokenHandler REPORT_SPAM_TOKEN_HANDLER = mock(ReportSpamTokenHandler.class); private static final ResourceExtension resources = ResourceExtension.builder() .addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE) @@ -150,7 +154,8 @@ class MessageControllerTest { .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource( new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, deletedAccountsManager, - messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor)) + messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor, + ReportSpamTokenProvider.noop(), REPORT_SPAM_TOKEN_HANDLER)) .build(); @BeforeEach @@ -167,7 +172,8 @@ class MessageControllerTest { Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, UNIDENTIFIED_ACCESS_BYTES); Account multiDeviceAccount = AccountsHelper.generateTestAccount(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, MULTI_DEVICE_PNI, multiDeviceList, UNIDENTIFIED_ACCESS_BYTES); - internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, UUID.randomUUID(), singleDeviceList, UNIDENTIFIED_ACCESS_BYTES); + Account internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, + UUID.randomUUID(), singleDeviceList, UNIDENTIFIED_ACCESS_BYTES); when(accountsManager.getByAccountIdentifier(eq(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount)); when(accountsManager.getByPhoneNumberIdentifier(SINGLE_DEVICE_PNI)).thenReturn(Optional.of(singleDeviceAccount)); @@ -176,6 +182,8 @@ class MessageControllerTest { when(accountsManager.getByAccountIdentifier(INTERNATIONAL_UUID)).thenReturn(Optional.of(internationalAccount)); when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter); + + when(REPORT_SPAM_TOKEN_HANDLER.handle(any(), any(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(false)); } private static Device generateTestDevice(final long id, final int registrationId, final int pniRegistrationId, final SignedPreKey signedPreKey, final long createdAt, final long lastSeen) { @@ -198,12 +206,14 @@ class MessageControllerTest { messageSender, receiptSender, accountsManager, + deletedAccountsManager, messagesManager, rateLimiters, rateLimiter, pushNotificationManager, reportMessageManager, - multiRecipientMessageExecutor + multiRecipientMessageExecutor, + REPORT_SPAM_TOKEN_HANDLER ); } @@ -692,6 +702,64 @@ class MessageControllerTest { messageGuid, AuthHelper.VALID_UUID); } + @Test + void testReportMessageByAciWithSpamReportToken() { + + final String senderNumber = "+12125550001"; + final UUID senderAci = UUID.randomUUID(); + final UUID senderPni = UUID.randomUUID(); + UUID messageGuid = UUID.randomUUID(); + + final Account account = mock(Account.class); + when(account.getUuid()).thenReturn(senderAci); + when(account.getNumber()).thenReturn(senderNumber); + when(account.getPhoneNumberIdentifier()).thenReturn(senderPni); + + when(accountsManager.getByAccountIdentifier(senderAci)).thenReturn(Optional.of(account)); + when(deletedAccountsManager.findDeletedAccountE164(senderAci)).thenReturn(Optional.of(senderNumber)); + when(accountsManager.getPhoneNumberIdentifier(senderNumber)).thenReturn(senderPni); + when(REPORT_SPAM_TOKEN_HANDLER.handle(any(), any(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(true)); + + ArgumentCaptor captor = ArgumentCaptor.forClass(byte[].class); + + String token = Base64.getEncoder().encodeToString(new byte[3]); + Entity entity = Entity.entity(new SpamReport(token), "application/json"); + Response response = + resources.getJerseyTest() + .target(String.format("/v1/messages/report/%s/%s", senderAci, messageGuid)) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .post(entity); + + assertThat(response.getStatus(), is(equalTo(202))); + verify(REPORT_SPAM_TOKEN_HANDLER).handle(any(), any(), any(), any(), any(), captor.capture()); + assertArrayEquals(new byte[3], captor.getValue()); + verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni), + messageGuid, AuthHelper.VALID_UUID); + verify(deletedAccountsManager, never()).findDeletedAccountE164(any(UUID.class)); + verify(accountsManager, never()).getPhoneNumberIdentifier(anyString()); + when(accountsManager.getByAccountIdentifier(senderAci)).thenReturn(Optional.empty()); + + clearInvocations(REPORT_SPAM_TOKEN_HANDLER); + + messageGuid = UUID.randomUUID(); + + token = Base64.getEncoder().encodeToString(new byte[5]); + entity = Entity.entity(new SpamReport(token), "application/json"); + response = + resources.getJerseyTest() + .target(String.format("/v1/messages/report/%s/%s", senderAci, messageGuid)) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .post(entity); + + assertThat(response.getStatus(), is(equalTo(202))); + verify(REPORT_SPAM_TOKEN_HANDLER).handle(any(), any(), any(), any(), any(), captor.capture()); + assertArrayEquals(new byte[5], captor.getValue()); + verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni), + messageGuid, AuthHelper.VALID_UUID); + } + @Test void testValidateContentLength() throws Exception { final int contentLength = Math.toIntExact(MessageController.MAX_MESSAGE_SIZE + 1);