Introduce spam report tokens

This commit is contained in:
erik-signal 2023-01-19 11:13:43 -05:00 committed by GitHub
parent ee5aaf5383
commit ab26a65b6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 325 additions and 57 deletions

View File

@ -73,6 +73,8 @@ import org.whispersystems.dispatch.DispatchManager;
import org.whispersystems.textsecuregcm.abuse.AbusiveMessageFilter; import org.whispersystems.textsecuregcm.abuse.AbusiveMessageFilter;
import org.whispersystems.textsecuregcm.abuse.FilterAbusiveMessages; import org.whispersystems.textsecuregcm.abuse.FilterAbusiveMessages;
import org.whispersystems.textsecuregcm.abuse.RateLimitChallengeListener; import org.whispersystems.textsecuregcm.abuse.RateLimitChallengeListener;
import org.whispersystems.textsecuregcm.abuse.ReportSpamTokenHandler;
import org.whispersystems.textsecuregcm.abuse.ReportSpamTokenProvider;
import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; import org.whispersystems.textsecuregcm.auth.AccountAuthenticator;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.CertificateGenerator; import org.whispersystems.textsecuregcm.auth.CertificateGenerator;
@ -678,50 +680,29 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
environment.jersey().register(new KeysController(rateLimiters, keys, accountsManager)); environment.jersey().register(new KeysController(rateLimiters, keys, accountsManager));
final List<Object> commonControllers = Lists.newArrayList(
new ArtController(rateLimiters, artCredentialsGenerator),
new AttachmentControllerV2(rateLimiters, config.getAwsAttachmentsConfiguration().getAccessKey(), config.getAwsAttachmentsConfiguration().getAccessSecret(), config.getAwsAttachmentsConfiguration().getRegion(), config.getAwsAttachmentsConfiguration().getBucket()),
new AttachmentControllerV3(rateLimiters, config.getGcpAttachmentsConfiguration().getDomain(), config.getGcpAttachmentsConfiguration().getEmail(), config.getGcpAttachmentsConfiguration().getMaxSizeInBytes(), config.getGcpAttachmentsConfiguration().getPathPrefix(), config.getGcpAttachmentsConfiguration().getRsaSigningKey()),
new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().getCertificate(), config.getDeliveryCertificate().getPrivateKey(), config.getDeliveryCertificate().getExpiresDays()), zkAuthOperations, clock),
new ChallengeController(rateLimitChallengeManager),
new DeviceController(pendingDevicesManager, accountsManager, messagesManager, keys, rateLimiters, config.getMaxDevices()),
new DirectoryController(directoryCredentialsGenerator),
new DirectoryV2Controller(directoryV2CredentialsGenerator),
new DonationController(clock, zkReceiptOperations, redeemedReceiptsManager, accountsManager, config.getBadges(),
ReceiptCredentialPresentation::new),
new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, deletedAccountsManager, messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor),
new PaymentsController(currencyManager, paymentsCredentialsGenerator),
new ProfileController(clock, rateLimiters, accountsManager, profilesManager, dynamicConfigurationManager,
profileBadgeConverter, config.getBadges(), cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner,
config.getCdnConfiguration().getBucket(), zkProfileOperations, batchIdentityCheckExecutor),
new ProvisioningController(rateLimiters, provisioningManager),
new RemoteConfigController(remoteConfigsManager, adminEventLogger,
config.getRemoteConfigConfiguration().getAuthorizedTokens(),
config.getRemoteConfigConfiguration().getGlobalConfig()),
new SecureBackupController(backupCredentialsGenerator),
new SecureStorageController(storageCredentialsGenerator),
new StickerController(rateLimiters, config.getCdnConfiguration().getAccessKey(),
config.getCdnConfiguration().getAccessSecret(), config.getCdnConfiguration().getRegion(),
config.getCdnConfiguration().getBucket())
);
if (config.getSubscription() != null && config.getOneTimeDonations() != null) {
commonControllers.add(new SubscriptionController(clock, config.getSubscription(), config.getOneTimeDonations(),
subscriptionManager, stripeManager, braintreeManager, zkReceiptOperations, issuedReceiptsManager, profileBadgeConverter,
resourceBundleLevelTranslator));
}
for (Object controller : commonControllers) {
environment.jersey().register(controller);
webSocketEnvironment.jersey().register(controller);
}
boolean registeredAbusiveMessageFilter = false; boolean registeredAbusiveMessageFilter = false;
ReportSpamTokenProvider reportSpamTokenProvider = null;
ReportSpamTokenHandler reportSpamTokenHandler = null;
for (final AbusiveMessageFilter filter : ServiceLoader.load(AbusiveMessageFilter.class)) { for (final AbusiveMessageFilter filter : ServiceLoader.load(AbusiveMessageFilter.class)) {
if (filter.getClass().isAnnotationPresent(FilterAbusiveMessages.class)) { if (filter.getClass().isAnnotationPresent(FilterAbusiveMessages.class)) {
try { try {
filter.configure(config.getAbusiveMessageFilterConfiguration().getEnvironment()); filter.configure(config.getAbusiveMessageFilterConfiguration().getEnvironment());
ReportSpamTokenProvider thisProvider = filter.getReportSpamTokenProvider();
if (reportSpamTokenProvider == null) {
reportSpamTokenProvider = thisProvider;
} else if (thisProvider != null) {
log.info("Multiple spam report token providers found. Using the first.");
}
ReportSpamTokenHandler thisHandler = filter.getReportSpamTokenHandler();
if (reportSpamTokenHandler == null) {
reportSpamTokenHandler = thisHandler;
} else if (thisProvider != null) {
log.info("Multiple spam report token handlers found. Using the first.");
}
environment.lifecycle().manage(filter); environment.lifecycle().manage(filter);
environment.jersey().register(filter); environment.jersey().register(filter);
webSocketEnvironment.jersey().register(filter); webSocketEnvironment.jersey().register(filter);
@ -746,6 +727,52 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
log.warn("No abusive message filters installed"); log.warn("No abusive message filters installed");
} }
if (reportSpamTokenProvider == null) {
reportSpamTokenProvider = ReportSpamTokenProvider.noop();
}
if (reportSpamTokenHandler == null) {
reportSpamTokenHandler = ReportSpamTokenHandler.noop();
}
final List<Object> commonControllers = Lists.newArrayList(
new ArtController(rateLimiters, artCredentialsGenerator),
new AttachmentControllerV2(rateLimiters, config.getAwsAttachmentsConfiguration().getAccessKey(), config.getAwsAttachmentsConfiguration().getAccessSecret(), config.getAwsAttachmentsConfiguration().getRegion(), config.getAwsAttachmentsConfiguration().getBucket()),
new AttachmentControllerV3(rateLimiters, config.getGcpAttachmentsConfiguration().getDomain(), config.getGcpAttachmentsConfiguration().getEmail(), config.getGcpAttachmentsConfiguration().getMaxSizeInBytes(), config.getGcpAttachmentsConfiguration().getPathPrefix(), config.getGcpAttachmentsConfiguration().getRsaSigningKey()),
new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().getCertificate(), config.getDeliveryCertificate().getPrivateKey(), config.getDeliveryCertificate().getExpiresDays()), zkAuthOperations, clock),
new ChallengeController(rateLimitChallengeManager),
new DeviceController(pendingDevicesManager, accountsManager, messagesManager, keys, rateLimiters, config.getMaxDevices()),
new DirectoryController(directoryCredentialsGenerator),
new DirectoryV2Controller(directoryV2CredentialsGenerator),
new DonationController(clock, zkReceiptOperations, redeemedReceiptsManager, accountsManager, config.getBadges(),
ReceiptCredentialPresentation::new),
new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, deletedAccountsManager, messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor,
reportSpamTokenProvider, reportSpamTokenHandler),
new PaymentsController(currencyManager, paymentsCredentialsGenerator),
new ProfileController(clock, rateLimiters, accountsManager, profilesManager, dynamicConfigurationManager,
profileBadgeConverter, config.getBadges(), cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner,
config.getCdnConfiguration().getBucket(), zkProfileOperations, batchIdentityCheckExecutor),
new ProvisioningController(rateLimiters, provisioningManager),
new RemoteConfigController(remoteConfigsManager, adminEventLogger,
config.getRemoteConfigConfiguration().getAuthorizedTokens(),
config.getRemoteConfigConfiguration().getGlobalConfig()),
new SecureBackupController(backupCredentialsGenerator),
new SecureStorageController(storageCredentialsGenerator),
new StickerController(rateLimiters, config.getCdnConfiguration().getAccessKey(),
config.getCdnConfiguration().getAccessSecret(), config.getCdnConfiguration().getRegion(),
config.getCdnConfiguration().getBucket())
);
if (config.getSubscription() != null && config.getOneTimeDonations() != null) {
commonControllers.add(new SubscriptionController(clock, config.getSubscription(), config.getOneTimeDonations(),
subscriptionManager, stripeManager, braintreeManager, zkReceiptOperations, issuedReceiptsManager, profileBadgeConverter,
resourceBundleLevelTranslator));
}
for (Object controller : commonControllers) {
environment.jersey().register(controller);
webSocketEnvironment.jersey().register(controller);
}
WebSocketEnvironment<AuthenticatedAccount> provisioningEnvironment = new WebSocketEnvironment<>(environment, WebSocketEnvironment<AuthenticatedAccount> provisioningEnvironment = new WebSocketEnvironment<>(environment,
webSocketEnvironment.getRequestLog(), 60000); webSocketEnvironment.getRequestLog(), 60000);
provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager)); provisioningEnvironment.jersey().register(new WebsocketRefreshApplicationEventListener(accountsManager, clientPresenceManager));

View File

@ -30,4 +30,19 @@ public interface AbusiveMessageFilter extends ContainerRequestFilter, Managed {
* @throws IOException if the filter could not read its configuration source for any reason * @throws IOException if the filter could not read its configuration source for any reason
*/ */
void configure(String environmentName) throws IOException; void configure(String environmentName) throws IOException;
/**
* Builds a spam report token provider. This will generate tokens used by the spam reporting system.
*
* @return the configured spam report token provider.
*/
ReportSpamTokenProvider getReportSpamTokenProvider();
/**
* Builds a spam report token handler. This will handle tokens received by the spam reporting system.
*
* @return the configured spam report token handler
*/
ReportSpamTokenHandler getReportSpamTokenHandler();
} }

View File

@ -0,0 +1,47 @@
package org.whispersystems.textsecuregcm.abuse;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
/**
* Handles ReportSpamTokens during spam reports.
*/
public interface ReportSpamTokenHandler {
/**
* Handle spam reports using the given ReportSpamToken and other provided parameters.
*
* @param reportSpamToken binary data representing a spam report token.
* @return true if the token could be handled (and was), false otherwise.
*/
CompletableFuture<Boolean> handle(
Optional<String> sourceNumber,
Optional<UUID> sourceAci,
Optional<UUID> sourcePni,
UUID messageGuid,
UUID spamReporterUuid,
byte[] reportSpamToken);
/**
* Handler which does nothing.
*
* @return the handler
*/
static ReportSpamTokenHandler noop() {
return new ReportSpamTokenHandler() {
@Override
public CompletableFuture<Boolean> handle(
final Optional<String> sourceNumber,
final Optional<UUID> sourceAci,
final Optional<UUID> sourcePni,
final UUID messageGuid,
final UUID spamReporterUuid,
final byte[] reportSpamToken) {
return CompletableFuture.completedFuture(false);
}
};
}
}

View File

@ -0,0 +1,38 @@
package org.whispersystems.textsecuregcm.abuse;
import javax.ws.rs.container.ContainerRequestContext;
import java.util.Optional;
import java.util.function.Function;
/**
* Generates ReportSpamTokens to be used for spam reports.
*/
public interface ReportSpamTokenProvider {
/**
* Generate a new ReportSpamToken
*
* @param context the message request context
* @return either a generated token or nothing
*/
Optional<byte[]> makeReportSpamToken(ContainerRequestContext context);
/**
* Provider which generates nothing
*
* @return the provider
*/
static ReportSpamTokenProvider noop() {
return create(c -> Optional.empty());
}
/**
* Provider which generates ReportSpamTokens using the given function
*
* @param fn function from message requests to optional tokens
* @return the provider
*/
static ReportSpamTokenProvider create(Function<ContainerRequestContext, Optional<byte[]>> fn) {
return fn::apply;
}
}

View File

@ -55,12 +55,16 @@ import javax.ws.rs.PathParam;
import javax.ws.rs.Produces; import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam; import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException; import javax.ws.rs.WebApplicationException;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status; import javax.ws.rs.core.Response.Status;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.abuse.FilterAbusiveMessages; import org.whispersystems.textsecuregcm.abuse.FilterAbusiveMessages;
import org.whispersystems.textsecuregcm.abuse.ReportSpamTokenHandler;
import org.whispersystems.textsecuregcm.abuse.ReportSpamTokenProvider;
import org.whispersystems.textsecuregcm.auth.Anonymous; import org.whispersystems.textsecuregcm.auth.Anonymous;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKeys; import org.whispersystems.textsecuregcm.auth.CombinedUnidentifiedSenderAccessKeys;
@ -78,6 +82,7 @@ import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMessageResponse; import org.whispersystems.textsecuregcm.entities.SendMessageResponse;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
import org.whispersystems.textsecuregcm.entities.SpamReport;
import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
@ -117,6 +122,8 @@ public class MessageController {
private final PushNotificationManager pushNotificationManager; private final PushNotificationManager pushNotificationManager;
private final ReportMessageManager reportMessageManager; private final ReportMessageManager reportMessageManager;
private final ExecutorService multiRecipientMessageExecutor; private final ExecutorService multiRecipientMessageExecutor;
private final ReportSpamTokenProvider reportSpamTokenProvider;
private final ReportSpamTokenHandler reportSpamTokenHandler;
private static final String REJECT_OVERSIZE_MESSAGE_COUNTER = name(MessageController.class, "rejectOversizeMessage"); private static final String REJECT_OVERSIZE_MESSAGE_COUNTER = name(MessageController.class, "rejectOversizeMessage");
private static final String SENT_MESSAGE_COUNTER_NAME = name(MessageController.class, "sentMessages"); private static final String SENT_MESSAGE_COUNTER_NAME = name(MessageController.class, "sentMessages");
@ -147,7 +154,9 @@ public class MessageController {
MessagesManager messagesManager, MessagesManager messagesManager,
PushNotificationManager pushNotificationManager, PushNotificationManager pushNotificationManager,
ReportMessageManager reportMessageManager, ReportMessageManager reportMessageManager,
@Nonnull ExecutorService multiRecipientMessageExecutor) { @Nonnull ExecutorService multiRecipientMessageExecutor,
@Nonnull ReportSpamTokenProvider reportSpamTokenProvider,
@Nonnull ReportSpamTokenHandler reportSpamTokenHandler) {
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
this.messageSender = messageSender; this.messageSender = messageSender;
this.receiptSender = receiptSender; this.receiptSender = receiptSender;
@ -157,6 +166,9 @@ public class MessageController {
this.pushNotificationManager = pushNotificationManager; this.pushNotificationManager = pushNotificationManager;
this.reportMessageManager = reportMessageManager; this.reportMessageManager = reportMessageManager;
this.multiRecipientMessageExecutor = Objects.requireNonNull(multiRecipientMessageExecutor); this.multiRecipientMessageExecutor = Objects.requireNonNull(multiRecipientMessageExecutor);
this.reportSpamTokenProvider = reportSpamTokenProvider;
this.reportSpamTokenHandler = reportSpamTokenHandler;
} }
@Timed @Timed
@ -171,7 +183,9 @@ public class MessageController {
@HeaderParam(HttpHeaders.X_FORWARDED_FOR) String forwardedFor, @HeaderParam(HttpHeaders.X_FORWARDED_FOR) String forwardedFor,
@PathParam("destination") UUID destinationUuid, @PathParam("destination") UUID destinationUuid,
@QueryParam("story") boolean isStory, @QueryParam("story") boolean isStory,
@NotNull @Valid IncomingMessageList messages) @NotNull @Valid IncomingMessageList messages,
@Context ContainerRequestContext context
)
throws RateLimitExceededException { throws RateLimitExceededException {
if (source.isEmpty() && accessKey.isEmpty() && !isStory) { if (source.isEmpty() && accessKey.isEmpty() && !isStory) {
@ -190,6 +204,13 @@ public class MessageController {
senderType = SENDER_TYPE_UNIDENTIFIED; senderType = SENDER_TYPE_UNIDENTIFIED;
} }
final Optional<byte[]> spamReportToken;
if (senderType.equals(SENDER_TYPE_IDENTIFIED)) {
spamReportToken = reportSpamTokenProvider.makeReportSpamToken(context);
} else {
spamReportToken = Optional.empty();
}
for (final IncomingMessage message : messages.messages()) { for (final IncomingMessage message : messages.messages()) {
int contentLength = 0; int contentLength = 0;
@ -267,7 +288,18 @@ public class MessageController {
if (destinationDevice.isPresent()) { if (destinationDevice.isPresent()) {
Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment(); Metrics.counter(SENT_MESSAGE_COUNTER_NAME, tags).increment();
sendIndividualMessage(source, destination.get(), destinationDevice.get(), destinationUuid, messages.timestamp(), messages.online(), isStory, messages.urgent(), incomingMessage, userAgent); sendIndividualMessage(
source,
destination.get(),
destinationDevice.get(),
destinationUuid,
messages.timestamp(),
messages.online(),
isStory,
messages.urgent(),
incomingMessage,
userAgent,
spamReportToken);
} }
} }
@ -570,9 +602,14 @@ public class MessageController {
@Timed @Timed
@POST @POST
@Consumes(MediaType.APPLICATION_JSON)
@Path("/report/{source}/{messageGuid}") @Path("/report/{source}/{messageGuid}")
public Response reportMessage(@Auth AuthenticatedAccount auth, @PathParam("source") String source, public Response reportSpamMessage(
@PathParam("messageGuid") UUID messageGuid) { @Auth AuthenticatedAccount auth,
@PathParam("source") String source,
@PathParam("messageGuid") UUID messageGuid,
@Nullable @Valid SpamReport spamReport
) {
final Optional<String> sourceNumber; final Optional<String> sourceNumber;
final Optional<UUID> sourceAci; final Optional<UUID> sourceAci;
@ -602,13 +639,30 @@ public class MessageController {
} }
} }
reportMessageManager.report(sourceNumber, sourceAci, sourcePni, messageGuid, auth.getAccount().getUuid()); UUID spamReporterUuid = auth.getAccount().getUuid();
// spam report token is optional, but if provided ensure it is valid base64.
byte[] spamReportToken = null;
if (spamReport != null) {
try {
spamReportToken = Base64.getDecoder().decode(spamReport.token());
} catch (IllegalArgumentException e) {
throw new WebApplicationException(Response.status(400).build());
}
}
// fire-and-forget: we don't want to block the response on this action.
CompletableFuture<Boolean> ignored =
reportSpamTokenHandler.handle(sourceNumber, sourceAci, sourcePni, messageGuid, spamReporterUuid, spamReportToken);
reportMessageManager.report(sourceNumber, sourceAci, sourcePni, messageGuid, spamReporterUuid);
return Response.status(Status.ACCEPTED) return Response.status(Status.ACCEPTED)
.build(); .build();
} }
private void sendIndividualMessage(Optional<AuthenticatedAccount> source, private void sendIndividualMessage(
Optional<AuthenticatedAccount> source,
Account destinationAccount, Account destinationAccount,
Device destinationDevice, Device destinationDevice,
UUID destinationUuid, UUID destinationUuid,
@ -617,18 +671,23 @@ public class MessageController {
boolean story, boolean story,
boolean urgent, boolean urgent,
IncomingMessage incomingMessage, IncomingMessage incomingMessage,
String userAgentString) String userAgentString,
Optional<byte[]> spamReportToken)
throws NoSuchUserException { throws NoSuchUserException {
try { try {
final Envelope envelope; final Envelope envelope;
try { try {
envelope = incomingMessage.toEnvelope(destinationUuid, Account sourceAccount = source.map(AuthenticatedAccount::getAccount).orElse(null);
source.map(AuthenticatedAccount::getAccount).orElse(null), Long sourceDeviceId = source.map(account -> account.getAuthenticatedDevice().getId()).orElse(null);
source.map(authenticatedAccount -> authenticatedAccount.getAuthenticatedDevice().getId()).orElse(null), envelope = incomingMessage.toEnvelope(
destinationUuid,
sourceAccount,
sourceDeviceId,
timestamp == 0 ? System.currentTimeMillis() : timestamp, timestamp == 0 ? System.currentTimeMillis() : timestamp,
story, story,
urgent); urgent,
spamReportToken.orElse(null));
} catch (final IllegalArgumentException e) { } catch (final IllegalArgumentException e) {
logger.warn("Received bad envelope type {} from {}", incomingMessage.type(), userAgentString); logger.warn("Received bad envelope type {} from {}", incomingMessage.type(), userAgentString);
throw new BadRequestException(e); throw new BadRequestException(e);

View File

@ -18,7 +18,8 @@ public record IncomingMessage(int type, long destinationDeviceId, int destinatio
@Nullable Long sourceDeviceId, @Nullable Long sourceDeviceId,
final long timestamp, final long timestamp,
final boolean story, final boolean story,
final boolean urgent) { final boolean urgent,
@Nullable byte[] reportSpamToken) {
final MessageProtos.Envelope.Type envelopeType = MessageProtos.Envelope.Type.forNumber(type()); final MessageProtos.Envelope.Type envelopeType = MessageProtos.Envelope.Type.forNumber(type());
@ -36,10 +37,15 @@ public record IncomingMessage(int type, long destinationDeviceId, int destinatio
.setUrgent(urgent); .setUrgent(urgent);
if (sourceAccount != null && sourceDeviceId != null) { if (sourceAccount != null && sourceDeviceId != null) {
envelopeBuilder.setSourceUuid(sourceAccount.getUuid().toString()) envelopeBuilder
.setSourceUuid(sourceAccount.getUuid().toString())
.setSourceDevice(sourceDeviceId.intValue()); .setSourceDevice(sourceDeviceId.intValue());
} }
if (reportSpamToken != null) {
envelopeBuilder.setReportSpamToken(ByteString.copyFrom(reportSpamToken));
}
if (StringUtils.isNotEmpty(content())) { if (StringUtils.isNotEmpty(content())) {
envelopeBuilder.setContent(ByteString.copyFrom(Base64.getDecoder().decode(content()))); envelopeBuilder.setContent(ByteString.copyFrom(Base64.getDecoder().decode(content())));
} }

View File

@ -0,0 +1,7 @@
package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonProperty;
import javax.validation.Valid;
import javax.validation.constraints.NotEmpty;
public record SpamReport(@JsonProperty("token") @NotEmpty String token) {}

View File

@ -33,7 +33,8 @@ message Envelope {
optional bool urgent = 14 [default=true]; optional bool urgent = 14 [default=true];
optional string updated_pni = 15; optional string updated_pni = 15;
optional bool story = 16; // indicates that the content is a story. optional bool story = 16; // indicates that the content is a story.
// next: 17 optional bytes report_spam_token = 17; // token sent when reporting spam
// next: 18
} }
message ProvisioningUuid { message ProvisioningUuid {

View File

@ -8,6 +8,7 @@ package org.whispersystems.textsecuregcm.controllers;
import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
@ -17,6 +18,7 @@ import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.anyBoolean; import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
@ -65,6 +67,8 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.abuse.ReportSpamTokenHandler;
import org.whispersystems.textsecuregcm.abuse.ReportSpamTokenProvider;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.DisabledPermittedAuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.OptionalAccess; import org.whispersystems.textsecuregcm.auth.OptionalAccess;
@ -79,6 +83,7 @@ import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList; import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntityList;
import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse; import org.whispersystems.textsecuregcm.entities.SendMultiRecipientMessageResponse;
import org.whispersystems.textsecuregcm.entities.SignedPreKey; import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.entities.SpamReport;
import org.whispersystems.textsecuregcm.entities.StaleDevices; import org.whispersystems.textsecuregcm.entities.StaleDevices;
import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.limits.RateLimiters;
@ -124,8 +129,6 @@ class MessageControllerTest {
private static final String INTERNATIONAL_RECIPIENT = "+61123456789"; private static final String INTERNATIONAL_RECIPIENT = "+61123456789";
private static final UUID INTERNATIONAL_UUID = UUID.fromString("33333333-3333-3333-3333-333333333333"); private static final UUID INTERNATIONAL_UUID = UUID.fromString("33333333-3333-3333-3333-333333333333");
private Account internationalAccount;
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static final RedisAdvancedClusterCommands<String, String> redisCommands = mock(RedisAdvancedClusterCommands.class); private static final RedisAdvancedClusterCommands<String, String> redisCommands = mock(RedisAdvancedClusterCommands.class);
@ -139,6 +142,7 @@ class MessageControllerTest {
private static final PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class); private static final PushNotificationManager pushNotificationManager = mock(PushNotificationManager.class);
private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class); private static final ReportMessageManager reportMessageManager = mock(ReportMessageManager.class);
private static final ExecutorService multiRecipientMessageExecutor = mock(ExecutorService.class); private static final ExecutorService multiRecipientMessageExecutor = mock(ExecutorService.class);
private static final ReportSpamTokenHandler REPORT_SPAM_TOKEN_HANDLER = mock(ReportSpamTokenHandler.class);
private static final ResourceExtension resources = ResourceExtension.builder() private static final ResourceExtension resources = ResourceExtension.builder()
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE) .addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
@ -150,7 +154,8 @@ class MessageControllerTest {
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource( .addResource(
new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, deletedAccountsManager, new MessageController(rateLimiters, messageSender, receiptSender, accountsManager, deletedAccountsManager,
messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor)) messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor,
ReportSpamTokenProvider.noop(), REPORT_SPAM_TOKEN_HANDLER))
.build(); .build();
@BeforeEach @BeforeEach
@ -167,7 +172,8 @@ class MessageControllerTest {
Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, UNIDENTIFIED_ACCESS_BYTES); Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, UNIDENTIFIED_ACCESS_BYTES);
Account multiDeviceAccount = AccountsHelper.generateTestAccount(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, MULTI_DEVICE_PNI, multiDeviceList, UNIDENTIFIED_ACCESS_BYTES); Account multiDeviceAccount = AccountsHelper.generateTestAccount(MULTI_DEVICE_RECIPIENT, MULTI_DEVICE_UUID, MULTI_DEVICE_PNI, multiDeviceList, UNIDENTIFIED_ACCESS_BYTES);
internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID, UUID.randomUUID(), singleDeviceList, UNIDENTIFIED_ACCESS_BYTES); Account internationalAccount = AccountsHelper.generateTestAccount(INTERNATIONAL_RECIPIENT, INTERNATIONAL_UUID,
UUID.randomUUID(), singleDeviceList, UNIDENTIFIED_ACCESS_BYTES);
when(accountsManager.getByAccountIdentifier(eq(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount)); when(accountsManager.getByAccountIdentifier(eq(SINGLE_DEVICE_UUID))).thenReturn(Optional.of(singleDeviceAccount));
when(accountsManager.getByPhoneNumberIdentifier(SINGLE_DEVICE_PNI)).thenReturn(Optional.of(singleDeviceAccount)); when(accountsManager.getByPhoneNumberIdentifier(SINGLE_DEVICE_PNI)).thenReturn(Optional.of(singleDeviceAccount));
@ -176,6 +182,8 @@ class MessageControllerTest {
when(accountsManager.getByAccountIdentifier(INTERNATIONAL_UUID)).thenReturn(Optional.of(internationalAccount)); when(accountsManager.getByAccountIdentifier(INTERNATIONAL_UUID)).thenReturn(Optional.of(internationalAccount));
when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter); when(rateLimiters.getMessagesLimiter()).thenReturn(rateLimiter);
when(REPORT_SPAM_TOKEN_HANDLER.handle(any(), any(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(false));
} }
private static Device generateTestDevice(final long id, final int registrationId, final int pniRegistrationId, final SignedPreKey signedPreKey, final long createdAt, final long lastSeen) { private static Device generateTestDevice(final long id, final int registrationId, final int pniRegistrationId, final SignedPreKey signedPreKey, final long createdAt, final long lastSeen) {
@ -198,12 +206,14 @@ class MessageControllerTest {
messageSender, messageSender,
receiptSender, receiptSender,
accountsManager, accountsManager,
deletedAccountsManager,
messagesManager, messagesManager,
rateLimiters, rateLimiters,
rateLimiter, rateLimiter,
pushNotificationManager, pushNotificationManager,
reportMessageManager, reportMessageManager,
multiRecipientMessageExecutor multiRecipientMessageExecutor,
REPORT_SPAM_TOKEN_HANDLER
); );
} }
@ -692,6 +702,64 @@ class MessageControllerTest {
messageGuid, AuthHelper.VALID_UUID); messageGuid, AuthHelper.VALID_UUID);
} }
@Test
void testReportMessageByAciWithSpamReportToken() {
final String senderNumber = "+12125550001";
final UUID senderAci = UUID.randomUUID();
final UUID senderPni = UUID.randomUUID();
UUID messageGuid = UUID.randomUUID();
final Account account = mock(Account.class);
when(account.getUuid()).thenReturn(senderAci);
when(account.getNumber()).thenReturn(senderNumber);
when(account.getPhoneNumberIdentifier()).thenReturn(senderPni);
when(accountsManager.getByAccountIdentifier(senderAci)).thenReturn(Optional.of(account));
when(deletedAccountsManager.findDeletedAccountE164(senderAci)).thenReturn(Optional.of(senderNumber));
when(accountsManager.getPhoneNumberIdentifier(senderNumber)).thenReturn(senderPni);
when(REPORT_SPAM_TOKEN_HANDLER.handle(any(), any(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(true));
ArgumentCaptor<byte[]> captor = ArgumentCaptor.forClass(byte[].class);
String token = Base64.getEncoder().encodeToString(new byte[3]);
Entity<SpamReport> entity = Entity.entity(new SpamReport(token), "application/json");
Response response =
resources.getJerseyTest()
.target(String.format("/v1/messages/report/%s/%s", senderAci, messageGuid))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.post(entity);
assertThat(response.getStatus(), is(equalTo(202)));
verify(REPORT_SPAM_TOKEN_HANDLER).handle(any(), any(), any(), any(), any(), captor.capture());
assertArrayEquals(new byte[3], captor.getValue());
verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni),
messageGuid, AuthHelper.VALID_UUID);
verify(deletedAccountsManager, never()).findDeletedAccountE164(any(UUID.class));
verify(accountsManager, never()).getPhoneNumberIdentifier(anyString());
when(accountsManager.getByAccountIdentifier(senderAci)).thenReturn(Optional.empty());
clearInvocations(REPORT_SPAM_TOKEN_HANDLER);
messageGuid = UUID.randomUUID();
token = Base64.getEncoder().encodeToString(new byte[5]);
entity = Entity.entity(new SpamReport(token), "application/json");
response =
resources.getJerseyTest()
.target(String.format("/v1/messages/report/%s/%s", senderAci, messageGuid))
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.post(entity);
assertThat(response.getStatus(), is(equalTo(202)));
verify(REPORT_SPAM_TOKEN_HANDLER).handle(any(), any(), any(), any(), any(), captor.capture());
assertArrayEquals(new byte[5], captor.getValue());
verify(reportMessageManager).report(Optional.of(senderNumber), Optional.of(senderAci), Optional.of(senderPni),
messageGuid, AuthHelper.VALID_UUID);
}
@Test @Test
void testValidateContentLength() throws Exception { void testValidateContentLength() throws Exception {
final int contentLength = Math.toIntExact(MessageController.MAX_MESSAGE_SIZE + 1); final int contentLength = Math.toIntExact(MessageController.MAX_MESSAGE_SIZE + 1);