Explicitly call spam-filter for messages

Pass in the same information to the spam-filter, but just use explicit
method calls rather than jersey request filters.
This commit is contained in:
Ravi Khadiwala 2024-02-05 11:59:04 -06:00 committed by ravi-signal
parent 0965ab8063
commit 3b44ed6d16
7 changed files with 155 additions and 104 deletions

View File

@ -43,6 +43,7 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.stream.Stream;
import javax.servlet.DispatcherType;
import javax.servlet.Filter;
import javax.servlet.FilterRegistration;
@ -177,6 +178,7 @@ import org.whispersystems.textsecuregcm.spam.RateLimitChallengeListener;
import org.whispersystems.textsecuregcm.spam.ReportSpamTokenProvider;
import org.whispersystems.textsecuregcm.spam.ScoreThresholdProvider;
import org.whispersystems.textsecuregcm.spam.SenderOverrideProvider;
import org.whispersystems.textsecuregcm.spam.SpamChecker;
import org.whispersystems.textsecuregcm.spam.SpamFilter;
import org.whispersystems.textsecuregcm.storage.AccountLockManager;
import org.whispersystems.textsecuregcm.storage.Accounts;
@ -773,55 +775,51 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
webSocketEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET, clientReleaseManager));
webSocketEnvironment.jersey().register(new KeepAliveController(clientPresenceManager));
boolean registeredSpamFilter = false;
ReportSpamTokenProvider reportSpamTokenProvider = null;
List<RateLimitChallengeListener> rateLimitChallengeListeners = new ArrayList<>();
for (final SpamFilter filter : ServiceLoader.load(SpamFilter.class)) {
if (filter.getClass().isAnnotationPresent(FilterSpam.class)) {
try {
filter.configure(config.getSpamFilterConfiguration().getEnvironment());
ReportSpamTokenProvider thisProvider = filter.getReportSpamTokenProvider();
if (reportSpamTokenProvider == null) {
reportSpamTokenProvider = thisProvider;
} else if (thisProvider != null) {
log.info("Multiple spam report token providers found. Using the first.");
final List<SpamFilter> spamFilters = ServiceLoader.load(SpamFilter.class)
.stream()
.map(ServiceLoader.Provider::get)
.filter(s -> s.getClass().isAnnotationPresent(FilterSpam.class))
.flatMap(filter -> {
try {
filter.configure(config.getSpamFilterConfiguration().getEnvironment());
return Stream.of(filter);
} catch (Exception e) {
log.warn("Failed to register spam filter: {}", filter.getClass().getName(), e);
return Stream.empty();
}
filter.getReportedMessageListeners().forEach(reportMessageManager::addListener);
environment.lifecycle().manage(filter);
environment.jersey().register(filter);
webSocketEnvironment.jersey().register(filter);
log.info("Registered spam filter: {}", filter.getClass().getName());
registeredSpamFilter = true;
} catch (final Exception e) {
log.warn("Failed to register spam filter: {}", filter.getClass().getName(), e);
}
} else {
log.warn("Spam filter {} not annotated with @FilterSpam and will not be installed",
filter.getClass().getName());
}
if (filter instanceof RateLimitChallengeListener) {
log.info("Registered rate limit challenge listener: {}", filter.getClass().getName());
rateLimitChallengeListeners.add((RateLimitChallengeListener) filter);
}
})
.toList();
if (spamFilters.size() > 1) {
log.warn("Multiple spam report token providers found. Using the first.");
}
RateLimitChallengeManager rateLimitChallengeManager = new RateLimitChallengeManager(pushChallengeManager,
captchaChecker, rateLimiters, rateLimitChallengeListeners);
if (!registeredSpamFilter) {
final Optional<SpamFilter> spamFilter = spamFilters.stream().findFirst();
if (spamFilter.isEmpty()) {
log.warn("No spam filters installed");
}
final ReportSpamTokenProvider reportSpamTokenProvider = spamFilter
.map(SpamFilter::getReportSpamTokenProvider)
.orElseGet(() -> {
log.warn("No spam-reporting token providers found; using default (no-op) provider as a default");
return ReportSpamTokenProvider.noop();
});
final SpamChecker spamChecker = spamFilter
.map(SpamFilter::getSpamChecker)
.orElseGet(() -> {
log.warn("No spam-checkers found; using default (no-op) provider as a default");
return SpamChecker.noop();
});
spamFilter.map(SpamFilter::getReportedMessageListener).ifPresent(reportMessageManager::addListener);
if (reportSpamTokenProvider == null) {
log.warn("No spam-reporting token providers found; using default (no-op) provider as a default");
reportSpamTokenProvider = ReportSpamTokenProvider.noop();
}
final RateLimitChallengeManager rateLimitChallengeManager = new RateLimitChallengeManager(pushChallengeManager,
captchaChecker, rateLimiters, spamFilter.map(SpamFilter::getRateLimitChallengeListener).stream().toList());
spamFilter.ifPresent(filter -> {
environment.lifecycle().manage(filter);
environment.jersey().register(filter);
webSocketEnvironment.jersey().register(filter);
log.info("Registered spam filter: {}", filter.getClass().getName());
});
final List<Object> commonControllers = Lists.newArrayList(
new AccountController(accountsManager, rateLimiters, turnTokenGenerator, registrationRecoveryPasswordsManager,
@ -850,7 +848,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new MessageController(rateLimiters, messageByteLimitCardinalityEstimator, messageSender, receiptSender,
accountsManager, messagesManager, pushNotificationManager, reportMessageManager,
multiRecipientMessageExecutor, messageDeliveryScheduler, reportSpamTokenProvider, clientReleaseManager,
dynamicConfigurationManager, zkSecretParams),
dynamicConfigurationManager, zkSecretParams, spamChecker),
new PaymentsController(currencyManager, paymentsCredentialsGenerator),
new ProfileController(clock, rateLimiters, accountsManager, profilesManager, dynamicConfigurationManager,
profileBadgeConverter, config.getBadges(), cdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner,

View File

@ -112,6 +112,7 @@ import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.spam.FilterSpam;
import org.whispersystems.textsecuregcm.spam.ReportSpamTokenProvider;
import org.whispersystems.textsecuregcm.spam.SpamChecker;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
@ -135,6 +136,7 @@ import reactor.util.function.Tuples;
@io.swagger.v3.oas.annotations.tags.Tag(name = "Messages")
public class MessageController {
private record MultiRecipientDeliveryData(
ServiceIdentifier serviceIdentifier,
Account account,
@ -144,8 +146,6 @@ public class MessageController {
private static final Logger logger = LoggerFactory.getLogger(MessageController.class);
public static final String DESTINATION_ACCOUNT_PROPERTY_NAME = "destinationAccount";
private final RateLimiters rateLimiters;
private final CardinalityEstimator messageByteLimitEstimator;
private final MessageSender messageSender;
@ -160,6 +160,7 @@ public class MessageController {
private final ClientReleaseManager clientReleaseManager;
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
private final ServerSecretParams serverSecretParams;
private final SpamChecker spamChecker;
private static final int MAX_FETCH_ACCOUNT_CONCURRENCY = 8;
@ -202,7 +203,8 @@ public class MessageController {
@Nonnull ReportSpamTokenProvider reportSpamTokenProvider,
final ClientReleaseManager clientReleaseManager,
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
final ServerSecretParams serverSecretParams) {
final ServerSecretParams serverSecretParams,
final SpamChecker spamChecker) {
this.rateLimiters = rateLimiters;
this.messageByteLimitEstimator = messageByteLimitEstimator;
this.messageSender = messageSender;
@ -217,6 +219,7 @@ public class MessageController {
this.clientReleaseManager = clientReleaseManager;
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.serverSecretParams = serverSecretParams;
this.spamChecker = spamChecker;
}
@Timed
@ -224,7 +227,6 @@ public class MessageController {
@PUT
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@FilterSpam
@ManagedAsync
public Response sendMessage(@Auth Optional<AuthenticatedAccount> source,
@HeaderParam(HeaderUtils.UNIDENTIFIED_ACCESS_KEY) Optional<Anonymous> accessKey,
@ -234,12 +236,12 @@ public class MessageController {
@NotNull @Valid IncomingMessageList messages,
@Context ContainerRequestContext context) throws RateLimitExceededException {
if (source.isEmpty() && accessKey.isEmpty() && !isStory) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}
final String senderType;
if (source.isPresent()) {
if (source.get().getAccount().isIdentifiedBy(destinationIdentifier)) {
senderType = SENDER_TYPE_SELF;
@ -250,13 +252,31 @@ public class MessageController {
senderType = SENDER_TYPE_UNIDENTIFIED;
}
final Optional<byte[]> spamReportToken;
if (senderType.equals(SENDER_TYPE_IDENTIFIED)) {
spamReportToken = reportSpamTokenProvider.makeReportSpamToken(context);
} else {
spamReportToken = Optional.empty();
boolean isSyncMessage = source.isPresent() && source.get().getAccount().isIdentifiedBy(destinationIdentifier);
if (isSyncMessage && destinationIdentifier.identityType() == IdentityType.PNI) {
throw new WebApplicationException(Status.FORBIDDEN);
}
Optional<Account> destination;
if (!isSyncMessage) {
destination = accountsManager.getByServiceIdentifier(destinationIdentifier);
} else {
destination = source.map(AuthenticatedAccount::getAccount);
}
final Optional<Response> spamCheck = spamChecker.checkForSpam(
context, source.map(AuthenticatedAccount::getAccount), destination);
if (spamCheck.isPresent()) {
return spamCheck.get();
}
final Optional<byte[]> spamReportToken = switch (senderType) {
case SENDER_TYPE_IDENTIFIED ->
reportSpamTokenProvider.makeReportSpamToken(context, source.get().getAccount(), destination);
default -> Optional.empty();
};
int totalContentLength = 0;
for (final IncomingMessage message : messages.messages()) {
@ -282,21 +302,6 @@ public class MessageController {
}
try {
boolean isSyncMessage = source.isPresent() && source.get().getAccount().isIdentifiedBy(destinationIdentifier);
if (isSyncMessage && destinationIdentifier.identityType() == IdentityType.PNI) {
throw new WebApplicationException(Status.FORBIDDEN);
}
Optional<Account> destination;
if (!isSyncMessage) {
destination = accountsManager.getByServiceIdentifier(destinationIdentifier);
} else {
destination = source.map(AuthenticatedAccount::getAccount);
}
destination.ifPresent(account -> context.setProperty(DESTINATION_ACCOUNT_PROPERTY_NAME, account));
// Stories will be checked by the client; we bypass access checks here for stories.
if (!isStory) {
OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination);
@ -416,7 +421,6 @@ public class MessageController {
@PUT
@Consumes(MultiRecipientMessageProvider.MEDIA_TYPE)
@Produces(MediaType.APPLICATION_JSON)
@FilterSpam
@Operation(
summary = "Send multi-recipient sealed-sender message",
description = """
@ -460,7 +464,15 @@ public class MessageController {
@Parameter(description="If true, the message is a story; access tokens are not checked and sending to nonexistent recipients is permitted")
@QueryParam("story") boolean isStory,
@Parameter(description="The sealed-sender multi-recipient message payload as serialized by libsignal")
@NotNull SealedSenderMultiRecipientMessage multiRecipientMessage) throws RateLimitExceededException {
@NotNull SealedSenderMultiRecipientMessage multiRecipientMessage,
@Context ContainerRequestContext context) throws RateLimitExceededException {
final Optional<Response> spamCheck = spamChecker.checkForSpam(context, Optional.empty(), Optional.empty());
if (spamCheck.isPresent()) {
return spamCheck.get();
}
if (groupSendCredential == null && accessKeys == null && !isStory) {
throw new NotAuthorizedException("A group send credential or unidentified access key is required for non-story messages");
}

View File

@ -12,13 +12,4 @@ import java.io.IOException;
public interface RateLimitChallengeListener {
void handleRateLimitChallengeAnswered(Account account, ChallengeType type);
/**
* Configures this rate limit challenge listener. This method will be called before the service begins processing any
* challenges.
*
* @param environmentName the name of the environment in which this listener is running (e.g. "staging" or "production")
* @throws IOException if the listener could not read its configuration source for any reason
*/
void configure(String environmentName) throws IOException;
}

View File

@ -1,5 +1,6 @@
package org.whispersystems.textsecuregcm.spam;
import org.whispersystems.textsecuregcm.storage.Account;
import javax.ws.rs.container.ContainerRequestContext;
import java.util.Optional;
import java.util.function.Function;
@ -12,10 +13,13 @@ public interface ReportSpamTokenProvider {
/**
* Generate a new ReportSpamToken
*
* @param context the message request context
* @param context the message request context
* @param sender the account that sent the unsealed sender message
* @param maybeDestination the intended recepient of the message if available
* @return either a generated token or nothing
*/
Optional<byte[]> makeReportSpamToken(ContainerRequestContext context);
Optional<byte[]> makeReportSpamToken(ContainerRequestContext context, final Account sender,
final Optional<Account> maybeDestination);
/**
* Provider which generates nothing
@ -23,6 +27,6 @@ public interface ReportSpamTokenProvider {
* @return the provider
*/
static ReportSpamTokenProvider noop() {
return context -> Optional.empty();
return (ignoredContext, ignoredSender, ignoredDest) -> Optional.empty();
}
}

View File

@ -0,0 +1,31 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.spam;
import org.whispersystems.textsecuregcm.storage.Account;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.core.Response;
import java.util.Optional;
public interface SpamChecker {
/**
* Determine if a message may be spam
*
* @param requestContext The request context for a message send attempt
* @param maybeSource The sender of the message, could be empty if this as message sent with sealed sender
* @param maybeDestination The destination of the message, could be empty if the destination does not exist or could
* not be retrieved
* @return A response to return if the request is determined to be spam, otherwise empty if the message should be sent
*/
Optional<Response> checkForSpam(
final ContainerRequestContext requestContext,
final Optional<Account> maybeSource,
final Optional<Account> maybeDestination);
static SpamChecker noop() {
return (ignoredContext, ignoredSource, ignoredDestination) -> Optional.empty();
}
}

View File

@ -9,26 +9,26 @@ import io.dropwizard.lifecycle.Managed;
import org.whispersystems.textsecuregcm.storage.ReportedMessageListener;
import javax.ws.rs.container.ContainerRequestFilter;
import java.io.IOException;
import java.util.List;
/**
* A spam filter is a {@link ContainerRequestFilter} that filters requests to message-sending endpoints to
* detect and respond to patterns of spam.
* A spam filter is a {@link ContainerRequestFilter} that filters requests to endpoints to detect and respond to
* patterns of spam and fraud.
* <p/>
* Spam filters are managed components that are generally loaded dynamically via a
* {@link java.util.ServiceLoader}. Their {@link #configure(String)} method will be called prior to be adding to the
* server's pool of {@link Managed} objects.
* Spam filters are managed components that are generally loaded dynamically via a {@link java.util.ServiceLoader}.
* Their {@link #configure(String)} method will be called prior to be adding to the server's pool of {@link Managed}
* objects.
* <p/>
* Spam filters must be annotated with {@link FilterSpam}, a name binding annotation that
* restricts the endpoints to which the filter may apply.
* Spam filters must be annotated with {@link FilterSpam}, a name binding annotation that restricts the endpoints to
* which the filter may apply.
*/
public interface SpamFilter extends ContainerRequestFilter, Managed {
/**
* Configures this spam filter. This method will be called before the filter is added to the server's pool
* of managed objects and before the server processes any requests.
* Configures this spam filter. This method will be called before the filter is added to the server's pool of managed
* objects and before the server processes any requests.
*
* @param environmentName the name of the environment in which this filter is running (e.g. "staging" or "production")
* @param environmentName the name of the environment in which this filter is running (e.g. "staging" or
* "production")
* @throws IOException if the filter could not read its configuration source for any reason
*/
void configure(String environmentName) throws IOException;
@ -41,10 +41,27 @@ public interface SpamFilter extends ContainerRequestFilter, Managed {
ReportSpamTokenProvider getReportSpamTokenProvider();
/**
* Return any and all reported message listeners controlled by the spam filter. Listeners will be registered with the
* Return a reported message listener controlled by the spam filter. Listeners will be registered with the
* {@link org.whispersystems.textsecuregcm.storage.ReportMessageManager}.
*
* @return a list of reported message listeners controlled by the spam filter
* @return a reported message listener controlled by the spam filter
*/
List<ReportedMessageListener> getReportedMessageListeners();
ReportedMessageListener getReportedMessageListener();
/**
* Return a rate limit challenge listener. Listeners will be registered with the
* {@link org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager}
*
* @return a {@link RateLimitChallengeListener} controlled by the spam filter
*/
RateLimitChallengeListener getRateLimitChallengeListener();
/**
* Return a spam checker that will be called on message sends via the
* {@link org.whispersystems.textsecuregcm.controllers.MessageController} to determine whether a specific message
* spend is spam.
*
* @return a {@link SpamChecker} controlled by the spam filter
*/
SpamChecker getSpamChecker();
}

View File

@ -48,7 +48,6 @@ import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@ -81,10 +80,10 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.ArgumentSets;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.mockito.ArgumentCaptor;
import org.signal.libsignal.protocol.ServiceId;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.signal.libsignal.protocol.util.Hex;
import org.signal.libsignal.zkgroup.ServerPublicParams;
import org.signal.libsignal.zkgroup.ServerSecretParams;
import org.signal.libsignal.zkgroup.groups.ClientZkGroupCipher;
import org.signal.libsignal.zkgroup.groups.GroupMasterKey;
import org.signal.libsignal.zkgroup.groups.GroupSecretParams;
@ -92,8 +91,6 @@ import org.signal.libsignal.zkgroup.groups.UuidCiphertext;
import org.signal.libsignal.zkgroup.groupsend.GroupSendCredential;
import org.signal.libsignal.zkgroup.groupsend.GroupSendCredentialPresentation;
import org.signal.libsignal.zkgroup.groupsend.GroupSendCredentialResponse;
import org.signal.libsignal.zkgroup.ServerPublicParams;
import org.signal.libsignal.zkgroup.ServerSecretParams;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.UnidentifiedAccessUtil;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
@ -123,6 +120,7 @@ import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.spam.ReportSpamTokenProvider;
import org.whispersystems.textsecuregcm.spam.SpamChecker;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
@ -206,7 +204,7 @@ class MessageControllerTest {
new MessageController(rateLimiters, cardinalityEstimator, messageSender, receiptSender, accountsManager,
messagesManager, pushNotificationManager, reportMessageManager, multiRecipientMessageExecutor,
messageDeliveryScheduler, ReportSpamTokenProvider.noop(), mock(ClientReleaseManager.class), dynamicConfigurationManager,
serverSecretParams))
serverSecretParams, SpamChecker.noop()))
.build();
@BeforeEach