Explicitly call spam-filter for challenges

Pass in the same information to the spam-filter, but just use explicit
method calls rather than jersey request filters.
This commit is contained in:
Ravi Khadiwala 2024-02-15 18:21:57 -06:00 committed by ravi-signal
parent 30b5ad1515
commit 4f40c128bf
7 changed files with 109 additions and 147 deletions

View File

@ -181,8 +181,8 @@ import org.whispersystems.textsecuregcm.s3.PolicySigner;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator; import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker;
import org.whispersystems.textsecuregcm.spam.FilterSpam; import org.whispersystems.textsecuregcm.spam.FilterSpam;
import org.whispersystems.textsecuregcm.spam.PushChallengeConfigProvider;
import org.whispersystems.textsecuregcm.spam.ReportSpamTokenProvider; import org.whispersystems.textsecuregcm.spam.ReportSpamTokenProvider;
import org.whispersystems.textsecuregcm.spam.ScoreThresholdProvider; import org.whispersystems.textsecuregcm.spam.ScoreThresholdProvider;
import org.whispersystems.textsecuregcm.spam.SenderOverrideProvider; import org.whispersystems.textsecuregcm.spam.SenderOverrideProvider;
@ -892,6 +892,12 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
log.warn("No spam-checkers found; using default (no-op) provider as a default"); log.warn("No spam-checkers found; using default (no-op) provider as a default");
return SpamChecker.noop(); return SpamChecker.noop();
}); });
final ChallengeConstraintChecker challengeConstraintChecker = spamFilter
.map(SpamFilter::getChallengeConstraintChecker)
.orElseGet(() -> {
log.warn("No challenge-constraint-checkers found; using default (no-op) provider as a default");
return ChallengeConstraintChecker.noop();
});
spamFilter.map(SpamFilter::getReportedMessageListener).ifPresent(reportMessageManager::addListener); spamFilter.map(SpamFilter::getReportedMessageListener).ifPresent(reportMessageManager::addListener);
final RateLimitChallengeManager rateLimitChallengeManager = new RateLimitChallengeManager(pushChallengeManager, final RateLimitChallengeManager rateLimitChallengeManager = new RateLimitChallengeManager(pushChallengeManager,
@ -923,7 +929,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().certificate().value(), new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().certificate().value(),
config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()), config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()),
zkAuthOperations, callingGenericZkSecretParams, clock), zkAuthOperations, callingGenericZkSecretParams, clock),
new ChallengeController(rateLimitChallengeManager), new ChallengeController(rateLimitChallengeManager, challengeConstraintChecker),
new DeviceController(config.getLinkDeviceSecretConfiguration().secret().value(), accountsManager, new DeviceController(config.getLinkDeviceSecretConfiguration().secret().value(), accountsManager,
rateLimiters, rateLimitersCluster, config.getMaxDevices(), clock), rateLimiters, rateLimitersCluster, config.getMaxDevices(), clock),
new DirectoryV2Controller(directoryV2CredentialsGenerator), new DirectoryV2Controller(directoryV2CredentialsGenerator),
@ -1009,8 +1015,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
WebSocketEnvironment<AuthenticatedAccount> provisioningEnvironment) { WebSocketEnvironment<AuthenticatedAccount> provisioningEnvironment) {
List.of( List.of(
ScoreThresholdProvider.ScoreThresholdFeature.class, ScoreThresholdProvider.ScoreThresholdFeature.class,
SenderOverrideProvider.SenderOverrideFeature.class, SenderOverrideProvider.SenderOverrideFeature.class)
PushChallengeConfigProvider.PushChallengeConfigFeature.class)
.forEach(feature -> { .forEach(feature -> {
environment.jersey().register(feature); environment.jersey().register(feature);
webSocketEnvironment.jersey().register(feature); webSocketEnvironment.jersey().register(feature);

View File

@ -38,24 +38,25 @@ import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager; import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.spam.Extract; import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker;
import org.whispersystems.textsecuregcm.spam.FilterSpam; import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker.ChallengeConstraints;
import org.whispersystems.textsecuregcm.spam.PushChallengeConfig;
import org.whispersystems.textsecuregcm.spam.ScoreThreshold;
import org.whispersystems.websocket.auth.ReadOnly; import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v1/challenge") @Path("/v1/challenge")
@Tag(name = "Challenge") @Tag(name = "Challenge")
@FilterSpam
public class ChallengeController { public class ChallengeController {
private final RateLimitChallengeManager rateLimitChallengeManager; private final RateLimitChallengeManager rateLimitChallengeManager;
private final ChallengeConstraintChecker challengeConstraintChecker;
private static final String CHALLENGE_RESPONSE_COUNTER_NAME = name(ChallengeController.class, "challengeResponse"); private static final String CHALLENGE_RESPONSE_COUNTER_NAME = name(ChallengeController.class, "challengeResponse");
private static final String CHALLENGE_TYPE_TAG = "type"; private static final String CHALLENGE_TYPE_TAG = "type";
public ChallengeController(final RateLimitChallengeManager rateLimitChallengeManager) { public ChallengeController(
final RateLimitChallengeManager rateLimitChallengeManager,
final ChallengeConstraintChecker challengeConstraintChecker) {
this.rateLimitChallengeManager = rateLimitChallengeManager; this.rateLimitChallengeManager = rateLimitChallengeManager;
this.challengeConstraintChecker = challengeConstraintChecker;
} }
@PUT @PUT
@ -81,17 +82,17 @@ public class ChallengeController {
public Response handleChallengeResponse(@ReadOnly @Auth final AuthenticatedAccount auth, public Response handleChallengeResponse(@ReadOnly @Auth final AuthenticatedAccount auth,
@Valid final AnswerChallengeRequest answerRequest, @Valid final AnswerChallengeRequest answerRequest,
@Context ContainerRequestContext requestContext, @Context ContainerRequestContext requestContext,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) throws RateLimitExceededException, IOException {
@Extract final ScoreThreshold captchaScoreThreshold,
@Extract final PushChallengeConfig pushChallengeConfig) throws RateLimitExceededException, IOException {
Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent)); Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent));
final ChallengeConstraints constraints = challengeConstraintChecker.challengeConstraints(
requestContext, auth.getAccount());
try { try {
if (answerRequest instanceof final AnswerPushChallengeRequest pushChallengeRequest) { if (answerRequest instanceof final AnswerPushChallengeRequest pushChallengeRequest) {
tags = tags.and(CHALLENGE_TYPE_TAG, "push"); tags = tags.and(CHALLENGE_TYPE_TAG, "push");
if (!pushChallengeConfig.pushPermitted()) { if (!constraints.pushPermitted()) {
return Response.status(429).build(); return Response.status(429).build();
} }
rateLimitChallengeManager.answerPushChallenge(auth.getAccount(), pushChallengeRequest.getChallenge()); rateLimitChallengeManager.answerPushChallenge(auth.getAccount(), pushChallengeRequest.getChallenge());
@ -105,7 +106,7 @@ public class ChallengeController {
recaptchaChallengeRequest.getCaptcha(), recaptchaChallengeRequest.getCaptcha(),
remoteAddress, remoteAddress,
userAgent, userAgent,
captchaScoreThreshold.getScoreThreshold()); constraints.captchaScoreThreshold());
if (!success) { if (!success) {
return Response.status(428).build(); return Response.status(428).build();
@ -165,8 +166,10 @@ public class ChallengeController {
name = "Retry-After", name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed")) description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public Response requestPushChallenge(@ReadOnly @Auth final AuthenticatedAccount auth, public Response requestPushChallenge(@ReadOnly @Auth final AuthenticatedAccount auth,
@Extract PushChallengeConfig pushChallengeConfig) { @Context ContainerRequestContext requestContext) {
if (!pushChallengeConfig.pushPermitted()) { final ChallengeConstraints constraints = challengeConstraintChecker.challengeConstraints(
requestContext, auth.getAccount());
if (!constraints.pushPermitted()) {
return Response.status(429).build(); return Response.status(429).build();
} }
try { try {

View File

@ -0,0 +1,27 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.spam;
import org.whispersystems.textsecuregcm.storage.Account;
import javax.ws.rs.container.ContainerRequestContext;
import java.util.Optional;
public interface ChallengeConstraintChecker {
record ChallengeConstraints(boolean pushPermitted, Optional<Float> captchaScoreThreshold) {}
/**
* Retrieve constraints for captcha and push challenges
*
* @param authenticatedAccount The authenticated account attempting to request or solve a challenge
* @return ChallengeConstraints indicating what constraints should be applied to challenges
*/
ChallengeConstraints challengeConstraints(ContainerRequestContext requestContext, Account authenticatedAccount);
static ChallengeConstraintChecker noop() {
return (account, ctx) -> new ChallengeConstraints(true, Optional.empty());
}
}

View File

@ -1,44 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.spam;
import java.util.Optional;
import org.glassfish.jersey.server.ContainerRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A PushChallengeConfig may be provided by an upstream request filter. If request contains a
* property for PROPERTY_NAME it can be forwarded to a downstream filter to indicate whether
* push-token challenges can be used in place of captchas when evaluating whether a request should
* be allowed to continue.
*/
public class PushChallengeConfig {
private static final Logger logger = LoggerFactory.getLogger(PushChallengeConfig.class);
public static final String PROPERTY_NAME = "pushChallengePermitted";
/**
* A score threshold in the range [0, 1.0]
*/
private final boolean pushPermitted;
/**
* Extract an optional score threshold parameter provided by an upstream request filter
*/
public PushChallengeConfig(final ContainerRequest containerRequest) {
this.pushPermitted = Optional
.ofNullable(containerRequest.getProperty(PROPERTY_NAME))
.filter(obj -> obj instanceof Boolean)
.map(obj -> (Boolean) obj)
.orElse(true); // not a typo! true is the default
}
public boolean pushPermitted() {
return pushPermitted;
}
}

View File

@ -1,60 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.spam;
import java.util.function.Function;
import javax.inject.Singleton;
import javax.ws.rs.core.Feature;
import javax.ws.rs.core.FeatureContext;
import org.glassfish.jersey.internal.inject.AbstractBinder;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.model.Parameter;
import org.glassfish.jersey.server.spi.internal.ValueParamProvider;
/**
* Parses a {@link PushChallengeConfig} out of a {@link ContainerRequest} to provide to jersey resources.
*
* A request filter may enrich a ContainerRequest with a PushChallengeConfig by providing a float
* property with the name {@link PushChallengeConfig#PROPERTY_NAME}. This indicates whether push
* challenges may be considered when evaluating whether a request should proceed.
*
* A resource can consume a PushChallengeConfig with by annotating a PushChallengeConfig parameter with {@link Extract}
*/
public class PushChallengeConfigProvider implements ValueParamProvider {
/**
* Configures the PushChallengeConfigProvider
*/
public static class PushChallengeConfigFeature implements Feature {
@Override
public boolean configure(FeatureContext context) {
context.register(new AbstractBinder() {
@Override
protected void configure() {
bind(PushChallengeConfigProvider.class)
.to(ValueParamProvider.class)
.in(Singleton.class);
}
});
return true;
}
}
@Override
public Function<ContainerRequest, ?> getValueProvider(final Parameter parameter) {
if (parameter.getRawType().equals(PushChallengeConfig.class)
&& parameter.isAnnotationPresent(Extract.class)) {
return PushChallengeConfig::new;
}
return null;
}
@Override
public PriorityType getPriority() {
return Priority.HIGH;
}
}

View File

@ -64,4 +64,12 @@ public interface SpamFilter extends ContainerRequestFilter, Managed {
* @return a {@link SpamChecker} controlled by the spam filter * @return a {@link SpamChecker} controlled by the spam filter
*/ */
SpamChecker getSpamChecker(); SpamChecker getSpamChecker();
/**
* Return a checker that will be called to determine what constraints should be applied
* when a user requests or solves a challenge (captchas, push challenges, etc).
*
* @return a {@link ChallengeConstraintChecker} controlled by the spam filter
*/
ChallengeConstraintChecker getChallengeConstraintChecker();
} }

View File

@ -23,15 +23,11 @@ import io.dropwizard.testing.junit5.ResourceExtension;
import java.io.IOException; import java.io.IOException;
import java.time.Duration; import java.time.Duration;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import javax.ws.rs.client.Entity; import javax.ws.rs.client.Entity;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerRequestFilter;
import javax.ws.rs.core.Feature;
import javax.ws.rs.core.FeatureContext;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
@ -40,9 +36,8 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager; import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.spam.PushChallengeConfigProvider; import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker;
import org.whispersystems.textsecuregcm.spam.ScoreThreshold; import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker.ChallengeConstraints;
import org.whispersystems.textsecuregcm.spam.ScoreThresholdProvider;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider; import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
@ -51,26 +46,14 @@ import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
class ChallengeControllerTest { class ChallengeControllerTest {
private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class); private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class);
private static final ChallengeConstraintChecker challengeConstraintChecker = mock(ChallengeConstraintChecker.class);
private static final ChallengeController challengeController = new ChallengeController(rateLimitChallengeManager); private static final ChallengeController challengeController =
new ChallengeController(rateLimitChallengeManager, challengeConstraintChecker);
private static final AtomicReference<Float> scoreThreshold = new AtomicReference<>();
private static final ResourceExtension EXTENSION = ResourceExtension.builder() private static final ResourceExtension EXTENSION = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedAccount.class)) .addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedAccount.class))
.addProvider(ScoreThresholdProvider.ScoreThresholdFeature.class)
.addProvider(PushChallengeConfigProvider.PushChallengeConfigFeature.class)
.addProvider(new Feature() {
public boolean configure(FeatureContext featureContext) {
featureContext.register(new ContainerRequestFilter() {
public void filter(ContainerRequestContext requestContext) {
requestContext.setProperty(ScoreThreshold.PROPERTY_NAME, scoreThreshold.get());
}
});
return true;
}
})
.addProvider(new TestRemoteAddressFilterProvider("127.0.0.1")) .addProvider(new TestRemoteAddressFilterProvider("127.0.0.1"))
.setMapper(SystemMapper.jsonMapper()) .setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
@ -78,10 +61,15 @@ class ChallengeControllerTest {
.addResource(challengeController) .addResource(challengeController)
.build(); .build();
@BeforeEach
void setup() {
when(challengeConstraintChecker.challengeConstraints(any(), any()))
.thenReturn(new ChallengeConstraints(true, Optional.empty()));
}
@AfterEach @AfterEach
void teardown() { void teardown() {
reset(rateLimitChallengeManager); reset(rateLimitChallengeManager, challengeConstraintChecker);
scoreThreshold.set(null);
} }
@Test @Test
@ -140,7 +128,8 @@ class ChallengeControllerTest {
if (hasThreshold) { if (hasThreshold) {
scoreThreshold.set(Float.valueOf(0.5f)); when(challengeConstraintChecker.challengeConstraints(any(), any()))
.thenReturn(new ChallengeConstraints(true, Optional.of(0.5F)));
} }
final Response response = EXTENSION.target("/v1/challenge") final Response response = EXTENSION.target("/v1/challenge")
.request() .request()
@ -240,6 +229,40 @@ class ChallengeControllerTest {
} }
} }
@Test
void testRequestPushChallengeNotPermitted() {
when(challengeConstraintChecker.challengeConstraints(any(), any()))
.thenReturn(new ChallengeConstraints(false, Optional.empty()));
final Response response = EXTENSION.target("/v1/challenge/push")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.post(Entity.text(""));
assertEquals(429, response.getStatus());
verifyNoInteractions(rateLimitChallengeManager);
}
@Test
void testAnswerPushChallengeNotPermitted() {
when(challengeConstraintChecker.challengeConstraints(any(), any()))
.thenReturn(new ChallengeConstraints(false, Optional.empty()));
final String pushChallengeJson = """
{
"type": "rateLimitPushChallenge",
"challenge": "Hello I am a push challenge token"
}
""";
final Response response = EXTENSION.target("/v1/challenge")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(pushChallengeJson));
assertEquals(429, response.getStatus());
verifyNoInteractions(rateLimitChallengeManager);
}
@Test @Test
void testValidationError() { void testValidationError() {
final String unrecognizedJson = """ final String unrecognizedJson = """