Conditionally use `request.remoteAddr` instead of `X-Forwarded-For`

This commit is contained in:
Chris Eager 2023-11-22 17:05:50 -06:00 committed by Chris Eager
parent b1fd025ea6
commit a027c4ce1f
8 changed files with 48 additions and 16 deletions

View File

@ -39,6 +39,7 @@ import java.time.Duration;
import java.util.Collections; import java.util.Collections;
import java.util.EnumSet; import java.util.EnumSet;
import java.util.List; import java.util.List;
import java.util.Optional;
import java.util.ServiceLoader; import java.util.ServiceLoader;
import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue; import java.util.concurrent.BlockingQueue;
@ -300,6 +301,10 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
MetricsUtil.configureRegistries(config, environment); MetricsUtil.configureRegistries(config, environment);
final boolean useRemoteAddress = Optional.ofNullable(
System.getenv("SIGNAL_USE_REMOTE_ADDRESS"))
.isPresent();
HeaderControlledResourceBundleLookup headerControlledResourceBundleLookup = HeaderControlledResourceBundleLookup headerControlledResourceBundleLookup =
new HeaderControlledResourceBundleLookup(); new HeaderControlledResourceBundleLookup();
ConfiguredProfileBadgeConverter profileBadgeConverter = new ConfiguredProfileBadgeConverter( ConfiguredProfileBadgeConverter profileBadgeConverter = new ConfiguredProfileBadgeConverter(
@ -800,7 +805,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
new ArchiveController(backupAuthManager, backupManager), new ArchiveController(backupAuthManager, backupManager),
new CallLinkController(rateLimiters, callingGenericZkSecretParams), new CallLinkController(rateLimiters, callingGenericZkSecretParams),
new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().certificate().value(), config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()), zkAuthOperations, callingGenericZkSecretParams, clock), new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().certificate().value(), config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()), zkAuthOperations, callingGenericZkSecretParams, clock),
new ChallengeController(rateLimitChallengeManager), new ChallengeController(rateLimitChallengeManager, useRemoteAddress),
new DeviceController(config.getLinkDeviceSecretConfiguration().secret().value(), accountsManager, messagesManager, keysManager, rateLimiters, new DeviceController(config.getLinkDeviceSecretConfiguration().secret().value(), accountsManager, messagesManager, keysManager, rateLimiters,
rateLimitersCluster, config.getMaxDevices(), clock), rateLimitersCluster, config.getMaxDevices(), clock),
new DirectoryV2Controller(directoryV2CredentialsGenerator), new DirectoryV2Controller(directoryV2CredentialsGenerator),
@ -831,7 +836,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
config.getCdnConfiguration().bucket()), config.getCdnConfiguration().bucket()),
new VerificationController(registrationServiceClient, new VerificationSessionManager(verificationSessions), new VerificationController(registrationServiceClient, new VerificationSessionManager(verificationSessions),
pushNotificationManager, registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters, pushNotificationManager, registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters,
accountsManager, clock) accountsManager, useRemoteAddress, clock)
); );
if (config.getSubscription() != null && config.getOneTimeDonations() != null) { if (config.getSubscription() != null && config.getOneTimeDonations() != null) {
commonControllers.add(new SubscriptionController(clock, config.getSubscription(), config.getOneTimeDonations(), commonControllers.add(new SubscriptionController(clock, config.getSubscription(), config.getOneTimeDonations(),

View File

@ -19,6 +19,7 @@ import io.swagger.v3.oas.annotations.parameters.RequestBody;
import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import java.io.IOException; import java.io.IOException;
import javax.servlet.http.HttpServletRequest;
import javax.validation.Valid; import javax.validation.Valid;
import javax.ws.rs.BadRequestException; import javax.ws.rs.BadRequestException;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
@ -27,6 +28,7 @@ import javax.ws.rs.POST;
import javax.ws.rs.PUT; import javax.ws.rs.PUT;
import javax.ws.rs.Path; import javax.ws.rs.Path;
import javax.ws.rs.Produces; import javax.ws.rs.Produces;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
@ -48,12 +50,15 @@ import org.whispersystems.textsecuregcm.util.HeaderUtils;
public class ChallengeController { public class ChallengeController {
private final RateLimitChallengeManager rateLimitChallengeManager; private final RateLimitChallengeManager rateLimitChallengeManager;
private final boolean useRemoteAddress;
private static final String CHALLENGE_RESPONSE_COUNTER_NAME = name(ChallengeController.class, "challengeResponse"); private static final String CHALLENGE_RESPONSE_COUNTER_NAME = name(ChallengeController.class, "challengeResponse");
private static final String CHALLENGE_TYPE_TAG = "type"; private static final String CHALLENGE_TYPE_TAG = "type";
public ChallengeController(final RateLimitChallengeManager rateLimitChallengeManager) { public ChallengeController(final RateLimitChallengeManager rateLimitChallengeManager,
final boolean useRemoteAddress) {
this.rateLimitChallengeManager = rateLimitChallengeManager; this.rateLimitChallengeManager = rateLimitChallengeManager;
this.useRemoteAddress = useRemoteAddress;
} }
@PUT @PUT
@ -79,6 +84,7 @@ public class ChallengeController {
public Response handleChallengeResponse(@Auth final AuthenticatedAccount auth, public Response handleChallengeResponse(@Auth final AuthenticatedAccount auth,
@Valid final AnswerChallengeRequest answerRequest, @Valid final AnswerChallengeRequest answerRequest,
@HeaderParam(HttpHeaders.X_FORWARDED_FOR) final String forwardedFor, @HeaderParam(HttpHeaders.X_FORWARDED_FOR) final String forwardedFor,
@Context HttpServletRequest request,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Extract final ScoreThreshold captchaScoreThreshold, @Extract final ScoreThreshold captchaScoreThreshold,
@Extract final PushChallengeConfig pushChallengeConfig) throws RateLimitExceededException, IOException { @Extract final PushChallengeConfig pushChallengeConfig) throws RateLimitExceededException, IOException {
@ -96,11 +102,13 @@ public class ChallengeController {
} else if (answerRequest instanceof AnswerRecaptchaChallengeRequest recaptchaChallengeRequest) { } else if (answerRequest instanceof AnswerRecaptchaChallengeRequest recaptchaChallengeRequest) {
tags = tags.and(CHALLENGE_TYPE_TAG, "recaptcha"); tags = tags.and(CHALLENGE_TYPE_TAG, "recaptcha");
final String mostRecentProxy = HeaderUtils.getMostRecentProxy(forwardedFor).orElseThrow(() -> new BadRequestException()); final String remoteAddress = useRemoteAddress
? request.getRemoteAddr()
: HeaderUtils.getMostRecentProxy(forwardedFor).orElseThrow(BadRequestException::new);
boolean success = rateLimitChallengeManager.answerRecaptchaChallenge( boolean success = rateLimitChallengeManager.answerRecaptchaChallenge(
auth.getAccount(), auth.getAccount(),
recaptchaChallengeRequest.getCaptcha(), recaptchaChallengeRequest.getCaptcha(),
mostRecentProxy, remoteAddress,
userAgent, userAgent,
captchaScoreThreshold.getScoreThreshold()); captchaScoreThreshold.getScoreThreshold());

View File

@ -31,6 +31,7 @@ import java.util.Optional;
import java.util.concurrent.CancellationException; import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import javax.servlet.http.HttpServletRequest;
import javax.validation.Valid; import javax.validation.Valid;
import javax.validation.constraints.NotNull; import javax.validation.constraints.NotNull;
import javax.ws.rs.BadRequestException; import javax.ws.rs.BadRequestException;
@ -48,6 +49,7 @@ import javax.ws.rs.PathParam;
import javax.ws.rs.Produces; import javax.ws.rs.Produces;
import javax.ws.rs.ServerErrorException; import javax.ws.rs.ServerErrorException;
import javax.ws.rs.WebApplicationException; import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.HttpHeaders; import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
@ -116,6 +118,7 @@ public class VerificationController {
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
private final AccountsManager accountsManager; private final AccountsManager accountsManager;
private final boolean useRemoteAddress;
private final Clock clock; private final Clock clock;
public VerificationController(final RegistrationServiceClient registrationServiceClient, public VerificationController(final RegistrationServiceClient registrationServiceClient,
@ -125,6 +128,7 @@ public class VerificationController {
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager, final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager,
final RateLimiters rateLimiters, final RateLimiters rateLimiters,
final AccountsManager accountsManager, final AccountsManager accountsManager,
final boolean useRemoteAddress,
final Clock clock) { final Clock clock) {
this.registrationServiceClient = registrationServiceClient; this.registrationServiceClient = registrationServiceClient;
this.verificationSessionManager = verificationSessionManager; this.verificationSessionManager = verificationSessionManager;
@ -133,6 +137,7 @@ public class VerificationController {
this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager; this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager;
this.rateLimiters = rateLimiters; this.rateLimiters = rateLimiters;
this.accountsManager = accountsManager; this.accountsManager = accountsManager;
this.useRemoteAddress = useRemoteAddress;
this.clock = clock; this.clock = clock;
} }
@ -194,10 +199,13 @@ public class VerificationController {
public VerificationSessionResponse updateSession(@PathParam("sessionId") final String encodedSessionId, public VerificationSessionResponse updateSession(@PathParam("sessionId") final String encodedSessionId,
@HeaderParam(com.google.common.net.HttpHeaders.X_FORWARDED_FOR) String forwardedFor, @HeaderParam(com.google.common.net.HttpHeaders.X_FORWARDED_FOR) String forwardedFor,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Context HttpServletRequest request,
@NotNull @Valid final UpdateVerificationSessionRequest updateVerificationSessionRequest, @NotNull @Valid final UpdateVerificationSessionRequest updateVerificationSessionRequest,
@NotNull @Extract final ScoreThreshold captchaScoreThreshold) { @NotNull @Extract final ScoreThreshold captchaScoreThreshold) {
final String sourceHost = HeaderUtils.getMostRecentProxy(forwardedFor).orElseThrow(); final String sourceHost = useRemoteAddress
? request.getRemoteAddr()
: HeaderUtils.getMostRecentProxy(forwardedFor).orElseThrow();
final Pair<String, PushNotification.TokenType> pushTokenAndType = validateAndExtractPushToken( final Pair<String, PushNotification.TokenType> pushTokenAndType = validateAndExtractPushToken(
updateVerificationSessionRequest); updateVerificationSessionRequest);

View File

@ -12,9 +12,12 @@ import com.google.common.net.HttpHeaders;
import java.io.IOException; import java.io.IOException;
import java.time.Duration; import java.time.Duration;
import java.util.Optional; import java.util.Optional;
import javax.inject.Provider;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.ClientErrorException; import javax.ws.rs.ClientErrorException;
import javax.ws.rs.container.ContainerRequestContext; import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerRequestFilter; import javax.ws.rs.container.ContainerRequestFilter;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import javax.ws.rs.ext.ExceptionMapper; import javax.ws.rs.ext.ExceptionMapper;
import org.glassfish.jersey.server.ExtendedUriInfo; import org.glassfish.jersey.server.ExtendedUriInfo;
@ -28,6 +31,9 @@ public class RateLimitByIpFilter implements ContainerRequestFilter {
private static final Logger logger = LoggerFactory.getLogger(RateLimitByIpFilter.class); private static final Logger logger = LoggerFactory.getLogger(RateLimitByIpFilter.class);
@Context
private Provider<HttpServletRequest> httpServletRequestProvider;
@VisibleForTesting @VisibleForTesting
static final RateLimitExceededException INVALID_HEADER_EXCEPTION = new RateLimitExceededException(Duration.ofHours(1), static final RateLimitExceededException INVALID_HEADER_EXCEPTION = new RateLimitExceededException(Duration.ofHours(1),
true); true);
@ -35,10 +41,12 @@ public class RateLimitByIpFilter implements ContainerRequestFilter {
private static final ExceptionMapper<RateLimitExceededException> EXCEPTION_MAPPER = new RateLimitExceededExceptionMapper(); private static final ExceptionMapper<RateLimitExceededException> EXCEPTION_MAPPER = new RateLimitExceededExceptionMapper();
private final RateLimiters rateLimiters; private final RateLimiters rateLimiters;
private final boolean useRemoteAddress;
public RateLimitByIpFilter(final RateLimiters rateLimiters) { public RateLimitByIpFilter(final RateLimiters rateLimiters, final boolean useRemoteAddress) {
this.rateLimiters = requireNonNull(rateLimiters); this.rateLimiters = requireNonNull(rateLimiters);
this.useRemoteAddress = useRemoteAddress;
} }
@Override @Override
@ -62,12 +70,14 @@ public class RateLimitByIpFilter implements ContainerRequestFilter {
try { try {
final String xffHeader = requestContext.getHeaders().getFirst(HttpHeaders.X_FORWARDED_FOR); final String xffHeader = requestContext.getHeaders().getFirst(HttpHeaders.X_FORWARDED_FOR);
final Optional<String> maybeMostRecentProxy = Optional.ofNullable(xffHeader) final Optional<String> remoteAddress = useRemoteAddress
.flatMap(HeaderUtils::getMostRecentProxy); ? Optional.of(httpServletRequestProvider.get().getRemoteAddr())
: Optional.ofNullable(xffHeader)
.flatMap(HeaderUtils::getMostRecentProxy);
// checking if we failed to extract the most recent IP from the X-Forwarded-For header // checking if we failed to extract the most recent IP from the X-Forwarded-For header
// for any reason // for any reason
if (maybeMostRecentProxy.isEmpty()) { if (remoteAddress.isEmpty()) {
// checking if annotation is configured to fail when the most recent IP is not resolved // checking if annotation is configured to fail when the most recent IP is not resolved
if (annotation.failOnUnresolvedIp()) { if (annotation.failOnUnresolvedIp()) {
logger.error("Missing/bad X-Forwarded-For: {}", xffHeader); logger.error("Missing/bad X-Forwarded-For: {}", xffHeader);
@ -78,7 +88,7 @@ public class RateLimitByIpFilter implements ContainerRequestFilter {
} }
final RateLimiter rateLimiter = rateLimiters.forDescriptor(handle); final RateLimiter rateLimiter = rateLimiters.forDescriptor(handle);
rateLimiter.validate(maybeMostRecentProxy.get()); rateLimiter.validate(remoteAddress.get());
} catch (RateLimitExceededException e) { } catch (RateLimitExceededException e) {
final Response response = EXCEPTION_MAPPER.toResponse(e); final Response response = EXCEPTION_MAPPER.toResponse(e);
throw new ClientErrorException(response); throw new ClientErrorException(response);

View File

@ -38,7 +38,6 @@ import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.ws.rs.client.Entity; import javax.ws.rs.client.Entity;
import javax.ws.rs.client.Invocation; import javax.ws.rs.client.Invocation;
@ -153,7 +152,7 @@ class AccountControllerTest {
.addProvider(new RateLimitExceededExceptionMapper()) .addProvider(new RateLimitExceededExceptionMapper())
.addProvider(new ImpossiblePhoneNumberExceptionMapper()) .addProvider(new ImpossiblePhoneNumberExceptionMapper())
.addProvider(new NonNormalizedPhoneNumberExceptionMapper()) .addProvider(new NonNormalizedPhoneNumberExceptionMapper())
.addProvider(new RateLimitByIpFilter(rateLimiters)) .addProvider(new RateLimitByIpFilter(rateLimiters, true))
.addProvider(ScoreThresholdProvider.ScoreThresholdFeature.class) .addProvider(ScoreThresholdProvider.ScoreThresholdFeature.class)
.setMapper(SystemMapper.jsonMapper()) .setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())

View File

@ -53,7 +53,8 @@ class ChallengeControllerTest {
private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class); private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class);
private static final ChallengeController challengeController = new ChallengeController(rateLimitChallengeManager); private static final ChallengeController challengeController = new ChallengeController(rateLimitChallengeManager,
true);
private static final AtomicReference<Float> scoreThreshold = new AtomicReference<>(); private static final AtomicReference<Float> scoreThreshold = new AtomicReference<>();

View File

@ -109,7 +109,8 @@ class VerificationControllerTest {
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource( .addResource(
new VerificationController(registrationServiceClient, verificationSessionManager, pushNotificationManager, new VerificationController(registrationServiceClient, verificationSessionManager, pushNotificationManager,
registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters, accountsManager, clock)) registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters, accountsManager, true,
clock))
.build(); .build();
@BeforeEach @BeforeEach

View File

@ -64,7 +64,7 @@ public class RateLimitedByIpTest {
.setMapper(SystemMapper.jsonMapper()) .setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new Controller()) .addResource(new Controller())
.addProvider(new RateLimitByIpFilter(RATE_LIMITERS)) .addProvider(new RateLimitByIpFilter(RATE_LIMITERS, true))
.build(); .build();
@Test @Test