Explicitly call spam-filter for challenges

Pass in the same information to the spam-filter, but just use explicit
method calls rather than jersey request filters.
This commit is contained in:
Ravi Khadiwala 2024-02-15 18:21:57 -06:00 committed by ravi-signal
parent 30b5ad1515
commit 4f40c128bf
7 changed files with 109 additions and 147 deletions

View File

@ -181,8 +181,8 @@ import org.whispersystems.textsecuregcm.s3.PolicySigner;
import org.whispersystems.textsecuregcm.s3.PostPolicyGenerator;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker;
import org.whispersystems.textsecuregcm.spam.FilterSpam;
import org.whispersystems.textsecuregcm.spam.PushChallengeConfigProvider;
import org.whispersystems.textsecuregcm.spam.ReportSpamTokenProvider;
import org.whispersystems.textsecuregcm.spam.ScoreThresholdProvider;
import org.whispersystems.textsecuregcm.spam.SenderOverrideProvider;
@ -892,6 +892,12 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
log.warn("No spam-checkers found; using default (no-op) provider as a default");
return SpamChecker.noop();
});
final ChallengeConstraintChecker challengeConstraintChecker = spamFilter
.map(SpamFilter::getChallengeConstraintChecker)
.orElseGet(() -> {
log.warn("No challenge-constraint-checkers found; using default (no-op) provider as a default");
return ChallengeConstraintChecker.noop();
});
spamFilter.map(SpamFilter::getReportedMessageListener).ifPresent(reportMessageManager::addListener);
final RateLimitChallengeManager rateLimitChallengeManager = new RateLimitChallengeManager(pushChallengeManager,
@ -923,7 +929,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().certificate().value(),
config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()),
zkAuthOperations, callingGenericZkSecretParams, clock),
new ChallengeController(rateLimitChallengeManager),
new ChallengeController(rateLimitChallengeManager, challengeConstraintChecker),
new DeviceController(config.getLinkDeviceSecretConfiguration().secret().value(), accountsManager,
rateLimiters, rateLimitersCluster, config.getMaxDevices(), clock),
new DirectoryV2Controller(directoryV2CredentialsGenerator),
@ -1009,8 +1015,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
WebSocketEnvironment<AuthenticatedAccount> provisioningEnvironment) {
List.of(
ScoreThresholdProvider.ScoreThresholdFeature.class,
SenderOverrideProvider.SenderOverrideFeature.class,
PushChallengeConfigProvider.PushChallengeConfigFeature.class)
SenderOverrideProvider.SenderOverrideFeature.class)
.forEach(feature -> {
environment.jersey().register(feature);
webSocketEnvironment.jersey().register(feature);

View File

@ -38,24 +38,25 @@ import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.spam.Extract;
import org.whispersystems.textsecuregcm.spam.FilterSpam;
import org.whispersystems.textsecuregcm.spam.PushChallengeConfig;
import org.whispersystems.textsecuregcm.spam.ScoreThreshold;
import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker;
import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker.ChallengeConstraints;
import org.whispersystems.websocket.auth.ReadOnly;
@Path("/v1/challenge")
@Tag(name = "Challenge")
@FilterSpam
public class ChallengeController {
private final RateLimitChallengeManager rateLimitChallengeManager;
private final ChallengeConstraintChecker challengeConstraintChecker;
private static final String CHALLENGE_RESPONSE_COUNTER_NAME = name(ChallengeController.class, "challengeResponse");
private static final String CHALLENGE_TYPE_TAG = "type";
public ChallengeController(final RateLimitChallengeManager rateLimitChallengeManager) {
public ChallengeController(
final RateLimitChallengeManager rateLimitChallengeManager,
final ChallengeConstraintChecker challengeConstraintChecker) {
this.rateLimitChallengeManager = rateLimitChallengeManager;
this.challengeConstraintChecker = challengeConstraintChecker;
}
@PUT
@ -81,17 +82,17 @@ public class ChallengeController {
public Response handleChallengeResponse(@ReadOnly @Auth final AuthenticatedAccount auth,
@Valid final AnswerChallengeRequest answerRequest,
@Context ContainerRequestContext requestContext,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Extract final ScoreThreshold captchaScoreThreshold,
@Extract final PushChallengeConfig pushChallengeConfig) throws RateLimitExceededException, IOException {
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent) throws RateLimitExceededException, IOException {
Tags tags = Tags.of(UserAgentTagUtil.getPlatformTag(userAgent));
final ChallengeConstraints constraints = challengeConstraintChecker.challengeConstraints(
requestContext, auth.getAccount());
try {
if (answerRequest instanceof final AnswerPushChallengeRequest pushChallengeRequest) {
tags = tags.and(CHALLENGE_TYPE_TAG, "push");
if (!pushChallengeConfig.pushPermitted()) {
if (!constraints.pushPermitted()) {
return Response.status(429).build();
}
rateLimitChallengeManager.answerPushChallenge(auth.getAccount(), pushChallengeRequest.getChallenge());
@ -105,7 +106,7 @@ public class ChallengeController {
recaptchaChallengeRequest.getCaptcha(),
remoteAddress,
userAgent,
captchaScoreThreshold.getScoreThreshold());
constraints.captchaScoreThreshold());
if (!success) {
return Response.status(428).build();
@ -165,8 +166,10 @@ public class ChallengeController {
name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public Response requestPushChallenge(@ReadOnly @Auth final AuthenticatedAccount auth,
@Extract PushChallengeConfig pushChallengeConfig) {
if (!pushChallengeConfig.pushPermitted()) {
@Context ContainerRequestContext requestContext) {
final ChallengeConstraints constraints = challengeConstraintChecker.challengeConstraints(
requestContext, auth.getAccount());
if (!constraints.pushPermitted()) {
return Response.status(429).build();
}
try {

View File

@ -0,0 +1,27 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.spam;
import org.whispersystems.textsecuregcm.storage.Account;
import javax.ws.rs.container.ContainerRequestContext;
import java.util.Optional;
public interface ChallengeConstraintChecker {
record ChallengeConstraints(boolean pushPermitted, Optional<Float> captchaScoreThreshold) {}
/**
* Retrieve constraints for captcha and push challenges
*
* @param authenticatedAccount The authenticated account attempting to request or solve a challenge
* @return ChallengeConstraints indicating what constraints should be applied to challenges
*/
ChallengeConstraints challengeConstraints(ContainerRequestContext requestContext, Account authenticatedAccount);
static ChallengeConstraintChecker noop() {
return (account, ctx) -> new ChallengeConstraints(true, Optional.empty());
}
}

View File

@ -1,44 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.spam;
import java.util.Optional;
import org.glassfish.jersey.server.ContainerRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A PushChallengeConfig may be provided by an upstream request filter. If request contains a
* property for PROPERTY_NAME it can be forwarded to a downstream filter to indicate whether
* push-token challenges can be used in place of captchas when evaluating whether a request should
* be allowed to continue.
*/
public class PushChallengeConfig {
private static final Logger logger = LoggerFactory.getLogger(PushChallengeConfig.class);
public static final String PROPERTY_NAME = "pushChallengePermitted";
/**
* A score threshold in the range [0, 1.0]
*/
private final boolean pushPermitted;
/**
* Extract an optional score threshold parameter provided by an upstream request filter
*/
public PushChallengeConfig(final ContainerRequest containerRequest) {
this.pushPermitted = Optional
.ofNullable(containerRequest.getProperty(PROPERTY_NAME))
.filter(obj -> obj instanceof Boolean)
.map(obj -> (Boolean) obj)
.orElse(true); // not a typo! true is the default
}
public boolean pushPermitted() {
return pushPermitted;
}
}

View File

@ -1,60 +0,0 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.spam;
import java.util.function.Function;
import javax.inject.Singleton;
import javax.ws.rs.core.Feature;
import javax.ws.rs.core.FeatureContext;
import org.glassfish.jersey.internal.inject.AbstractBinder;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.model.Parameter;
import org.glassfish.jersey.server.spi.internal.ValueParamProvider;
/**
* Parses a {@link PushChallengeConfig} out of a {@link ContainerRequest} to provide to jersey resources.
*
* A request filter may enrich a ContainerRequest with a PushChallengeConfig by providing a float
* property with the name {@link PushChallengeConfig#PROPERTY_NAME}. This indicates whether push
* challenges may be considered when evaluating whether a request should proceed.
*
* A resource can consume a PushChallengeConfig with by annotating a PushChallengeConfig parameter with {@link Extract}
*/
public class PushChallengeConfigProvider implements ValueParamProvider {
/**
* Configures the PushChallengeConfigProvider
*/
public static class PushChallengeConfigFeature implements Feature {
@Override
public boolean configure(FeatureContext context) {
context.register(new AbstractBinder() {
@Override
protected void configure() {
bind(PushChallengeConfigProvider.class)
.to(ValueParamProvider.class)
.in(Singleton.class);
}
});
return true;
}
}
@Override
public Function<ContainerRequest, ?> getValueProvider(final Parameter parameter) {
if (parameter.getRawType().equals(PushChallengeConfig.class)
&& parameter.isAnnotationPresent(Extract.class)) {
return PushChallengeConfig::new;
}
return null;
}
@Override
public PriorityType getPriority() {
return Priority.HIGH;
}
}

View File

@ -64,4 +64,12 @@ public interface SpamFilter extends ContainerRequestFilter, Managed {
* @return a {@link SpamChecker} controlled by the spam filter
*/
SpamChecker getSpamChecker();
/**
* Return a checker that will be called to determine what constraints should be applied
* when a user requests or solves a challenge (captchas, push challenges, etc).
*
* @return a {@link ChallengeConstraintChecker} controlled by the spam filter
*/
ChallengeConstraintChecker getChallengeConstraintChecker();
}

View File

@ -23,15 +23,11 @@ import io.dropwizard.testing.junit5.ResourceExtension;
import java.io.IOException;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import javax.ws.rs.client.Entity;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerRequestFilter;
import javax.ws.rs.core.Feature;
import javax.ws.rs.core.FeatureContext;
import javax.ws.rs.core.Response;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
@ -40,9 +36,8 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.spam.PushChallengeConfigProvider;
import org.whispersystems.textsecuregcm.spam.ScoreThreshold;
import org.whispersystems.textsecuregcm.spam.ScoreThresholdProvider;
import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker;
import org.whispersystems.textsecuregcm.spam.ChallengeConstraintChecker.ChallengeConstraints;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
@ -51,26 +46,14 @@ import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
class ChallengeControllerTest {
private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class);
private static final ChallengeConstraintChecker challengeConstraintChecker = mock(ChallengeConstraintChecker.class);
private static final ChallengeController challengeController = new ChallengeController(rateLimitChallengeManager);
private static final AtomicReference<Float> scoreThreshold = new AtomicReference<>();
private static final ChallengeController challengeController =
new ChallengeController(rateLimitChallengeManager, challengeConstraintChecker);
private static final ResourceExtension EXTENSION = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter())
.addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedAccount.class))
.addProvider(ScoreThresholdProvider.ScoreThresholdFeature.class)
.addProvider(PushChallengeConfigProvider.PushChallengeConfigFeature.class)
.addProvider(new Feature() {
public boolean configure(FeatureContext featureContext) {
featureContext.register(new ContainerRequestFilter() {
public void filter(ContainerRequestContext requestContext) {
requestContext.setProperty(ScoreThreshold.PROPERTY_NAME, scoreThreshold.get());
}
});
return true;
}
})
.addProvider(new TestRemoteAddressFilterProvider("127.0.0.1"))
.setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
@ -78,10 +61,15 @@ class ChallengeControllerTest {
.addResource(challengeController)
.build();
@BeforeEach
void setup() {
when(challengeConstraintChecker.challengeConstraints(any(), any()))
.thenReturn(new ChallengeConstraints(true, Optional.empty()));
}
@AfterEach
void teardown() {
reset(rateLimitChallengeManager);
scoreThreshold.set(null);
reset(rateLimitChallengeManager, challengeConstraintChecker);
}
@Test
@ -140,7 +128,8 @@ class ChallengeControllerTest {
if (hasThreshold) {
scoreThreshold.set(Float.valueOf(0.5f));
when(challengeConstraintChecker.challengeConstraints(any(), any()))
.thenReturn(new ChallengeConstraints(true, Optional.of(0.5F)));
}
final Response response = EXTENSION.target("/v1/challenge")
.request()
@ -240,6 +229,40 @@ class ChallengeControllerTest {
}
}
@Test
void testRequestPushChallengeNotPermitted() {
when(challengeConstraintChecker.challengeConstraints(any(), any()))
.thenReturn(new ChallengeConstraints(false, Optional.empty()));
final Response response = EXTENSION.target("/v1/challenge/push")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.post(Entity.text(""));
assertEquals(429, response.getStatus());
verifyNoInteractions(rateLimitChallengeManager);
}
@Test
void testAnswerPushChallengeNotPermitted() {
when(challengeConstraintChecker.challengeConstraints(any(), any()))
.thenReturn(new ChallengeConstraints(false, Optional.empty()));
final String pushChallengeJson = """
{
"type": "rateLimitPushChallenge",
"challenge": "Hello I am a push challenge token"
}
""";
final Response response = EXTENSION.target("/v1/challenge")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.json(pushChallengeJson));
assertEquals(429, response.getStatus());
verifyNoInteractions(rateLimitChallengeManager);
}
@Test
void testValidationError() {
final String unrecognizedJson = """