diff --git a/service/pom.xml b/service/pom.xml
index fc31db589..96d816bee 100644
--- a/service/pom.xml
+++ b/service/pom.xml
@@ -190,6 +190,11 @@
org.eclipse.jetty
jetty-servlets
+
+ org.eclipse.jetty.websocket
+ websocket-jetty-client
+ test
+
org.apache.commons
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
index cac78049c..27deba45f 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
@@ -44,6 +44,7 @@ import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import javax.servlet.DispatcherType;
+import javax.servlet.Filter;
import javax.servlet.FilterRegistration;
import javax.servlet.ServletRegistration;
import org.eclipse.jetty.servlets.CrossOriginFilter;
@@ -115,6 +116,7 @@ import org.whispersystems.textsecuregcm.currency.CoinMarketCapClient;
import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager;
import org.whispersystems.textsecuregcm.currency.FixerClient;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
+import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter;
import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter;
import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter;
@@ -212,8 +214,8 @@ import org.whispersystems.textsecuregcm.util.DynamoDbFromConfig;
import org.whispersystems.textsecuregcm.util.ManagedAwsCrt;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier;
-import org.whispersystems.textsecuregcm.util.VirtualThreadPinEventMonitor;
import org.whispersystems.textsecuregcm.util.VirtualExecutorServiceProvider;
+import org.whispersystems.textsecuregcm.util.VirtualThreadPinEventMonitor;
import org.whispersystems.textsecuregcm.util.logging.LoggingUnhandledExceptionMapper;
import org.whispersystems.textsecuregcm.util.logging.UncaughtExceptionHandler;
import org.whispersystems.textsecuregcm.websocket.AuthenticatedConnectListener;
@@ -718,10 +720,16 @@ public class WhisperServerService extends Application filters = new ArrayList<>();
+ final RemoteDeprecationFilter remoteDeprecationFilter = new RemoteDeprecationFilter(dynamicConfigurationManager);
+ filters.add(remoteDeprecationFilter);
+ filters.add(new RemoteAddressFilter(useRemoteAddress));
+
+ for (Filter filter : filters) {
+ environment.servlets()
+ .addFilter(filter.getClass().getSimpleName(), filter)
+ .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
+ }
// Note: interceptors run in the reverse order they are added; the remote deprecation filter
// depends on the user-agent context so it has to come first here!
@@ -832,7 +840,7 @@ public class WhisperServerService extends Application webSocketServlet = new WebSocketResourceProviderFactory<>(
- webSocketEnvironment, AuthenticatedAccount.class, config.getWebSocketConfiguration());
+ webSocketEnvironment, AuthenticatedAccount.class, config.getWebSocketConfiguration(),
+ RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
WebSocketResourceProviderFactory provisioningServlet = new WebSocketResourceProviderFactory<>(
- provisioningEnvironment, AuthenticatedAccount.class, config.getWebSocketConfiguration());
+ provisioningEnvironment, AuthenticatedAccount.class, config.getWebSocketConfiguration(),
+ RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet);
ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet);
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ChallengeController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ChallengeController.java
index 6bfa566c1..360f70f7b 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ChallengeController.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ChallengeController.java
@@ -19,15 +19,14 @@ import io.swagger.v3.oas.annotations.parameters.RequestBody;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.io.IOException;
-import javax.servlet.http.HttpServletRequest;
import javax.validation.Valid;
-import javax.ws.rs.BadRequestException;
import javax.ws.rs.Consumes;
import javax.ws.rs.HeaderParam;
import javax.ws.rs.POST;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
+import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
@@ -35,6 +34,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.entities.AnswerChallengeRequest;
import org.whispersystems.textsecuregcm.entities.AnswerPushChallengeRequest;
import org.whispersystems.textsecuregcm.entities.AnswerRecaptchaChallengeRequest;
+import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
@@ -42,8 +42,6 @@ import org.whispersystems.textsecuregcm.spam.Extract;
import org.whispersystems.textsecuregcm.spam.FilterSpam;
import org.whispersystems.textsecuregcm.spam.PushChallengeConfig;
import org.whispersystems.textsecuregcm.spam.ScoreThreshold;
-import org.whispersystems.textsecuregcm.util.HeaderUtils;
-import org.whispersystems.textsecuregcm.util.HttpServletRequestUtil;
@Path("/v1/challenge")
@Tag(name = "Challenge")
@@ -51,15 +49,12 @@ import org.whispersystems.textsecuregcm.util.HttpServletRequestUtil;
public class ChallengeController {
private final RateLimitChallengeManager rateLimitChallengeManager;
- private final boolean useRemoteAddress;
private static final String CHALLENGE_RESPONSE_COUNTER_NAME = name(ChallengeController.class, "challengeResponse");
private static final String CHALLENGE_TYPE_TAG = "type";
- public ChallengeController(final RateLimitChallengeManager rateLimitChallengeManager,
- final boolean useRemoteAddress) {
+ public ChallengeController(final RateLimitChallengeManager rateLimitChallengeManager) {
this.rateLimitChallengeManager = rateLimitChallengeManager;
- this.useRemoteAddress = useRemoteAddress;
}
@PUT
@@ -84,8 +79,7 @@ public class ChallengeController {
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public Response handleChallengeResponse(@Auth final AuthenticatedAccount auth,
@Valid final AnswerChallengeRequest answerRequest,
- @HeaderParam(HttpHeaders.X_FORWARDED_FOR) final String forwardedFor,
- @Context HttpServletRequest request,
+ @Context ContainerRequestContext requestContext,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
@Extract final ScoreThreshold captchaScoreThreshold,
@Extract final PushChallengeConfig pushChallengeConfig) throws RateLimitExceededException, IOException {
@@ -103,9 +97,8 @@ public class ChallengeController {
} else if (answerRequest instanceof AnswerRecaptchaChallengeRequest recaptchaChallengeRequest) {
tags = tags.and(CHALLENGE_TYPE_TAG, "recaptcha");
- final String remoteAddress = useRemoteAddress
- ? HttpServletRequestUtil.getRemoteAddress(request)
- : HeaderUtils.getMostRecentProxy(forwardedFor).orElseThrow(BadRequestException::new);
+ final String remoteAddress = (String) requestContext.getProperty(
+ RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
boolean success = rateLimitChallengeManager.answerRecaptchaChallenge(
auth.getAccount(),
recaptchaChallengeRequest.getCaptcha(),
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/VerificationController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/VerificationController.java
index 673b2f41b..2ec24943d 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/VerificationController.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/VerificationController.java
@@ -31,7 +31,6 @@ import java.util.Optional;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletionException;
import java.util.concurrent.TimeUnit;
-import javax.servlet.http.HttpServletRequest;
import javax.validation.Valid;
import javax.validation.constraints.NotNull;
import javax.ws.rs.BadRequestException;
@@ -49,6 +48,7 @@ import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.ServerErrorException;
import javax.ws.rs.WebApplicationException;
+import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.MediaType;
@@ -66,6 +66,7 @@ import org.whispersystems.textsecuregcm.entities.SubmitVerificationCodeRequest;
import org.whispersystems.textsecuregcm.entities.UpdateVerificationSessionRequest;
import org.whispersystems.textsecuregcm.entities.VerificationCodeRequest;
import org.whispersystems.textsecuregcm.entities.VerificationSessionResponse;
+import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
@@ -88,8 +89,6 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
import org.whispersystems.textsecuregcm.storage.VerificationSessionManager;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
-import org.whispersystems.textsecuregcm.util.HeaderUtils;
-import org.whispersystems.textsecuregcm.util.HttpServletRequestUtil;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
@@ -123,7 +122,6 @@ public class VerificationController {
private final RateLimiters rateLimiters;
private final AccountsManager accountsManager;
- private final boolean useRemoteAddress;
private final DynamicConfigurationManager dynamicConfigurationManager;
private final Clock clock;
@@ -134,7 +132,6 @@ public class VerificationController {
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager,
final RateLimiters rateLimiters,
final AccountsManager accountsManager,
- final boolean useRemoteAddress,
final DynamicConfigurationManager dynamicConfigurationManager,
final Clock clock) {
this.registrationServiceClient = registrationServiceClient;
@@ -144,7 +141,6 @@ public class VerificationController {
this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager;
this.rateLimiters = rateLimiters;
this.accountsManager = accountsManager;
- this.useRemoteAddress = useRemoteAddress;
this.dynamicConfigurationManager = dynamicConfigurationManager;
this.clock = clock;
}
@@ -205,16 +201,13 @@ public class VerificationController {
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public VerificationSessionResponse updateSession(@PathParam("sessionId") final String encodedSessionId,
- @HeaderParam(com.google.common.net.HttpHeaders.X_FORWARDED_FOR) String forwardedFor,
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
- @Context HttpServletRequest request,
+ @Context ContainerRequestContext requestContext,
@NotNull @Valid final UpdateVerificationSessionRequest updateVerificationSessionRequest,
@NotNull @Extract final ScoreThreshold scoreThreshold,
@NotNull @Extract final SenderOverride senderOverride) {
- final String sourceHost = useRemoteAddress
- ? HttpServletRequestUtil.getRemoteAddress(request)
- : HeaderUtils.getMostRecentProxy(forwardedFor).orElseThrow();
+ final String sourceHost = (String) requestContext.getProperty(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
final Pair pushTokenAndType = validateAndExtractPushToken(
updateVerificationSessionRequest);
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilter.java
new file mode 100644
index 000000000..750cb9b47
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilter.java
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2024 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.filters;
+
+import javax.servlet.Filter;
+import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServletRequest;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.textsecuregcm.util.HeaderUtils;
+import org.whispersystems.textsecuregcm.util.HttpServletRequestUtil;
+import java.io.IOException;
+
+/**
+ * Sets a {@link HttpServletRequest} attribute (that will also be available as a
+ * {@link javax.ws.rs.container.ContainerRequestContext} property) with the remote address of the connection, using
+ * either the {@link HttpServletRequest#getRemoteAddr()} or the {@code X-Forwarded-For} HTTP header value, depending on
+ * whether {@link #preferRemoteAddress} is {@code true}.
+ */
+public class RemoteAddressFilter implements Filter {
+
+ public static final String REMOTE_ADDRESS_ATTRIBUTE_NAME = RemoteAddressFilter.class.getName() + ".remoteAddress";
+ private static final Logger logger = LoggerFactory.getLogger(RemoteAddressFilter.class);
+
+ private final boolean preferRemoteAddress;
+
+
+ public RemoteAddressFilter(boolean preferRemoteAddress) {
+ this.preferRemoteAddress = preferRemoteAddress;
+ }
+
+ @Override
+ public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain)
+ throws ServletException, IOException {
+
+ if (request instanceof HttpServletRequest httpServletRequest) {
+
+ final String remoteAddress;
+
+ if (preferRemoteAddress) {
+ remoteAddress = HttpServletRequestUtil.getRemoteAddress(httpServletRequest);
+ } else {
+ final String forwardedFor = httpServletRequest.getHeader(com.google.common.net.HttpHeaders.X_FORWARDED_FOR);
+ remoteAddress = HeaderUtils.getMostRecentProxy(forwardedFor)
+ .orElseGet(() -> HttpServletRequestUtil.getRemoteAddress(httpServletRequest));
+ }
+
+ request.setAttribute(REMOTE_ADDRESS_ATTRIBUTE_NAME, remoteAddress);
+
+ } else {
+ logger.warn("request was of unexpected type: {}", request.getClass());
+ }
+
+ chain.doFilter(request, response);
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilter.java
index 6e9c590af..49169526c 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilter.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/filters/RemoteDeprecationFilter.java
@@ -7,15 +7,12 @@ package org.whispersystems.textsecuregcm.filters;
import static com.codahale.metrics.MetricRegistry.name;
-import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.HttpHeaders;
import com.vdurmont.semver4j.Semver;
-
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
-import io.grpc.Status;
import io.micrometer.core.instrument.Metrics;
import java.io.IOException;
import java.util.Map;
@@ -30,7 +27,6 @@ import javax.servlet.http.HttpServletResponse;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration;
import org.whispersystems.textsecuregcm.grpc.StatusConstants;
-import org.whispersystems.textsecuregcm.grpc.UserAgentInterceptor;
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimitByIpFilter.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimitByIpFilter.java
index 0aa244ab2..e169b4a91 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimitByIpFilter.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimitByIpFilter.java
@@ -8,33 +8,25 @@ package org.whispersystems.textsecuregcm.limits;
import static java.util.Objects.requireNonNull;
import com.google.common.annotations.VisibleForTesting;
-import com.google.common.net.HttpHeaders;
import java.io.IOException;
import java.time.Duration;
import java.util.Optional;
-import javax.inject.Provider;
-import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.ClientErrorException;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerRequestFilter;
-import javax.ws.rs.core.Context;
import javax.ws.rs.core.Response;
import javax.ws.rs.ext.ExceptionMapper;
import org.glassfish.jersey.server.ExtendedUriInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
+import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
-import org.whispersystems.textsecuregcm.util.HeaderUtils;
-import org.whispersystems.textsecuregcm.util.HttpServletRequestUtil;
public class RateLimitByIpFilter implements ContainerRequestFilter {
private static final Logger logger = LoggerFactory.getLogger(RateLimitByIpFilter.class);
- @Context
- private Provider httpServletRequestProvider;
-
@VisibleForTesting
static final RateLimitExceededException INVALID_HEADER_EXCEPTION = new RateLimitExceededException(Duration.ofHours(1),
true);
@@ -42,12 +34,10 @@ public class RateLimitByIpFilter implements ContainerRequestFilter {
private static final ExceptionMapper EXCEPTION_MAPPER = new RateLimitExceededExceptionMapper();
private final RateLimiters rateLimiters;
- private final boolean useRemoteAddress;
- public RateLimitByIpFilter(final RateLimiters rateLimiters, final boolean useRemoteAddress) {
+ public RateLimitByIpFilter(final RateLimiters rateLimiters) {
this.rateLimiters = requireNonNull(rateLimiters);
- this.useRemoteAddress = useRemoteAddress;
}
@Override
@@ -70,18 +60,14 @@ public class RateLimitByIpFilter implements ContainerRequestFilter {
final RateLimiters.For handle = annotation.value();
try {
- final String xffHeader = requestContext.getHeaders().getFirst(HttpHeaders.X_FORWARDED_FOR);
- final Optional remoteAddress = useRemoteAddress
- ? Optional.of(HttpServletRequestUtil.getRemoteAddress(httpServletRequestProvider.get()))
- : Optional.ofNullable(xffHeader)
- .flatMap(HeaderUtils::getMostRecentProxy);
+ final Optional remoteAddress = Optional.ofNullable(
+ (String) requestContext.getProperty(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME));
- // checking if we failed to extract the most recent IP from the X-Forwarded-For header
- // for any reason
+ // checking if we failed to extract the most recent IP for any reason
if (remoteAddress.isEmpty()) {
// checking if annotation is configured to fail when the most recent IP is not resolved
if (annotation.failOnUnresolvedIp()) {
- logger.error("Missing/bad X-Forwarded-For: {}", xffHeader);
+ logger.error("Remote address was null");
throw INVALID_HEADER_EXCEPTION;
}
// otherwise, allow request
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java
index 58376e1f5..253a54b7c 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java
@@ -69,6 +69,7 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
+import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
@@ -306,9 +307,9 @@ class AuthEnablementRefreshRequirementProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
- provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
- requestLog, new TestPrincipal("test", account, authenticatedDevice), new ProtobufWebSocketMessageFactory(),
- Optional.empty(), Duration.ofMillis(30000));
+ provider = new WebSocketResourceProvider<>("127.0.0.1", RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME,
+ applicationHandler, requestLog, new TestPrincipal("test", account, authenticatedDevice),
+ new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
remoteEndpoint = mock(RemoteEndpoint.class);
Session session = mock(Session.class);
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java
index 0ef36844b..ddbecb0fe 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/AccountControllerTest.java
@@ -91,6 +91,7 @@ import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
+import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier;
@ExtendWith(DropwizardExtensionsSupport.class)
@@ -119,9 +120,6 @@ class AccountControllerTest {
private static final UUID SENDER_REG_LOCK_UUID = UUID.randomUUID();
private static final UUID SENDER_TRANSFER_UUID = UUID.randomUUID();
- private static final String NICE_HOST = "127.0.0.1";
- private static final String RATE_LIMITED_IP_HOST = "10.0.0.1";
-
private static AccountsManager accountsManager = mock(AccountsManager.class);
private static RateLimiters rateLimiters = mock(RateLimiters.class);
private static RateLimiter rateLimiter = mock(RateLimiter.class);
@@ -140,6 +138,9 @@ class AccountControllerTest {
private byte[] registration_lock_key = new byte[32];
+ private static final TestRemoteAddressFilterProvider TEST_REMOTE_ADDRESS_FILTER_PROVIDER
+ = new TestRemoteAddressFilterProvider("127.0.0.1");
+
private static final ResourceExtension resources = ResourceExtension.builder()
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
.addProvider(AuthHelper.getAuthFilter())
@@ -148,7 +149,8 @@ class AccountControllerTest {
.addProvider(new RateLimitExceededExceptionMapper())
.addProvider(new ImpossiblePhoneNumberExceptionMapper())
.addProvider(new NonNormalizedPhoneNumberExceptionMapper())
- .addProvider(new RateLimitByIpFilter(rateLimiters, true))
+ .addProvider(TEST_REMOTE_ADDRESS_FILTER_PROVIDER)
+ .addProvider(new RateLimitByIpFilter(rateLimiters))
.addProvider(ScoreThresholdProvider.ScoreThresholdFeature.class)
.addProvider(SenderOverrideProvider.SenderOverrideFeature.class)
.setMapper(SystemMapper.jsonMapper())
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java
index 24493f20a..f761f3995 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/ChallengeControllerTest.java
@@ -43,18 +43,16 @@ import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.spam.PushChallengeConfigProvider;
import org.whispersystems.textsecuregcm.spam.ScoreThreshold;
import org.whispersystems.textsecuregcm.spam.ScoreThresholdProvider;
-import org.whispersystems.textsecuregcm.spam.SenderOverride;
-import org.whispersystems.textsecuregcm.spam.SenderOverrideProvider;
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
+import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
@ExtendWith(DropwizardExtensionsSupport.class)
class ChallengeControllerTest {
private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class);
- private static final ChallengeController challengeController = new ChallengeController(rateLimitChallengeManager,
- true);
+ private static final ChallengeController challengeController = new ChallengeController(rateLimitChallengeManager);
private static final AtomicReference scoreThreshold = new AtomicReference<>();
@@ -73,6 +71,7 @@ class ChallengeControllerTest {
return true;
}
})
+ .addProvider(new TestRemoteAddressFilterProvider("127.0.0.1"))
.setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new RateLimitExceededExceptionMapper())
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/VerificationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/VerificationControllerTest.java
index 7d3f03cb4..ba2ca2710 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/VerificationControllerTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/VerificationControllerTest.java
@@ -118,7 +118,7 @@ class VerificationControllerTest {
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(
new VerificationController(registrationServiceClient, verificationSessionManager, pushNotificationManager,
- registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters, accountsManager, true,
+ registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters, accountsManager,
dynamicConfigurationManager, clock))
.build();
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterIntegrationTest.java
new file mode 100644
index 000000000..4db9a4ae7
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterIntegrationTest.java
@@ -0,0 +1,308 @@
+/*
+ * Copyright 2024 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.filters;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assumptions.assumeTrue;
+
+import com.google.common.net.HttpHeaders;
+import io.dropwizard.core.Application;
+import io.dropwizard.core.Configuration;
+import io.dropwizard.core.setup.Environment;
+import io.dropwizard.testing.junit5.DropwizardAppExtension;
+import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
+import java.io.IOException;
+import java.net.InetAddress;
+import java.net.URI;
+import java.nio.ByteBuffer;
+import java.security.Principal;
+import java.time.Duration;
+import java.util.Arrays;
+import java.util.EnumSet;
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+import javax.security.auth.Subject;
+import javax.servlet.DispatcherType;
+import javax.ws.rs.GET;
+import javax.ws.rs.Path;
+import javax.ws.rs.client.Client;
+import javax.ws.rs.container.ContainerRequestContext;
+import javax.ws.rs.core.Context;
+import org.eclipse.jetty.util.HostPort;
+import org.eclipse.jetty.websocket.api.Session;
+import org.eclipse.jetty.websocket.api.WebSocketListener;
+import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
+import org.eclipse.jetty.websocket.client.WebSocketClient;
+import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Nested;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.CsvSource;
+import org.junit.jupiter.params.provider.ValueSource;
+import org.whispersystems.textsecuregcm.util.SystemMapper;
+import org.whispersystems.websocket.WebSocketResourceProviderFactory;
+import org.whispersystems.websocket.configuration.WebSocketConfiguration;
+import org.whispersystems.websocket.messages.WebSocketMessage;
+import org.whispersystems.websocket.messages.WebSocketMessageFactory;
+import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
+import org.whispersystems.websocket.setup.WebSocketEnvironment;
+
+@ExtendWith(DropwizardExtensionsSupport.class)
+class RemoteAddressFilterIntegrationTest {
+
+ private static final String WEBSOCKET_PREFIX = "/websocket";
+ private static final String REMOTE_ADDRESS_PATH = "/remoteAddress";
+ private static final String FORWARDED_FOR_PATH = "/forwardedFor";
+ private static final String WS_REQUEST_PATH = "/wsRequest";
+
+ // The Grizzly test container does not match the Jetty container used in real deployments, and JettyTestContainerFactory
+ // in jersey-test-framework-provider-jetty doesn’t easily support @Context HttpServletRequest, so this test runs a
+ // full Jetty server in a separate process
+ private static final DropwizardAppExtension EXTENSION = new DropwizardAppExtension<>(
+ TestApplication.class);
+
+ @Nested
+ class Rest {
+
+ @ParameterizedTest
+ @ValueSource(strings = {"127.0.0.1", "0:0:0:0:0:0:0:1"})
+ void testRemoteAddress(String ip) throws Exception {
+ final Set addresses = Arrays.stream(InetAddress.getAllByName("localhost"))
+ .map(InetAddress::getHostAddress)
+ .collect(Collectors.toSet());
+
+ assumeTrue(addresses.contains(ip), String.format("localhost does not resolve to %s", ip));
+
+ Client client = EXTENSION.client();
+
+ final RemoteAddressFilterIntegrationTest.TestResponse response = client.target(
+ String.format("http://%s:%d%s", HostPort.normalizeHost(ip), EXTENSION.getLocalPort(), REMOTE_ADDRESS_PATH))
+ .request("application/json")
+ .get(RemoteAddressFilterIntegrationTest.TestResponse.class);
+
+ assertEquals(ip, response.remoteAddress());
+ }
+
+ @ParameterizedTest
+ @CsvSource(value = {"127.0.0.1, 192.168.1.1 \t 192.168.1.1",
+ "127.0.0.1, fe80:1:1:1:1:1:1:1 \t fe80:1:1:1:1:1:1:1"}, delimiterString = "\t")
+ void testForwardedFor(String forwardedFor, String expectedIp) {
+
+ Client client = EXTENSION.client();
+
+ final RemoteAddressFilterIntegrationTest.TestResponse response = client.target(
+ String.format("http://localhost:%d%s", EXTENSION.getLocalPort(), FORWARDED_FOR_PATH))
+ .request("application/json")
+ .header(HttpHeaders.X_FORWARDED_FOR, forwardedFor)
+ .get(RemoteAddressFilterIntegrationTest.TestResponse.class);
+
+ assertEquals(expectedIp, response.remoteAddress());
+ }
+ }
+
+ @Nested
+ class WebSocket {
+
+ private WebSocketClient client;
+
+ @BeforeEach
+ void setUp() throws Exception {
+ client = new WebSocketClient();
+ client.start();
+ }
+
+ @AfterEach
+ void tearDown() throws Exception {
+ client.stop();
+ }
+
+ @ParameterizedTest
+ @ValueSource(strings = {"127.0.0.1", "0:0:0:0:0:0:0:1"})
+ void testRemoteAddress(String ip) throws Exception {
+ final Set addresses = Arrays.stream(InetAddress.getAllByName("localhost"))
+ .map(InetAddress::getHostAddress)
+ .collect(Collectors.toSet());
+
+ assumeTrue(addresses.contains(ip), String.format("localhost does not resolve to %s", ip));
+
+ final CompletableFuture responseFuture = new CompletableFuture<>();
+ final ClientEndpoint clientEndpoint = new ClientEndpoint(WS_REQUEST_PATH, responseFuture);
+
+ client.connect(clientEndpoint,
+ URI.create(
+ String.format("ws://%s:%d%s", HostPort.normalizeHost(ip), EXTENSION.getLocalPort(),
+ WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH)));
+
+ final byte[] responseBytes = responseFuture.get(1, TimeUnit.SECONDS);
+
+ final TestResponse response = SystemMapper.jsonMapper().readValue(responseBytes, TestResponse.class);
+
+ assertEquals(ip, response.remoteAddress());
+ }
+
+ @ParameterizedTest
+ @CsvSource(value = {"127.0.0.1, 192.168.1.1 \t 192.168.1.1",
+ "127.0.0.1, fe80:1:1:1:1:1:1:1 \t fe80:1:1:1:1:1:1:1"}, delimiterString = "\t")
+ void testForwardedFor(String forwardedFor, String expectedIp) throws Exception {
+
+ final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
+ upgradeRequest.setHeader(HttpHeaders.X_FORWARDED_FOR, forwardedFor);
+
+ final CompletableFuture responseFuture = new CompletableFuture<>();
+
+ client.connect(new ClientEndpoint(WS_REQUEST_PATH, responseFuture),
+ URI.create(
+ String.format("ws://localhost:%d%s", EXTENSION.getLocalPort(), WEBSOCKET_PREFIX + FORWARDED_FOR_PATH)),
+ upgradeRequest);
+
+ final byte[] responseBytes = responseFuture.get(1, TimeUnit.SECONDS);
+
+ final TestResponse response = SystemMapper.jsonMapper().readValue(responseBytes, TestResponse.class);
+
+ assertEquals(expectedIp, response.remoteAddress());
+ }
+ }
+
+ private static class ClientEndpoint implements WebSocketListener {
+
+ private final String requestPath;
+ private final CompletableFuture responseFuture;
+ private final WebSocketMessageFactory messageFactory;
+
+ ClientEndpoint(String requestPath, CompletableFuture responseFuture) {
+
+ this.requestPath = requestPath;
+ this.responseFuture = responseFuture;
+ this.messageFactory = new ProtobufWebSocketMessageFactory();
+ }
+
+ @Override
+ public void onWebSocketConnect(final Session session) {
+ final byte[] requestBytes = messageFactory.createRequest(Optional.of(1L), "GET", requestPath,
+ List.of("Accept: application/json"),
+ Optional.empty()).toByteArray();
+ try {
+ session.getRemote().sendBytes(ByteBuffer.wrap(requestBytes));
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public void onWebSocketBinary(final byte[] payload, final int offset, final int length) {
+
+ try {
+ WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
+
+ if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) {
+ assert 200 == webSocketMessage.getResponseMessage().getStatus();
+ responseFuture.complete(webSocketMessage.getResponseMessage().getBody().orElseThrow());
+ } else {
+ throw new RuntimeException("Unexpected message type: " + webSocketMessage.getType());
+ }
+ } catch (final Exception e) {
+ throw new RuntimeException(e);
+ }
+
+ }
+
+ }
+
+ public static abstract class TestController {
+
+ @GET
+ public RemoteAddressFilterIntegrationTest.TestResponse get(@Context ContainerRequestContext context) {
+
+ return new RemoteAddressFilterIntegrationTest.TestResponse(
+ (String) context.getProperty(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME));
+ }
+ }
+
+ @Path(REMOTE_ADDRESS_PATH)
+ public static class TestRemoteAddressController extends TestController {
+
+ }
+
+ @Path(FORWARDED_FOR_PATH)
+ public static class TestForwardedForController extends TestController {
+
+ }
+
+ @Path(WS_REQUEST_PATH)
+ public static class TestWebSocketController extends TestController {
+
+ }
+
+ public record TestResponse(String remoteAddress) {
+
+ }
+
+ public static class TestApplication extends Application {
+
+ @Override
+ public void run(final Configuration configuration,
+ final Environment environment) throws Exception {
+
+ // 2 filters, to cover useRemoteAddress = {true, false}
+ // each has explicit (not wildcard) path matching
+ environment.servlets().addFilter("RemoteAddressFilterRemoteAddress", new RemoteAddressFilter(true))
+ .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, REMOTE_ADDRESS_PATH,
+ WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH);
+ environment.servlets().addFilter("RemoteAddressFilterForwardedFor", new RemoteAddressFilter(false))
+ .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, FORWARDED_FOR_PATH,
+ WEBSOCKET_PREFIX + FORWARDED_FOR_PATH);
+
+ environment.jersey().register(new TestRemoteAddressController());
+ environment.jersey().register(new TestForwardedForController());
+
+ // WebSocket set up
+ final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
+
+ WebSocketEnvironment webSocketEnvironment = new WebSocketEnvironment<>(environment,
+ webSocketConfiguration, Duration.ofMillis(1000));
+
+ webSocketEnvironment.jersey().register(new TestWebSocketController());
+
+ JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
+
+ WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory<>(
+ webSocketEnvironment, TestPrincipal.class, webSocketConfiguration,
+ RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
+
+ // 2 servlets, because the filter only runs for the Upgrade request
+ environment.servlets().addServlet("WebSocketForwardedFor", webSocketServlet)
+ .addMapping(WEBSOCKET_PREFIX + FORWARDED_FOR_PATH);
+ environment.servlets().addServlet("WebSocketRemoteAddress", webSocketServlet)
+ .addMapping(WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH);
+
+ }
+ }
+
+ /**
+ * A minimal {@code Principal} implementation, only used to satisfy constructors
+ */
+ public static class TestPrincipal implements Principal {
+
+ // Principal implementation
+
+ @Override
+ public String getName() {
+ return null;
+ }
+
+ @Override
+ public boolean implies(final Subject subject) {
+ return false;
+ }
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterTest.java
new file mode 100644
index 000000000..1fdf20995
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/filters/RemoteAddressFilterTest.java
@@ -0,0 +1,60 @@
+/*
+ * Copyright 2024 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.filters;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.google.common.net.HttpHeaders;
+import javax.servlet.FilterChain;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServletRequest;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.CsvSource;
+
+class RemoteAddressFilterTest {
+
+ @ParameterizedTest
+ @CsvSource({
+ "127.0.0.1, 127.0.0.1",
+ "0:0:0:0:0:0:0:1, 0:0:0:0:0:0:0:1",
+ "[0:0:0:0:0:0:0:1], 0:0:0:0:0:0:0:1"
+ })
+ void testGetRemoteAddress(final String remoteAddr, final String expectedRemoteAddr) throws Exception {
+ final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
+ when(httpServletRequest.getRemoteAddr()).thenReturn(remoteAddr);
+
+ final RemoteAddressFilter filter = new RemoteAddressFilter(true);
+
+ final FilterChain filterChain = mock(FilterChain.class);
+ filter.doFilter(httpServletRequest, mock(ServletResponse.class), filterChain);
+
+ verify(httpServletRequest).setAttribute(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, expectedRemoteAddr);
+ verify(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
+ }
+
+ @ParameterizedTest
+ @CsvSource(value = {
+ "192.168.1.1, 127.0.0.1 \t 127.0.0.1",
+ "192.168.1.1, 0:0:0:0:0:0:0:1 \t 0:0:0:0:0:0:0:1"
+ }, delimiterString = "\t")
+ void testGetRemoteAddressFromHeader(final String forwardedFor, final String expectedRemoteAddr) throws Exception {
+ final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
+ when(httpServletRequest.getHeader(HttpHeaders.X_FORWARDED_FOR)).thenReturn(forwardedFor);
+
+ final RemoteAddressFilter filter = new RemoteAddressFilter(false);
+
+ final FilterChain filterChain = mock(FilterChain.class);
+ filter.doFilter(httpServletRequest, mock(ServletResponse.class), filterChain);
+
+ verify(httpServletRequest).setAttribute(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, expectedRemoteAddr);
+ verify(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
+ }
+
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIpTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIpTest.java
index 3f63155d6..6d3ff3d3e 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIpTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/limits/RateLimitedByIpTest.java
@@ -25,6 +25,7 @@ import org.junit.jupiter.api.extension.ExtendWith;
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
import org.whispersystems.textsecuregcm.util.MockUtils;
import org.whispersystems.textsecuregcm.util.SystemMapper;
+import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
@ExtendWith(DropwizardExtensionsSupport.class)
public class RateLimitedByIpTest {
@@ -60,7 +61,8 @@ public class RateLimitedByIpTest {
.setMapper(SystemMapper.jsonMapper())
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new Controller())
- .addProvider(new RateLimitByIpFilter(RATE_LIMITERS, true))
+ .addProvider(new RateLimitByIpFilter(RATE_LIMITERS))
+ .addProvider(new TestRemoteAddressFilterProvider(IP))
.build();
@Test
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java
index f3b1962f5..a712a07ef 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java
@@ -49,6 +49,7 @@ import org.glassfish.jersey.uri.UriTemplate;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
+import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.websocket.WebSocketResourceProvider;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
@@ -138,12 +139,8 @@ class MetricsRequestEventListenerTest {
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
final WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1",
- applicationHandler,
- requestLog,
- new TestPrincipal("foo"),
- new ProtobufWebSocketMessageFactory(),
- Optional.empty(),
- Duration.ofMillis(30000));
+ RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
+ new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class);
final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -204,9 +201,8 @@ class MetricsRequestEventListenerTest {
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
final WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1",
- applicationHandler,
- requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
- Duration.ofMillis(30000));
+ RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
+ new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
final Session session = mock(Session.class);
final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/HttpServletRequestUtilIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/HttpServletRequestUtilIntegrationTest.java
index a2b3e7430..0722ca99c 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/util/HttpServletRequestUtilIntegrationTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/HttpServletRequestUtilIntegrationTest.java
@@ -35,8 +35,7 @@ class HttpServletRequestUtilIntegrationTest {
// The Grizzly test container does not match the Jetty container used in real deployments, and JettyTestContainerFactory
// in jersey-test-framework-provider-jetty doesn’t easily support @Context HttpServletRequest, so this test runs a
// full Jetty server in a separate process
- private final DropwizardAppExtension EXTENSION = new DropwizardAppExtension<>(
- TestApplication.class);
+ private final DropwizardAppExtension EXTENSION = new DropwizardAppExtension<>(TestApplication.class);
@ParameterizedTest
@ValueSource(strings = {"127.0.0.1", "0:0:0:0:0:0:0:1"})
@@ -72,13 +71,11 @@ class HttpServletRequestUtilIntegrationTest {
}
- public static class TestApplication extends Application {
+ public static class TestApplication extends Application {
@Override
- public void run(final TestConfiguration configuration, final Environment environment) throws Exception {
+ public void run(final Configuration configuration, final Environment environment) throws Exception {
environment.jersey().register(new TestController());
}
}
-
- public static class TestConfiguration extends Configuration {}
}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/TestRemoteAddressFilterProvider.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/TestRemoteAddressFilterProvider.java
new file mode 100644
index 000000000..61ca1cd0e
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/TestRemoteAddressFilterProvider.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright 2024 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.util;
+
+import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
+import javax.annotation.Priority;
+import javax.ws.rs.container.ContainerRequestContext;
+import javax.ws.rs.container.ContainerRequestFilter;
+import java.io.IOException;
+
+/**
+ * Adds the request property set by {@link RemoteAddressFilter} for test scenarios that depend on it, but do not have
+ * access to a full {@code HttpServletRequest} pipline
+ */
+@Priority(Integer.MIN_VALUE) // highest priority, since other filters might depend on it
+public class TestRemoteAddressFilterProvider implements ContainerRequestFilter {
+
+ private final String ip;
+
+ public TestRemoteAddressFilterProvider(String ip) {
+ this.ip = ip;
+ }
+
+ @Override
+ public void filter(final ContainerRequestContext requestContext) throws IOException {
+ requestContext.setProperty(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, ip);
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java
index e75a7b0af..207d4dfbe 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java
@@ -55,6 +55,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.slf4j.Logger;
+import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.WebSocketResourceProvider;
@@ -173,9 +174,9 @@ class LoggingUnhandledExceptionMapperTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
- WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
- requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
- Duration.ofMillis(30000));
+ WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1",
+ RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
+ new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
doAnswer(answer -> {
diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java
index 2ca34905d..006fa5333 100644
--- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java
+++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java
@@ -65,6 +65,7 @@ public class WebSocketResourceProvider implements WebSocket
private final WebsocketRequestLog requestLog;
private final Duration idleTimeout;
private final String remoteAddress;
+ private final String remoteAddressPropertyName;
private Session session;
private RemoteEndpoint remoteEndpoint;
@@ -73,6 +74,7 @@ public class WebSocketResourceProvider implements WebSocket
private static final Set EXCLUDED_UPGRADE_REQUEST_HEADERS = Set.of("connection", "upgrade");
public WebSocketResourceProvider(String remoteAddress,
+ String remoteAddressPropertyName,
ApplicationHandler jerseyHandler,
WebsocketRequestLog requestLog,
T authenticated,
@@ -80,6 +82,7 @@ public class WebSocketResourceProvider implements WebSocket
Optional connectListener,
Duration idleTimeout) {
this.remoteAddress = remoteAddress;
+ this.remoteAddressPropertyName = remoteAddressPropertyName;
this.jerseyHandler = jerseyHandler;
this.requestLog = requestLog;
this.authenticated = authenticated;
@@ -169,6 +172,8 @@ public class WebSocketResourceProvider implements WebSocket
containerRequest.setEntityStream(new ByteArrayInputStream(requestMessage.getBody().get()));
}
+ containerRequest.setProperty(remoteAddressPropertyName, remoteAddress);
+
ByteArrayOutputStream responseBody = new ByteArrayOutputStream();
CompletableFuture responseFuture = (CompletableFuture) jerseyHandler.apply(
containerRequest, responseBody);
diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java
index b60faa223..1b3314430 100644
--- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java
+++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java
@@ -6,13 +6,12 @@ package org.whispersystems.websocket;
import static java.util.Optional.ofNullable;
-import com.google.common.net.HttpHeaders;
import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider;
import java.io.IOException;
-import java.net.InetSocketAddress;
import java.security.Principal;
-import java.util.Arrays;
import java.util.Optional;
+import javax.ws.rs.InternalServerErrorException;
+import org.apache.commons.lang3.StringUtils;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.eclipse.jetty.websocket.server.JettyWebSocketCreator;
@@ -38,8 +37,10 @@ public class WebSocketResourceProviderFactory extends Jetty
private final ApplicationHandler jerseyApplicationHandler;
private final WebSocketConfiguration configuration;
+ private final String remoteAddressPropertyName;
+
public WebSocketResourceProviderFactory(WebSocketEnvironment environment, Class principalClass,
- WebSocketConfiguration configuration) {
+ WebSocketConfiguration configuration, String remoteAddressPropertyName) {
this.environment = environment;
environment.jersey().register(new WebSocketSessionContextValueFactoryProvider.Binder());
@@ -49,6 +50,7 @@ public class WebSocketResourceProviderFactory extends Jetty
this.jerseyApplicationHandler = new ApplicationHandler(environment.jersey());
this.configuration = configuration;
+ this.remoteAddressPropertyName = remoteAddressPropertyName;
}
@Override
@@ -69,6 +71,7 @@ public class WebSocketResourceProviderFactory extends Jetty
}
return new WebSocketResourceProvider<>(getRemoteAddress(request),
+ remoteAddressPropertyName,
this.jerseyApplicationHandler,
this.environment.getRequestLog(),
authenticated,
@@ -93,18 +96,11 @@ public class WebSocketResourceProviderFactory extends Jetty
}
private String getRemoteAddress(JettyServerUpgradeRequest request) {
- String forwardedFor = request.getHeader(HttpHeaders.X_FORWARDED_FOR);
-
- if (forwardedFor == null || forwardedFor.isBlank()) {
- if (request.getRemoteSocketAddress() instanceof InetSocketAddress inetSocketAddress) {
- return inetSocketAddress.getAddress().getHostAddress();
- }
- return null;
- } else {
- return Arrays.stream(forwardedFor.split(","))
- .map(String::trim)
- .reduce((a, b) -> b)
- .orElseThrow();
+ final String remoteAddress = (String) request.getHttpServletRequest().getAttribute(remoteAddressPropertyName);
+ if (StringUtils.isBlank(remoteAddress)) {
+ logger.error("Remote address property is not present");
+ throw new InternalServerErrorException();
}
+ return remoteAddress;
}
}
diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java
index fe76b4faf..4546242ea 100644
--- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java
+++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java
@@ -18,8 +18,8 @@ import java.io.IOException;
import java.security.Principal;
import java.util.Optional;
import javax.security.auth.Subject;
+import javax.servlet.http.HttpServletRequest;
import org.eclipse.jetty.websocket.api.Session;
-import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory;
@@ -33,6 +33,8 @@ import org.whispersystems.websocket.setup.WebSocketEnvironment;
public class WebSocketResourceProviderFactoryTest {
+ private static final String REMOTE_ADDRESS_PROPERTY_NAME = "org.whispersystems.websocket.test.remoteAddress";
+
private ResourceConfig jerseyEnvironment;
private WebSocketEnvironment environment;
private WebSocketAuthenticator authenticator;
@@ -59,7 +61,7 @@ public class WebSocketResourceProviderFactoryTest {
when(environment.jersey()).thenReturn(jerseyEnvironment);
WebSocketResourceProviderFactory> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
- mock(WebSocketConfiguration.class));
+ mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME);
Object connection = factory.createWebSocket(request, response);
assertNull(connection);
@@ -69,24 +71,25 @@ public class WebSocketResourceProviderFactoryTest {
@Test
void testValidAuthorization() throws AuthenticationException {
- Session session = mock(Session.class);
Account account = new Account();
when(environment.getAuthenticator()).thenReturn(authenticator);
when(authenticator.authenticate(eq(request))).thenReturn(
new WebSocketAuthenticator.AuthenticationResult<>(Optional.of(account), true));
when(environment.jersey()).thenReturn(jerseyEnvironment);
- when(session.getUpgradeRequest()).thenReturn(mock(UpgradeRequest.class));
+ final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
+ when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1");
+ when(request.getHttpServletRequest()).thenReturn(httpServletRequest);
WebSocketResourceProviderFactory> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
- mock(WebSocketConfiguration.class));
+ mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME);
Object connection = factory.createWebSocket(request, response);
assertNotNull(connection);
verifyNoMoreInteractions(response);
verify(authenticator).authenticate(eq(request));
- ((WebSocketResourceProvider>) connection).onWebSocketConnect(session);
+ ((WebSocketResourceProvider>) connection).onWebSocketConnect(mock(Session.class));
assertNotNull(((WebSocketResourceProvider>) connection).getContext().getAuthenticated());
assertEquals(((WebSocketResourceProvider>) connection).getContext().getAuthenticated(), account);
@@ -100,7 +103,8 @@ public class WebSocketResourceProviderFactoryTest {
WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment,
Account.class,
- mock(WebSocketConfiguration.class));
+ mock(WebSocketConfiguration.class),
+ REMOTE_ADDRESS_PROPERTY_NAME);
Object connection = factory.createWebSocket(request, response);
assertNull(connection);
@@ -115,7 +119,8 @@ public class WebSocketResourceProviderFactoryTest {
WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory<>(environment,
Account.class,
- mock(WebSocketConfiguration.class));
+ mock(WebSocketConfiguration.class),
+ REMOTE_ADDRESS_PROPERTY_NAME);
factory.configure(servletFactory);
verify(servletFactory).setCreator(eq(factory));
diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java
index 6feba40e6..2f8536cd4 100644
--- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java
+++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java
@@ -70,12 +70,15 @@ import org.whispersystems.websocket.setup.WebSocketConnectListener;
class WebSocketResourceProviderTest {
+ private static final String REMOTE_ADDRESS_PROPERTY_NAME = "org.whispersystems.weboscket.test.remoteAddress";
+
@Test
void testOnConnect() {
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
WebSocketConnectListener connectListener = mock(WebSocketConnectListener.class);
WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1",
+ REMOTE_ADDRESS_PROPERTY_NAME,
applicationHandler, requestLog,
new TestPrincipal("fooz"),
new ProtobufWebSocketMessageFactory(),
@@ -104,9 +107,9 @@ class WebSocketResourceProviderTest {
void testMockedRouteMessageSuccess() throws Exception {
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
- WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
- requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
- Duration.ofMillis(30000));
+ WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1",
+ REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
+ new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -180,9 +183,9 @@ class WebSocketResourceProviderTest {
void testMockedRouteMessageFailure() throws Exception {
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
- WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
- requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
- Duration.ofMillis(30000));
+ WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1",
+ REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
+ new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -236,9 +239,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
- WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
- requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
- Duration.ofMillis(30000));
+ WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1",
+ REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
+ new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -276,9 +279,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
- WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
- requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
- Duration.ofMillis(30000));
+ WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1",
+ REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
+ new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -316,9 +319,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
- WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
- requestLog, new TestPrincipal("authorizedUserName"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
- Duration.ofMillis(30000));
+ WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1",
+ REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("authorizedUserName"),
+ new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -356,8 +359,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
- WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
- requestLog, null, new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
+ WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1",
+ REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(),
+ Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -394,9 +398,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
- WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
- requestLog, new TestPrincipal("something"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
- Duration.ofMillis(30000));
+ WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1",
+ REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("something"),
+ new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -434,8 +438,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
- WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
- requestLog, null, new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
+ WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1",
+ REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(),
+ Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -473,9 +478,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
- WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
- requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
- Duration.ofMillis(30000));
+ WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1",
+ REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
+ new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -514,9 +519,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
- WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
- requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
- Duration.ofMillis(30000));
+ WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1",
+ REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
+ new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -556,9 +561,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
- WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
- requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
- Duration.ofMillis(30000));
+ WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1",
+ REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
+ new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
@@ -596,9 +601,9 @@ class WebSocketResourceProviderTest {
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
- WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
- requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
- Duration.ofMillis(30000));
+ WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1",
+ REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
+ new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
Session session = mock(Session.class);
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);