From 2ab14ca59ecd8a27c189306c569f649e46700530 Mon Sep 17 00:00:00 2001 From: Chris Eager Date: Fri, 2 Feb 2024 12:53:02 -0600 Subject: [PATCH] Refactor remote address/X-Forwarded-For handling --- service/pom.xml | 5 + .../textsecuregcm/WhisperServerService.java | 28 +- .../controllers/ChallengeController.java | 19 +- .../controllers/VerificationController.java | 15 +- .../filters/RemoteAddressFilter.java | 62 ++++ .../filters/RemoteDeprecationFilter.java | 4 - .../limits/RateLimitByIpFilter.java | 26 +- ...blementRefreshRequirementProviderTest.java | 7 +- .../controllers/AccountControllerTest.java | 10 +- .../controllers/ChallengeControllerTest.java | 7 +- .../VerificationControllerTest.java | 2 +- .../RemoteAddressFilterIntegrationTest.java | 308 ++++++++++++++++++ .../filters/RemoteAddressFilterTest.java | 60 ++++ .../limits/RateLimitedByIpTest.java | 4 +- .../MetricsRequestEventListenerTest.java | 14 +- ...HttpServletRequestUtilIntegrationTest.java | 9 +- .../util/TestRemoteAddressFilterProvider.java | 31 ++ .../LoggingUnhandledExceptionMapperTest.java | 7 +- .../websocket/WebSocketResourceProvider.java | 5 + .../WebSocketResourceProviderFactory.java | 28 +- .../WebSocketResourceProviderFactoryTest.java | 21 +- .../WebSocketResourceProviderTest.java | 73 +++-- 22 files changed, 599 insertions(+), 146 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilter.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterIntegrationTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/util/TestRemoteAddressFilterProvider.java diff --git a/service/pom.xml b/service/pom.xml index fc31db589..96d816bee 100644 --- a/service/pom.xml +++ b/service/pom.xml @@ -190,6 +190,11 @@ org.eclipse.jetty jetty-servlets + + org.eclipse.jetty.websocket + websocket-jetty-client + test + org.apache.commons diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index cac78049c..27deba45f 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -44,6 +44,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadPoolExecutor; import javax.servlet.DispatcherType; +import javax.servlet.Filter; import javax.servlet.FilterRegistration; import javax.servlet.ServletRegistration; import org.eclipse.jetty.servlets.CrossOriginFilter; @@ -115,6 +116,7 @@ import org.whispersystems.textsecuregcm.currency.CoinMarketCapClient; import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager; import org.whispersystems.textsecuregcm.currency.FixerClient; import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter; import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter; import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter; @@ -212,8 +214,8 @@ import org.whispersystems.textsecuregcm.util.DynamoDbFromConfig; import org.whispersystems.textsecuregcm.util.ManagedAwsCrt; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier; -import org.whispersystems.textsecuregcm.util.VirtualThreadPinEventMonitor; import org.whispersystems.textsecuregcm.util.VirtualExecutorServiceProvider; +import org.whispersystems.textsecuregcm.util.VirtualThreadPinEventMonitor; import org.whispersystems.textsecuregcm.util.logging.LoggingUnhandledExceptionMapper; import org.whispersystems.textsecuregcm.util.logging.UncaughtExceptionHandler; import org.whispersystems.textsecuregcm.websocket.AuthenticatedConnectListener; @@ -718,10 +720,16 @@ public class WhisperServerService extends Application filters = new ArrayList<>(); + final RemoteDeprecationFilter remoteDeprecationFilter = new RemoteDeprecationFilter(dynamicConfigurationManager); + filters.add(remoteDeprecationFilter); + filters.add(new RemoteAddressFilter(useRemoteAddress)); + + for (Filter filter : filters) { + environment.servlets() + .addFilter(filter.getClass().getSimpleName(), filter) + .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); + } // Note: interceptors run in the reverse order they are added; the remote deprecation filter // depends on the user-agent context so it has to come first here! @@ -832,7 +840,7 @@ public class WhisperServerService extends Application webSocketServlet = new WebSocketResourceProviderFactory<>( - webSocketEnvironment, AuthenticatedAccount.class, config.getWebSocketConfiguration()); + webSocketEnvironment, AuthenticatedAccount.class, config.getWebSocketConfiguration(), + RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); WebSocketResourceProviderFactory provisioningServlet = new WebSocketResourceProviderFactory<>( - provisioningEnvironment, AuthenticatedAccount.class, config.getWebSocketConfiguration()); + provisioningEnvironment, AuthenticatedAccount.class, config.getWebSocketConfiguration(), + RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet); ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ChallengeController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ChallengeController.java index 6bfa566c1..360f70f7b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ChallengeController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ChallengeController.java @@ -19,15 +19,14 @@ import io.swagger.v3.oas.annotations.parameters.RequestBody; import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.tags.Tag; import java.io.IOException; -import javax.servlet.http.HttpServletRequest; import javax.validation.Valid; -import javax.ws.rs.BadRequestException; import javax.ws.rs.Consumes; import javax.ws.rs.HeaderParam; import javax.ws.rs.POST; import javax.ws.rs.PUT; import javax.ws.rs.Path; import javax.ws.rs.Produces; +import javax.ws.rs.container.ContainerRequestContext; import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; @@ -35,6 +34,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; import org.whispersystems.textsecuregcm.entities.AnswerChallengeRequest; import org.whispersystems.textsecuregcm.entities.AnswerPushChallengeRequest; import org.whispersystems.textsecuregcm.entities.AnswerRecaptchaChallengeRequest; +import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; @@ -42,8 +42,6 @@ import org.whispersystems.textsecuregcm.spam.Extract; import org.whispersystems.textsecuregcm.spam.FilterSpam; import org.whispersystems.textsecuregcm.spam.PushChallengeConfig; import org.whispersystems.textsecuregcm.spam.ScoreThreshold; -import org.whispersystems.textsecuregcm.util.HeaderUtils; -import org.whispersystems.textsecuregcm.util.HttpServletRequestUtil; @Path("/v1/challenge") @Tag(name = "Challenge") @@ -51,15 +49,12 @@ import org.whispersystems.textsecuregcm.util.HttpServletRequestUtil; public class ChallengeController { private final RateLimitChallengeManager rateLimitChallengeManager; - private final boolean useRemoteAddress; private static final String CHALLENGE_RESPONSE_COUNTER_NAME = name(ChallengeController.class, "challengeResponse"); private static final String CHALLENGE_TYPE_TAG = "type"; - public ChallengeController(final RateLimitChallengeManager rateLimitChallengeManager, - final boolean useRemoteAddress) { + public ChallengeController(final RateLimitChallengeManager rateLimitChallengeManager) { this.rateLimitChallengeManager = rateLimitChallengeManager; - this.useRemoteAddress = useRemoteAddress; } @PUT @@ -84,8 +79,7 @@ public class ChallengeController { description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed")) public Response handleChallengeResponse(@Auth final AuthenticatedAccount auth, @Valid final AnswerChallengeRequest answerRequest, - @HeaderParam(HttpHeaders.X_FORWARDED_FOR) final String forwardedFor, - @Context HttpServletRequest request, + @Context ContainerRequestContext requestContext, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, @Extract final ScoreThreshold captchaScoreThreshold, @Extract final PushChallengeConfig pushChallengeConfig) throws RateLimitExceededException, IOException { @@ -103,9 +97,8 @@ public class ChallengeController { } else if (answerRequest instanceof AnswerRecaptchaChallengeRequest recaptchaChallengeRequest) { tags = tags.and(CHALLENGE_TYPE_TAG, "recaptcha"); - final String remoteAddress = useRemoteAddress - ? HttpServletRequestUtil.getRemoteAddress(request) - : HeaderUtils.getMostRecentProxy(forwardedFor).orElseThrow(BadRequestException::new); + final String remoteAddress = (String) requestContext.getProperty( + RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); boolean success = rateLimitChallengeManager.answerRecaptchaChallenge( auth.getAccount(), recaptchaChallengeRequest.getCaptcha(), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/VerificationController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/VerificationController.java index 673b2f41b..2ec24943d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/VerificationController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/VerificationController.java @@ -31,7 +31,6 @@ import java.util.Optional; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletionException; import java.util.concurrent.TimeUnit; -import javax.servlet.http.HttpServletRequest; import javax.validation.Valid; import javax.validation.constraints.NotNull; import javax.ws.rs.BadRequestException; @@ -49,6 +48,7 @@ import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.ServerErrorException; import javax.ws.rs.WebApplicationException; +import javax.ws.rs.container.ContainerRequestContext; import javax.ws.rs.core.Context; import javax.ws.rs.core.HttpHeaders; import javax.ws.rs.core.MediaType; @@ -66,6 +66,7 @@ import org.whispersystems.textsecuregcm.entities.SubmitVerificationCodeRequest; import org.whispersystems.textsecuregcm.entities.UpdateVerificationSessionRequest; import org.whispersystems.textsecuregcm.entities.VerificationCodeRequest; import org.whispersystems.textsecuregcm.entities.VerificationSessionResponse; +import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; @@ -88,8 +89,6 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; import org.whispersystems.textsecuregcm.storage.VerificationSessionManager; import org.whispersystems.textsecuregcm.util.ExceptionUtils; -import org.whispersystems.textsecuregcm.util.HeaderUtils; -import org.whispersystems.textsecuregcm.util.HttpServletRequestUtil; import org.whispersystems.textsecuregcm.util.Pair; import org.whispersystems.textsecuregcm.util.Util; @@ -123,7 +122,6 @@ public class VerificationController { private final RateLimiters rateLimiters; private final AccountsManager accountsManager; - private final boolean useRemoteAddress; private final DynamicConfigurationManager dynamicConfigurationManager; private final Clock clock; @@ -134,7 +132,6 @@ public class VerificationController { final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager, final RateLimiters rateLimiters, final AccountsManager accountsManager, - final boolean useRemoteAddress, final DynamicConfigurationManager dynamicConfigurationManager, final Clock clock) { this.registrationServiceClient = registrationServiceClient; @@ -144,7 +141,6 @@ public class VerificationController { this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager; this.rateLimiters = rateLimiters; this.accountsManager = accountsManager; - this.useRemoteAddress = useRemoteAddress; this.dynamicConfigurationManager = dynamicConfigurationManager; this.clock = clock; } @@ -205,16 +201,13 @@ public class VerificationController { @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) public VerificationSessionResponse updateSession(@PathParam("sessionId") final String encodedSessionId, - @HeaderParam(com.google.common.net.HttpHeaders.X_FORWARDED_FOR) String forwardedFor, @HeaderParam(HttpHeaders.USER_AGENT) final String userAgent, - @Context HttpServletRequest request, + @Context ContainerRequestContext requestContext, @NotNull @Valid final UpdateVerificationSessionRequest updateVerificationSessionRequest, @NotNull @Extract final ScoreThreshold scoreThreshold, @NotNull @Extract final SenderOverride senderOverride) { - final String sourceHost = useRemoteAddress - ? HttpServletRequestUtil.getRemoteAddress(request) - : HeaderUtils.getMostRecentProxy(forwardedFor).orElseThrow(); + final String sourceHost = (String) requestContext.getProperty(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); final Pair pushTokenAndType = validateAndExtractPushToken( updateVerificationSessionRequest); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilter.java new file mode 100644 index 000000000..750cb9b47 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilter.java @@ -0,0 +1,62 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.filters; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.util.HeaderUtils; +import org.whispersystems.textsecuregcm.util.HttpServletRequestUtil; +import java.io.IOException; + +/** + * Sets a {@link HttpServletRequest} attribute (that will also be available as a + * {@link javax.ws.rs.container.ContainerRequestContext} property) with the remote address of the connection, using + * either the {@link HttpServletRequest#getRemoteAddr()} or the {@code X-Forwarded-For} HTTP header value, depending on + * whether {@link #preferRemoteAddress} is {@code true}. + */ +public class RemoteAddressFilter implements Filter { + + public static final String REMOTE_ADDRESS_ATTRIBUTE_NAME = RemoteAddressFilter.class.getName() + ".remoteAddress"; + private static final Logger logger = LoggerFactory.getLogger(RemoteAddressFilter.class); + + private final boolean preferRemoteAddress; + + + public RemoteAddressFilter(boolean preferRemoteAddress) { + this.preferRemoteAddress = preferRemoteAddress; + } + + @Override + public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain) + throws ServletException, IOException { + + if (request instanceof HttpServletRequest httpServletRequest) { + + final String remoteAddress; + + if (preferRemoteAddress) { + remoteAddress = HttpServletRequestUtil.getRemoteAddress(httpServletRequest); + } else { + final String forwardedFor = httpServletRequest.getHeader(com.google.common.net.HttpHeaders.X_FORWARDED_FOR); + remoteAddress = HeaderUtils.getMostRecentProxy(forwardedFor) + .orElseGet(() -> HttpServletRequestUtil.getRemoteAddress(httpServletRequest)); + } + + request.setAttribute(REMOTE_ADDRESS_ATTRIBUTE_NAME, remoteAddress); + + } else { + logger.warn("request was of unexpected type: {}", request.getClass()); + } + + chain.doFilter(request, response); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilter.java index 6e9c590af..49169526c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilter.java @@ -7,15 +7,12 @@ package org.whispersystems.textsecuregcm.filters; import static com.codahale.metrics.MetricRegistry.name; -import com.google.common.annotations.VisibleForTesting; import com.google.common.net.HttpHeaders; import com.vdurmont.semver4j.Semver; - import io.grpc.Metadata; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; -import io.grpc.Status; import io.micrometer.core.instrument.Metrics; import java.io.IOException; import java.util.Map; @@ -30,7 +27,6 @@ import javax.servlet.http.HttpServletResponse; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration; import org.whispersystems.textsecuregcm.grpc.StatusConstants; -import org.whispersystems.textsecuregcm.grpc.UserAgentInterceptor; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimitByIpFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimitByIpFilter.java index 0aa244ab2..e169b4a91 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimitByIpFilter.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimitByIpFilter.java @@ -8,33 +8,25 @@ package org.whispersystems.textsecuregcm.limits; import static java.util.Objects.requireNonNull; import com.google.common.annotations.VisibleForTesting; -import com.google.common.net.HttpHeaders; import java.io.IOException; import java.time.Duration; import java.util.Optional; -import javax.inject.Provider; -import javax.servlet.http.HttpServletRequest; import javax.ws.rs.ClientErrorException; import javax.ws.rs.container.ContainerRequestContext; import javax.ws.rs.container.ContainerRequestFilter; -import javax.ws.rs.core.Context; import javax.ws.rs.core.Response; import javax.ws.rs.ext.ExceptionMapper; import org.glassfish.jersey.server.ExtendedUriInfo; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; +import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; -import org.whispersystems.textsecuregcm.util.HeaderUtils; -import org.whispersystems.textsecuregcm.util.HttpServletRequestUtil; public class RateLimitByIpFilter implements ContainerRequestFilter { private static final Logger logger = LoggerFactory.getLogger(RateLimitByIpFilter.class); - @Context - private Provider httpServletRequestProvider; - @VisibleForTesting static final RateLimitExceededException INVALID_HEADER_EXCEPTION = new RateLimitExceededException(Duration.ofHours(1), true); @@ -42,12 +34,10 @@ public class RateLimitByIpFilter implements ContainerRequestFilter { private static final ExceptionMapper EXCEPTION_MAPPER = new RateLimitExceededExceptionMapper(); private final RateLimiters rateLimiters; - private final boolean useRemoteAddress; - public RateLimitByIpFilter(final RateLimiters rateLimiters, final boolean useRemoteAddress) { + public RateLimitByIpFilter(final RateLimiters rateLimiters) { this.rateLimiters = requireNonNull(rateLimiters); - this.useRemoteAddress = useRemoteAddress; } @Override @@ -70,18 +60,14 @@ public class RateLimitByIpFilter implements ContainerRequestFilter { final RateLimiters.For handle = annotation.value(); try { - final String xffHeader = requestContext.getHeaders().getFirst(HttpHeaders.X_FORWARDED_FOR); - final Optional remoteAddress = useRemoteAddress - ? Optional.of(HttpServletRequestUtil.getRemoteAddress(httpServletRequestProvider.get())) - : Optional.ofNullable(xffHeader) - .flatMap(HeaderUtils::getMostRecentProxy); + final Optional remoteAddress = Optional.ofNullable( + (String) requestContext.getProperty(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME)); - // checking if we failed to extract the most recent IP from the X-Forwarded-For header - // for any reason + // checking if we failed to extract the most recent IP for any reason if (remoteAddress.isEmpty()) { // checking if annotation is configured to fail when the most recent IP is not resolved if (annotation.failOnUnresolvedIp()) { - logger.error("Missing/bad X-Forwarded-For: {}", xffHeader); + logger.error("Remote address was null"); throw INVALID_HEADER_EXCEPTION; } // otherwise, allow request diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java index 58376e1f5..253a54b7c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java @@ -69,6 +69,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; +import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; @@ -306,9 +307,9 @@ class AuthEnablementRefreshRequirementProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); - provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, - requestLog, new TestPrincipal("test", account, authenticatedDevice), new ProtobufWebSocketMessageFactory(), - Optional.empty(), Duration.ofMillis(30000)); + provider = new WebSocketResourceProvider<>("127.0.0.1", RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, + applicationHandler, requestLog, new TestPrincipal("test", account, authenticatedDevice), + new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); remoteEndpoint = mock(RemoteEndpoint.class); Session session = mock(Session.class); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java index 0ef36844b..ddbecb0fe 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java @@ -91,6 +91,7 @@ import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil; import org.whispersystems.textsecuregcm.util.MockUtils; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.TestRandomUtil; +import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider; import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier; @ExtendWith(DropwizardExtensionsSupport.class) @@ -119,9 +120,6 @@ class AccountControllerTest { private static final UUID SENDER_REG_LOCK_UUID = UUID.randomUUID(); private static final UUID SENDER_TRANSFER_UUID = UUID.randomUUID(); - private static final String NICE_HOST = "127.0.0.1"; - private static final String RATE_LIMITED_IP_HOST = "10.0.0.1"; - private static AccountsManager accountsManager = mock(AccountsManager.class); private static RateLimiters rateLimiters = mock(RateLimiters.class); private static RateLimiter rateLimiter = mock(RateLimiter.class); @@ -140,6 +138,9 @@ class AccountControllerTest { private byte[] registration_lock_key = new byte[32]; + private static final TestRemoteAddressFilterProvider TEST_REMOTE_ADDRESS_FILTER_PROVIDER + = new TestRemoteAddressFilterProvider("127.0.0.1"); + private static final ResourceExtension resources = ResourceExtension.builder() .addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE) .addProvider(AuthHelper.getAuthFilter()) @@ -148,7 +149,8 @@ class AccountControllerTest { .addProvider(new RateLimitExceededExceptionMapper()) .addProvider(new ImpossiblePhoneNumberExceptionMapper()) .addProvider(new NonNormalizedPhoneNumberExceptionMapper()) - .addProvider(new RateLimitByIpFilter(rateLimiters, true)) + .addProvider(TEST_REMOTE_ADDRESS_FILTER_PROVIDER) + .addProvider(new RateLimitByIpFilter(rateLimiters)) .addProvider(ScoreThresholdProvider.ScoreThresholdFeature.class) .addProvider(SenderOverrideProvider.SenderOverrideFeature.class) .setMapper(SystemMapper.jsonMapper()) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java index 24493f20a..f761f3995 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java @@ -43,18 +43,16 @@ import org.whispersystems.textsecuregcm.push.NotPushRegisteredException; import org.whispersystems.textsecuregcm.spam.PushChallengeConfigProvider; import org.whispersystems.textsecuregcm.spam.ScoreThreshold; import org.whispersystems.textsecuregcm.spam.ScoreThresholdProvider; -import org.whispersystems.textsecuregcm.spam.SenderOverride; -import org.whispersystems.textsecuregcm.spam.SenderOverrideProvider; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.util.SystemMapper; +import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider; @ExtendWith(DropwizardExtensionsSupport.class) class ChallengeControllerTest { private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class); - private static final ChallengeController challengeController = new ChallengeController(rateLimitChallengeManager, - true); + private static final ChallengeController challengeController = new ChallengeController(rateLimitChallengeManager); private static final AtomicReference scoreThreshold = new AtomicReference<>(); @@ -73,6 +71,7 @@ class ChallengeControllerTest { return true; } }) + .addProvider(new TestRemoteAddressFilterProvider("127.0.0.1")) .setMapper(SystemMapper.jsonMapper()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource(new RateLimitExceededExceptionMapper()) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/VerificationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/VerificationControllerTest.java index 7d3f03cb4..ba2ca2710 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/VerificationControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/VerificationControllerTest.java @@ -118,7 +118,7 @@ class VerificationControllerTest { .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource( new VerificationController(registrationServiceClient, verificationSessionManager, pushNotificationManager, - registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters, accountsManager, true, + registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters, accountsManager, dynamicConfigurationManager, clock)) .build(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterIntegrationTest.java new file mode 100644 index 000000000..4db9a4ae7 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterIntegrationTest.java @@ -0,0 +1,308 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.filters; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import com.google.common.net.HttpHeaders; +import io.dropwizard.core.Application; +import io.dropwizard.core.Configuration; +import io.dropwizard.core.setup.Environment; +import io.dropwizard.testing.junit5.DropwizardAppExtension; +import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; +import java.io.IOException; +import java.net.InetAddress; +import java.net.URI; +import java.nio.ByteBuffer; +import java.security.Principal; +import java.time.Duration; +import java.util.Arrays; +import java.util.EnumSet; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import javax.security.auth.Subject; +import javax.servlet.DispatcherType; +import javax.ws.rs.GET; +import javax.ws.rs.Path; +import javax.ws.rs.client.Client; +import javax.ws.rs.container.ContainerRequestContext; +import javax.ws.rs.core.Context; +import org.eclipse.jetty.util.HostPort; +import org.eclipse.jetty.websocket.api.Session; +import org.eclipse.jetty.websocket.api.WebSocketListener; +import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; +import org.eclipse.jetty.websocket.client.WebSocketClient; +import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.whispersystems.textsecuregcm.util.SystemMapper; +import org.whispersystems.websocket.WebSocketResourceProviderFactory; +import org.whispersystems.websocket.configuration.WebSocketConfiguration; +import org.whispersystems.websocket.messages.WebSocketMessage; +import org.whispersystems.websocket.messages.WebSocketMessageFactory; +import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory; +import org.whispersystems.websocket.setup.WebSocketEnvironment; + +@ExtendWith(DropwizardExtensionsSupport.class) +class RemoteAddressFilterIntegrationTest { + + private static final String WEBSOCKET_PREFIX = "/websocket"; + private static final String REMOTE_ADDRESS_PATH = "/remoteAddress"; + private static final String FORWARDED_FOR_PATH = "/forwardedFor"; + private static final String WS_REQUEST_PATH = "/wsRequest"; + + // The Grizzly test container does not match the Jetty container used in real deployments, and JettyTestContainerFactory + // in jersey-test-framework-provider-jetty doesn’t easily support @Context HttpServletRequest, so this test runs a + // full Jetty server in a separate process + private static final DropwizardAppExtension EXTENSION = new DropwizardAppExtension<>( + TestApplication.class); + + @Nested + class Rest { + + @ParameterizedTest + @ValueSource(strings = {"127.0.0.1", "0:0:0:0:0:0:0:1"}) + void testRemoteAddress(String ip) throws Exception { + final Set addresses = Arrays.stream(InetAddress.getAllByName("localhost")) + .map(InetAddress::getHostAddress) + .collect(Collectors.toSet()); + + assumeTrue(addresses.contains(ip), String.format("localhost does not resolve to %s", ip)); + + Client client = EXTENSION.client(); + + final RemoteAddressFilterIntegrationTest.TestResponse response = client.target( + String.format("http://%s:%d%s", HostPort.normalizeHost(ip), EXTENSION.getLocalPort(), REMOTE_ADDRESS_PATH)) + .request("application/json") + .get(RemoteAddressFilterIntegrationTest.TestResponse.class); + + assertEquals(ip, response.remoteAddress()); + } + + @ParameterizedTest + @CsvSource(value = {"127.0.0.1, 192.168.1.1 \t 192.168.1.1", + "127.0.0.1, fe80:1:1:1:1:1:1:1 \t fe80:1:1:1:1:1:1:1"}, delimiterString = "\t") + void testForwardedFor(String forwardedFor, String expectedIp) { + + Client client = EXTENSION.client(); + + final RemoteAddressFilterIntegrationTest.TestResponse response = client.target( + String.format("http://localhost:%d%s", EXTENSION.getLocalPort(), FORWARDED_FOR_PATH)) + .request("application/json") + .header(HttpHeaders.X_FORWARDED_FOR, forwardedFor) + .get(RemoteAddressFilterIntegrationTest.TestResponse.class); + + assertEquals(expectedIp, response.remoteAddress()); + } + } + + @Nested + class WebSocket { + + private WebSocketClient client; + + @BeforeEach + void setUp() throws Exception { + client = new WebSocketClient(); + client.start(); + } + + @AfterEach + void tearDown() throws Exception { + client.stop(); + } + + @ParameterizedTest + @ValueSource(strings = {"127.0.0.1", "0:0:0:0:0:0:0:1"}) + void testRemoteAddress(String ip) throws Exception { + final Set addresses = Arrays.stream(InetAddress.getAllByName("localhost")) + .map(InetAddress::getHostAddress) + .collect(Collectors.toSet()); + + assumeTrue(addresses.contains(ip), String.format("localhost does not resolve to %s", ip)); + + final CompletableFuture responseFuture = new CompletableFuture<>(); + final ClientEndpoint clientEndpoint = new ClientEndpoint(WS_REQUEST_PATH, responseFuture); + + client.connect(clientEndpoint, + URI.create( + String.format("ws://%s:%d%s", HostPort.normalizeHost(ip), EXTENSION.getLocalPort(), + WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH))); + + final byte[] responseBytes = responseFuture.get(1, TimeUnit.SECONDS); + + final TestResponse response = SystemMapper.jsonMapper().readValue(responseBytes, TestResponse.class); + + assertEquals(ip, response.remoteAddress()); + } + + @ParameterizedTest + @CsvSource(value = {"127.0.0.1, 192.168.1.1 \t 192.168.1.1", + "127.0.0.1, fe80:1:1:1:1:1:1:1 \t fe80:1:1:1:1:1:1:1"}, delimiterString = "\t") + void testForwardedFor(String forwardedFor, String expectedIp) throws Exception { + + final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest(); + upgradeRequest.setHeader(HttpHeaders.X_FORWARDED_FOR, forwardedFor); + + final CompletableFuture responseFuture = new CompletableFuture<>(); + + client.connect(new ClientEndpoint(WS_REQUEST_PATH, responseFuture), + URI.create( + String.format("ws://localhost:%d%s", EXTENSION.getLocalPort(), WEBSOCKET_PREFIX + FORWARDED_FOR_PATH)), + upgradeRequest); + + final byte[] responseBytes = responseFuture.get(1, TimeUnit.SECONDS); + + final TestResponse response = SystemMapper.jsonMapper().readValue(responseBytes, TestResponse.class); + + assertEquals(expectedIp, response.remoteAddress()); + } + } + + private static class ClientEndpoint implements WebSocketListener { + + private final String requestPath; + private final CompletableFuture responseFuture; + private final WebSocketMessageFactory messageFactory; + + ClientEndpoint(String requestPath, CompletableFuture responseFuture) { + + this.requestPath = requestPath; + this.responseFuture = responseFuture; + this.messageFactory = new ProtobufWebSocketMessageFactory(); + } + + @Override + public void onWebSocketConnect(final Session session) { + final byte[] requestBytes = messageFactory.createRequest(Optional.of(1L), "GET", requestPath, + List.of("Accept: application/json"), + Optional.empty()).toByteArray(); + try { + session.getRemote().sendBytes(ByteBuffer.wrap(requestBytes)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onWebSocketBinary(final byte[] payload, final int offset, final int length) { + + try { + WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length); + + if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) { + assert 200 == webSocketMessage.getResponseMessage().getStatus(); + responseFuture.complete(webSocketMessage.getResponseMessage().getBody().orElseThrow()); + } else { + throw new RuntimeException("Unexpected message type: " + webSocketMessage.getType()); + } + } catch (final Exception e) { + throw new RuntimeException(e); + } + + } + + } + + public static abstract class TestController { + + @GET + public RemoteAddressFilterIntegrationTest.TestResponse get(@Context ContainerRequestContext context) { + + return new RemoteAddressFilterIntegrationTest.TestResponse( + (String) context.getProperty(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME)); + } + } + + @Path(REMOTE_ADDRESS_PATH) + public static class TestRemoteAddressController extends TestController { + + } + + @Path(FORWARDED_FOR_PATH) + public static class TestForwardedForController extends TestController { + + } + + @Path(WS_REQUEST_PATH) + public static class TestWebSocketController extends TestController { + + } + + public record TestResponse(String remoteAddress) { + + } + + public static class TestApplication extends Application { + + @Override + public void run(final Configuration configuration, + final Environment environment) throws Exception { + + // 2 filters, to cover useRemoteAddress = {true, false} + // each has explicit (not wildcard) path matching + environment.servlets().addFilter("RemoteAddressFilterRemoteAddress", new RemoteAddressFilter(true)) + .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, REMOTE_ADDRESS_PATH, + WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH); + environment.servlets().addFilter("RemoteAddressFilterForwardedFor", new RemoteAddressFilter(false)) + .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, FORWARDED_FOR_PATH, + WEBSOCKET_PREFIX + FORWARDED_FOR_PATH); + + environment.jersey().register(new TestRemoteAddressController()); + environment.jersey().register(new TestForwardedForController()); + + // WebSocket set up + final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration(); + + WebSocketEnvironment webSocketEnvironment = new WebSocketEnvironment<>(environment, + webSocketConfiguration, Duration.ofMillis(1000)); + + webSocketEnvironment.jersey().register(new TestWebSocketController()); + + JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); + + WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory<>( + webSocketEnvironment, TestPrincipal.class, webSocketConfiguration, + RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); + + // 2 servlets, because the filter only runs for the Upgrade request + environment.servlets().addServlet("WebSocketForwardedFor", webSocketServlet) + .addMapping(WEBSOCKET_PREFIX + FORWARDED_FOR_PATH); + environment.servlets().addServlet("WebSocketRemoteAddress", webSocketServlet) + .addMapping(WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH); + + } + } + + /** + * A minimal {@code Principal} implementation, only used to satisfy constructors + */ + public static class TestPrincipal implements Principal { + + // Principal implementation + + @Override + public String getName() { + return null; + } + + @Override + public boolean implies(final Subject subject) { + return false; + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterTest.java new file mode 100644 index 000000000..1fdf20995 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterTest.java @@ -0,0 +1,60 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.filters; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.net.HttpHeaders; +import javax.servlet.FilterChain; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +class RemoteAddressFilterTest { + + @ParameterizedTest + @CsvSource({ + "127.0.0.1, 127.0.0.1", + "0:0:0:0:0:0:0:1, 0:0:0:0:0:0:0:1", + "[0:0:0:0:0:0:0:1], 0:0:0:0:0:0:0:1" + }) + void testGetRemoteAddress(final String remoteAddr, final String expectedRemoteAddr) throws Exception { + final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class); + when(httpServletRequest.getRemoteAddr()).thenReturn(remoteAddr); + + final RemoteAddressFilter filter = new RemoteAddressFilter(true); + + final FilterChain filterChain = mock(FilterChain.class); + filter.doFilter(httpServletRequest, mock(ServletResponse.class), filterChain); + + verify(httpServletRequest).setAttribute(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, expectedRemoteAddr); + verify(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); + } + + @ParameterizedTest + @CsvSource(value = { + "192.168.1.1, 127.0.0.1 \t 127.0.0.1", + "192.168.1.1, 0:0:0:0:0:0:0:1 \t 0:0:0:0:0:0:0:1" + }, delimiterString = "\t") + void testGetRemoteAddressFromHeader(final String forwardedFor, final String expectedRemoteAddr) throws Exception { + final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class); + when(httpServletRequest.getHeader(HttpHeaders.X_FORWARDED_FOR)).thenReturn(forwardedFor); + + final RemoteAddressFilter filter = new RemoteAddressFilter(false); + + final FilterChain filterChain = mock(FilterChain.class); + filter.doFilter(httpServletRequest, mock(ServletResponse.class), filterChain); + + verify(httpServletRequest).setAttribute(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, expectedRemoteAddr); + verify(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class)); + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIpTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIpTest.java index 3f63155d6..6d3ff3d3e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIpTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIpTest.java @@ -25,6 +25,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException; import org.whispersystems.textsecuregcm.util.MockUtils; import org.whispersystems.textsecuregcm.util.SystemMapper; +import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider; @ExtendWith(DropwizardExtensionsSupport.class) public class RateLimitedByIpTest { @@ -60,7 +61,8 @@ public class RateLimitedByIpTest { .setMapper(SystemMapper.jsonMapper()) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource(new Controller()) - .addProvider(new RateLimitByIpFilter(RATE_LIMITERS, true)) + .addProvider(new RateLimitByIpFilter(RATE_LIMITERS)) + .addProvider(new TestRemoteAddressFilterProvider(IP)) .build(); @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java index f3b1962f5..a712a07ef 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java @@ -49,6 +49,7 @@ import org.glassfish.jersey.uri.UriTemplate; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.websocket.WebSocketResourceProvider; import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; @@ -138,12 +139,8 @@ class MetricsRequestEventListenerTest { final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); final WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - applicationHandler, - requestLog, - new TestPrincipal("foo"), - new ProtobufWebSocketMessageFactory(), - Optional.empty(), - Duration.ofMillis(30000)); + RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), + new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); final Session session = mock(Session.class); final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); @@ -204,9 +201,8 @@ class MetricsRequestEventListenerTest { final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); final WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", - applicationHandler, - requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), - Duration.ofMillis(30000)); + RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), + new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); final Session session = mock(Session.class); final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/HttpServletRequestUtilIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/HttpServletRequestUtilIntegrationTest.java index a2b3e7430..0722ca99c 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/HttpServletRequestUtilIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/HttpServletRequestUtilIntegrationTest.java @@ -35,8 +35,7 @@ class HttpServletRequestUtilIntegrationTest { // The Grizzly test container does not match the Jetty container used in real deployments, and JettyTestContainerFactory // in jersey-test-framework-provider-jetty doesn’t easily support @Context HttpServletRequest, so this test runs a // full Jetty server in a separate process - private final DropwizardAppExtension EXTENSION = new DropwizardAppExtension<>( - TestApplication.class); + private final DropwizardAppExtension EXTENSION = new DropwizardAppExtension<>(TestApplication.class); @ParameterizedTest @ValueSource(strings = {"127.0.0.1", "0:0:0:0:0:0:0:1"}) @@ -72,13 +71,11 @@ class HttpServletRequestUtilIntegrationTest { } - public static class TestApplication extends Application { + public static class TestApplication extends Application { @Override - public void run(final TestConfiguration configuration, final Environment environment) throws Exception { + public void run(final Configuration configuration, final Environment environment) throws Exception { environment.jersey().register(new TestController()); } } - - public static class TestConfiguration extends Configuration {} } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/TestRemoteAddressFilterProvider.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/TestRemoteAddressFilterProvider.java new file mode 100644 index 000000000..61ca1cd0e --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/TestRemoteAddressFilterProvider.java @@ -0,0 +1,31 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util; + +import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; +import javax.annotation.Priority; +import javax.ws.rs.container.ContainerRequestContext; +import javax.ws.rs.container.ContainerRequestFilter; +import java.io.IOException; + +/** + * Adds the request property set by {@link RemoteAddressFilter} for test scenarios that depend on it, but do not have + * access to a full {@code HttpServletRequest} pipline + */ +@Priority(Integer.MIN_VALUE) // highest priority, since other filters might depend on it +public class TestRemoteAddressFilterProvider implements ContainerRequestFilter { + + private final String ip; + + public TestRemoteAddressFilterProvider(String ip) { + this.ip = ip; + } + + @Override + public void filter(final ContainerRequestContext requestContext) throws IOException { + requestContext.setProperty(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, ip); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java index e75a7b0af..207d4dfbe 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java @@ -55,6 +55,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.slf4j.Logger; +import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.websocket.WebSocketResourceProvider; @@ -173,9 +174,9 @@ class LoggingUnhandledExceptionMapperTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); - WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, - requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), - Duration.ofMillis(30000)); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", + RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), + new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); doAnswer(answer -> { diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java index 2ca34905d..006fa5333 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java @@ -65,6 +65,7 @@ public class WebSocketResourceProvider implements WebSocket private final WebsocketRequestLog requestLog; private final Duration idleTimeout; private final String remoteAddress; + private final String remoteAddressPropertyName; private Session session; private RemoteEndpoint remoteEndpoint; @@ -73,6 +74,7 @@ public class WebSocketResourceProvider implements WebSocket private static final Set EXCLUDED_UPGRADE_REQUEST_HEADERS = Set.of("connection", "upgrade"); public WebSocketResourceProvider(String remoteAddress, + String remoteAddressPropertyName, ApplicationHandler jerseyHandler, WebsocketRequestLog requestLog, T authenticated, @@ -80,6 +82,7 @@ public class WebSocketResourceProvider implements WebSocket Optional connectListener, Duration idleTimeout) { this.remoteAddress = remoteAddress; + this.remoteAddressPropertyName = remoteAddressPropertyName; this.jerseyHandler = jerseyHandler; this.requestLog = requestLog; this.authenticated = authenticated; @@ -169,6 +172,8 @@ public class WebSocketResourceProvider implements WebSocket containerRequest.setEntityStream(new ByteArrayInputStream(requestMessage.getBody().get())); } + containerRequest.setProperty(remoteAddressPropertyName, remoteAddress); + ByteArrayOutputStream responseBody = new ByteArrayOutputStream(); CompletableFuture responseFuture = (CompletableFuture) jerseyHandler.apply( containerRequest, responseBody); diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java index b60faa223..1b3314430 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java @@ -6,13 +6,12 @@ package org.whispersystems.websocket; import static java.util.Optional.ofNullable; -import com.google.common.net.HttpHeaders; import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider; import java.io.IOException; -import java.net.InetSocketAddress; import java.security.Principal; -import java.util.Arrays; import java.util.Optional; +import javax.ws.rs.InternalServerErrorException; +import org.apache.commons.lang3.StringUtils; import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest; import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse; import org.eclipse.jetty.websocket.server.JettyWebSocketCreator; @@ -38,8 +37,10 @@ public class WebSocketResourceProviderFactory extends Jetty private final ApplicationHandler jerseyApplicationHandler; private final WebSocketConfiguration configuration; + private final String remoteAddressPropertyName; + public WebSocketResourceProviderFactory(WebSocketEnvironment environment, Class principalClass, - WebSocketConfiguration configuration) { + WebSocketConfiguration configuration, String remoteAddressPropertyName) { this.environment = environment; environment.jersey().register(new WebSocketSessionContextValueFactoryProvider.Binder()); @@ -49,6 +50,7 @@ public class WebSocketResourceProviderFactory extends Jetty this.jerseyApplicationHandler = new ApplicationHandler(environment.jersey()); this.configuration = configuration; + this.remoteAddressPropertyName = remoteAddressPropertyName; } @Override @@ -69,6 +71,7 @@ public class WebSocketResourceProviderFactory extends Jetty } return new WebSocketResourceProvider<>(getRemoteAddress(request), + remoteAddressPropertyName, this.jerseyApplicationHandler, this.environment.getRequestLog(), authenticated, @@ -93,18 +96,11 @@ public class WebSocketResourceProviderFactory extends Jetty } private String getRemoteAddress(JettyServerUpgradeRequest request) { - String forwardedFor = request.getHeader(HttpHeaders.X_FORWARDED_FOR); - - if (forwardedFor == null || forwardedFor.isBlank()) { - if (request.getRemoteSocketAddress() instanceof InetSocketAddress inetSocketAddress) { - return inetSocketAddress.getAddress().getHostAddress(); - } - return null; - } else { - return Arrays.stream(forwardedFor.split(",")) - .map(String::trim) - .reduce((a, b) -> b) - .orElseThrow(); + final String remoteAddress = (String) request.getHttpServletRequest().getAttribute(remoteAddressPropertyName); + if (StringUtils.isBlank(remoteAddress)) { + logger.error("Remote address property is not present"); + throw new InternalServerErrorException(); } + return remoteAddress; } } diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java index fe76b4faf..4546242ea 100644 --- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java +++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java @@ -18,8 +18,8 @@ import java.io.IOException; import java.security.Principal; import java.util.Optional; import javax.security.auth.Subject; +import javax.servlet.http.HttpServletRequest; import org.eclipse.jetty.websocket.api.Session; -import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest; import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse; import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory; @@ -33,6 +33,8 @@ import org.whispersystems.websocket.setup.WebSocketEnvironment; public class WebSocketResourceProviderFactoryTest { + private static final String REMOTE_ADDRESS_PROPERTY_NAME = "org.whispersystems.websocket.test.remoteAddress"; + private ResourceConfig jerseyEnvironment; private WebSocketEnvironment environment; private WebSocketAuthenticator authenticator; @@ -59,7 +61,7 @@ public class WebSocketResourceProviderFactoryTest { when(environment.jersey()).thenReturn(jerseyEnvironment); WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, Account.class, - mock(WebSocketConfiguration.class)); + mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME); Object connection = factory.createWebSocket(request, response); assertNull(connection); @@ -69,24 +71,25 @@ public class WebSocketResourceProviderFactoryTest { @Test void testValidAuthorization() throws AuthenticationException { - Session session = mock(Session.class); Account account = new Account(); when(environment.getAuthenticator()).thenReturn(authenticator); when(authenticator.authenticate(eq(request))).thenReturn( new WebSocketAuthenticator.AuthenticationResult<>(Optional.of(account), true)); when(environment.jersey()).thenReturn(jerseyEnvironment); - when(session.getUpgradeRequest()).thenReturn(mock(UpgradeRequest.class)); + final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class); + when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1"); + when(request.getHttpServletRequest()).thenReturn(httpServletRequest); WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, Account.class, - mock(WebSocketConfiguration.class)); + mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME); Object connection = factory.createWebSocket(request, response); assertNotNull(connection); verifyNoMoreInteractions(response); verify(authenticator).authenticate(eq(request)); - ((WebSocketResourceProvider) connection).onWebSocketConnect(session); + ((WebSocketResourceProvider) connection).onWebSocketConnect(mock(Session.class)); assertNotNull(((WebSocketResourceProvider) connection).getContext().getAuthenticated()); assertEquals(((WebSocketResourceProvider) connection).getContext().getAuthenticated(), account); @@ -100,7 +103,8 @@ public class WebSocketResourceProviderFactoryTest { WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, Account.class, - mock(WebSocketConfiguration.class)); + mock(WebSocketConfiguration.class), + REMOTE_ADDRESS_PROPERTY_NAME); Object connection = factory.createWebSocket(request, response); assertNull(connection); @@ -115,7 +119,8 @@ public class WebSocketResourceProviderFactoryTest { WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment, Account.class, - mock(WebSocketConfiguration.class)); + mock(WebSocketConfiguration.class), + REMOTE_ADDRESS_PROPERTY_NAME); factory.configure(servletFactory); verify(servletFactory).setCreator(eq(factory)); diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java index 6feba40e6..2f8536cd4 100644 --- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java +++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java @@ -70,12 +70,15 @@ import org.whispersystems.websocket.setup.WebSocketConnectListener; class WebSocketResourceProviderTest { + private static final String REMOTE_ADDRESS_PROPERTY_NAME = "org.whispersystems.weboscket.test.remoteAddress"; + @Test void testOnConnect() { ApplicationHandler applicationHandler = mock(ApplicationHandler.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); WebSocketConnectListener connectListener = mock(WebSocketConnectListener.class); WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("fooz"), new ProtobufWebSocketMessageFactory(), @@ -104,9 +107,9 @@ class WebSocketResourceProviderTest { void testMockedRouteMessageSuccess() throws Exception { ApplicationHandler applicationHandler = mock(ApplicationHandler.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); - WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, - requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), - Duration.ofMillis(30000)); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), + new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); @@ -180,9 +183,9 @@ class WebSocketResourceProviderTest { void testMockedRouteMessageFailure() throws Exception { ApplicationHandler applicationHandler = mock(ApplicationHandler.class); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); - WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, - requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), - Duration.ofMillis(30000)); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), + new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); @@ -236,9 +239,9 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); - WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, - requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), - Duration.ofMillis(30000)); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), + new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); @@ -276,9 +279,9 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); - WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, - requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), - Duration.ofMillis(30000)); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"), + new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); @@ -316,9 +319,9 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); - WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, - requestLog, new TestPrincipal("authorizedUserName"), new ProtobufWebSocketMessageFactory(), Optional.empty(), - Duration.ofMillis(30000)); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("authorizedUserName"), + new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); @@ -356,8 +359,9 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); - WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, - requestLog, null, new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(), + Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); @@ -394,9 +398,9 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); - WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, - requestLog, new TestPrincipal("something"), new ProtobufWebSocketMessageFactory(), Optional.empty(), - Duration.ofMillis(30000)); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("something"), + new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); @@ -434,8 +438,9 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); - WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, - requestLog, null, new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(), + Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); @@ -473,9 +478,9 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); - WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, - requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(), - Duration.ofMillis(30000)); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"), + new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); @@ -514,9 +519,9 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); - WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, - requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(), - Duration.ofMillis(30000)); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"), + new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); @@ -556,9 +561,9 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); - WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, - requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(), - Duration.ofMillis(30000)); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"), + new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); @@ -596,9 +601,9 @@ class WebSocketResourceProviderTest { ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); - WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, - requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(), - Duration.ofMillis(30000)); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", + REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"), + new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000)); Session session = mock(Session.class); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);