Refactor remote address/X-Forwarded-For handling
This commit is contained in:
parent
4475d65780
commit
2ab14ca59e
|
@ -190,6 +190,11 @@
|
||||||
<groupId>org.eclipse.jetty</groupId>
|
<groupId>org.eclipse.jetty</groupId>
|
||||||
<artifactId>jetty-servlets</artifactId>
|
<artifactId>jetty-servlets</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.eclipse.jetty.websocket</groupId>
|
||||||
|
<artifactId>websocket-jetty-client</artifactId>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.commons</groupId>
|
<groupId>org.apache.commons</groupId>
|
||||||
|
|
|
@ -44,6 +44,7 @@ import java.util.concurrent.LinkedBlockingQueue;
|
||||||
import java.util.concurrent.ScheduledExecutorService;
|
import java.util.concurrent.ScheduledExecutorService;
|
||||||
import java.util.concurrent.ThreadPoolExecutor;
|
import java.util.concurrent.ThreadPoolExecutor;
|
||||||
import javax.servlet.DispatcherType;
|
import javax.servlet.DispatcherType;
|
||||||
|
import javax.servlet.Filter;
|
||||||
import javax.servlet.FilterRegistration;
|
import javax.servlet.FilterRegistration;
|
||||||
import javax.servlet.ServletRegistration;
|
import javax.servlet.ServletRegistration;
|
||||||
import org.eclipse.jetty.servlets.CrossOriginFilter;
|
import org.eclipse.jetty.servlets.CrossOriginFilter;
|
||||||
|
@ -115,6 +116,7 @@ import org.whispersystems.textsecuregcm.currency.CoinMarketCapClient;
|
||||||
import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager;
|
import org.whispersystems.textsecuregcm.currency.CurrencyConversionManager;
|
||||||
import org.whispersystems.textsecuregcm.currency.FixerClient;
|
import org.whispersystems.textsecuregcm.currency.FixerClient;
|
||||||
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
|
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
|
||||||
|
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
|
||||||
import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter;
|
import org.whispersystems.textsecuregcm.filters.RemoteDeprecationFilter;
|
||||||
import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter;
|
import org.whispersystems.textsecuregcm.filters.RequestStatisticsFilter;
|
||||||
import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter;
|
import org.whispersystems.textsecuregcm.filters.TimestampResponseFilter;
|
||||||
|
@ -212,8 +214,8 @@ import org.whispersystems.textsecuregcm.util.DynamoDbFromConfig;
|
||||||
import org.whispersystems.textsecuregcm.util.ManagedAwsCrt;
|
import org.whispersystems.textsecuregcm.util.ManagedAwsCrt;
|
||||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||||
import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier;
|
import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier;
|
||||||
import org.whispersystems.textsecuregcm.util.VirtualThreadPinEventMonitor;
|
|
||||||
import org.whispersystems.textsecuregcm.util.VirtualExecutorServiceProvider;
|
import org.whispersystems.textsecuregcm.util.VirtualExecutorServiceProvider;
|
||||||
|
import org.whispersystems.textsecuregcm.util.VirtualThreadPinEventMonitor;
|
||||||
import org.whispersystems.textsecuregcm.util.logging.LoggingUnhandledExceptionMapper;
|
import org.whispersystems.textsecuregcm.util.logging.LoggingUnhandledExceptionMapper;
|
||||||
import org.whispersystems.textsecuregcm.util.logging.UncaughtExceptionHandler;
|
import org.whispersystems.textsecuregcm.util.logging.UncaughtExceptionHandler;
|
||||||
import org.whispersystems.textsecuregcm.websocket.AuthenticatedConnectListener;
|
import org.whispersystems.textsecuregcm.websocket.AuthenticatedConnectListener;
|
||||||
|
@ -718,10 +720,16 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
config.getBadges(), asyncCdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, rateLimiters, zkProfileOperations, config.getCdnConfiguration().bucket()), basicCredentialAuthenticationInterceptor))
|
config.getBadges(), asyncCdnS3Client, profileCdnPolicyGenerator, profileCdnPolicySigner, profileBadgeConverter, rateLimiters, zkProfileOperations, config.getCdnConfiguration().bucket()), basicCredentialAuthenticationInterceptor))
|
||||||
.addService(new ProfileAnonymousGrpcService(accountsManager, profilesManager, profileBadgeConverter, zkProfileOperations));
|
.addService(new ProfileAnonymousGrpcService(accountsManager, profilesManager, profileBadgeConverter, zkProfileOperations));
|
||||||
|
|
||||||
RemoteDeprecationFilter remoteDeprecationFilter = new RemoteDeprecationFilter(dynamicConfigurationManager);
|
final List<Filter> filters = new ArrayList<>();
|
||||||
environment.servlets()
|
final RemoteDeprecationFilter remoteDeprecationFilter = new RemoteDeprecationFilter(dynamicConfigurationManager);
|
||||||
.addFilter("RemoteDeprecationFilter", remoteDeprecationFilter)
|
filters.add(remoteDeprecationFilter);
|
||||||
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
|
filters.add(new RemoteAddressFilter(useRemoteAddress));
|
||||||
|
|
||||||
|
for (Filter filter : filters) {
|
||||||
|
environment.servlets()
|
||||||
|
.addFilter(filter.getClass().getSimpleName(), filter)
|
||||||
|
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
|
||||||
|
}
|
||||||
|
|
||||||
// Note: interceptors run in the reverse order they are added; the remote deprecation filter
|
// Note: interceptors run in the reverse order they are added; the remote deprecation filter
|
||||||
// depends on the user-agent context so it has to come first here!
|
// depends on the user-agent context so it has to come first here!
|
||||||
|
@ -832,7 +840,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().certificate().value(),
|
new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().certificate().value(),
|
||||||
config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()),
|
config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()),
|
||||||
zkAuthOperations, callingGenericZkSecretParams, clock),
|
zkAuthOperations, callingGenericZkSecretParams, clock),
|
||||||
new ChallengeController(rateLimitChallengeManager, useRemoteAddress),
|
new ChallengeController(rateLimitChallengeManager),
|
||||||
new DeviceController(config.getLinkDeviceSecretConfiguration().secret().value(), accountsManager,
|
new DeviceController(config.getLinkDeviceSecretConfiguration().secret().value(), accountsManager,
|
||||||
rateLimiters, rateLimitersCluster, config.getMaxDevices(), clock),
|
rateLimiters, rateLimitersCluster, config.getMaxDevices(), clock),
|
||||||
new DirectoryV2Controller(directoryV2CredentialsGenerator),
|
new DirectoryV2Controller(directoryV2CredentialsGenerator),
|
||||||
|
@ -859,7 +867,7 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
config.getCdnConfiguration().bucket()),
|
config.getCdnConfiguration().bucket()),
|
||||||
new VerificationController(registrationServiceClient, new VerificationSessionManager(verificationSessions),
|
new VerificationController(registrationServiceClient, new VerificationSessionManager(verificationSessions),
|
||||||
pushNotificationManager, registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters,
|
pushNotificationManager, registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters,
|
||||||
accountsManager, useRemoteAddress, dynamicConfigurationManager, clock)
|
accountsManager, dynamicConfigurationManager, clock)
|
||||||
);
|
);
|
||||||
if (config.getSubscription() != null && config.getOneTimeDonations() != null) {
|
if (config.getSubscription() != null && config.getOneTimeDonations() != null) {
|
||||||
commonControllers.add(new SubscriptionController(clock, config.getSubscription(), config.getOneTimeDonations(),
|
commonControllers.add(new SubscriptionController(clock, config.getSubscription(), config.getOneTimeDonations(),
|
||||||
|
@ -890,9 +898,11 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
|
||||||
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
|
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
|
||||||
|
|
||||||
WebSocketResourceProviderFactory<AuthenticatedAccount> webSocketServlet = new WebSocketResourceProviderFactory<>(
|
WebSocketResourceProviderFactory<AuthenticatedAccount> webSocketServlet = new WebSocketResourceProviderFactory<>(
|
||||||
webSocketEnvironment, AuthenticatedAccount.class, config.getWebSocketConfiguration());
|
webSocketEnvironment, AuthenticatedAccount.class, config.getWebSocketConfiguration(),
|
||||||
|
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
|
||||||
WebSocketResourceProviderFactory<AuthenticatedAccount> provisioningServlet = new WebSocketResourceProviderFactory<>(
|
WebSocketResourceProviderFactory<AuthenticatedAccount> provisioningServlet = new WebSocketResourceProviderFactory<>(
|
||||||
provisioningEnvironment, AuthenticatedAccount.class, config.getWebSocketConfiguration());
|
provisioningEnvironment, AuthenticatedAccount.class, config.getWebSocketConfiguration(),
|
||||||
|
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
|
||||||
|
|
||||||
ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet);
|
ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet);
|
||||||
ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet);
|
ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet);
|
||||||
|
|
|
@ -19,15 +19,14 @@ import io.swagger.v3.oas.annotations.parameters.RequestBody;
|
||||||
import io.swagger.v3.oas.annotations.responses.ApiResponse;
|
import io.swagger.v3.oas.annotations.responses.ApiResponse;
|
||||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import javax.servlet.http.HttpServletRequest;
|
|
||||||
import javax.validation.Valid;
|
import javax.validation.Valid;
|
||||||
import javax.ws.rs.BadRequestException;
|
|
||||||
import javax.ws.rs.Consumes;
|
import javax.ws.rs.Consumes;
|
||||||
import javax.ws.rs.HeaderParam;
|
import javax.ws.rs.HeaderParam;
|
||||||
import javax.ws.rs.POST;
|
import javax.ws.rs.POST;
|
||||||
import javax.ws.rs.PUT;
|
import javax.ws.rs.PUT;
|
||||||
import javax.ws.rs.Path;
|
import javax.ws.rs.Path;
|
||||||
import javax.ws.rs.Produces;
|
import javax.ws.rs.Produces;
|
||||||
|
import javax.ws.rs.container.ContainerRequestContext;
|
||||||
import javax.ws.rs.core.Context;
|
import javax.ws.rs.core.Context;
|
||||||
import javax.ws.rs.core.MediaType;
|
import javax.ws.rs.core.MediaType;
|
||||||
import javax.ws.rs.core.Response;
|
import javax.ws.rs.core.Response;
|
||||||
|
@ -35,6 +34,7 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
|
||||||
import org.whispersystems.textsecuregcm.entities.AnswerChallengeRequest;
|
import org.whispersystems.textsecuregcm.entities.AnswerChallengeRequest;
|
||||||
import org.whispersystems.textsecuregcm.entities.AnswerPushChallengeRequest;
|
import org.whispersystems.textsecuregcm.entities.AnswerPushChallengeRequest;
|
||||||
import org.whispersystems.textsecuregcm.entities.AnswerRecaptchaChallengeRequest;
|
import org.whispersystems.textsecuregcm.entities.AnswerRecaptchaChallengeRequest;
|
||||||
|
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
|
||||||
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
|
import org.whispersystems.textsecuregcm.limits.RateLimitChallengeManager;
|
||||||
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
|
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
|
||||||
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
|
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
|
||||||
|
@ -42,8 +42,6 @@ import org.whispersystems.textsecuregcm.spam.Extract;
|
||||||
import org.whispersystems.textsecuregcm.spam.FilterSpam;
|
import org.whispersystems.textsecuregcm.spam.FilterSpam;
|
||||||
import org.whispersystems.textsecuregcm.spam.PushChallengeConfig;
|
import org.whispersystems.textsecuregcm.spam.PushChallengeConfig;
|
||||||
import org.whispersystems.textsecuregcm.spam.ScoreThreshold;
|
import org.whispersystems.textsecuregcm.spam.ScoreThreshold;
|
||||||
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
|
||||||
import org.whispersystems.textsecuregcm.util.HttpServletRequestUtil;
|
|
||||||
|
|
||||||
@Path("/v1/challenge")
|
@Path("/v1/challenge")
|
||||||
@Tag(name = "Challenge")
|
@Tag(name = "Challenge")
|
||||||
|
@ -51,15 +49,12 @@ import org.whispersystems.textsecuregcm.util.HttpServletRequestUtil;
|
||||||
public class ChallengeController {
|
public class ChallengeController {
|
||||||
|
|
||||||
private final RateLimitChallengeManager rateLimitChallengeManager;
|
private final RateLimitChallengeManager rateLimitChallengeManager;
|
||||||
private final boolean useRemoteAddress;
|
|
||||||
|
|
||||||
private static final String CHALLENGE_RESPONSE_COUNTER_NAME = name(ChallengeController.class, "challengeResponse");
|
private static final String CHALLENGE_RESPONSE_COUNTER_NAME = name(ChallengeController.class, "challengeResponse");
|
||||||
private static final String CHALLENGE_TYPE_TAG = "type";
|
private static final String CHALLENGE_TYPE_TAG = "type";
|
||||||
|
|
||||||
public ChallengeController(final RateLimitChallengeManager rateLimitChallengeManager,
|
public ChallengeController(final RateLimitChallengeManager rateLimitChallengeManager) {
|
||||||
final boolean useRemoteAddress) {
|
|
||||||
this.rateLimitChallengeManager = rateLimitChallengeManager;
|
this.rateLimitChallengeManager = rateLimitChallengeManager;
|
||||||
this.useRemoteAddress = useRemoteAddress;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@PUT
|
@PUT
|
||||||
|
@ -84,8 +79,7 @@ public class ChallengeController {
|
||||||
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
|
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
|
||||||
public Response handleChallengeResponse(@Auth final AuthenticatedAccount auth,
|
public Response handleChallengeResponse(@Auth final AuthenticatedAccount auth,
|
||||||
@Valid final AnswerChallengeRequest answerRequest,
|
@Valid final AnswerChallengeRequest answerRequest,
|
||||||
@HeaderParam(HttpHeaders.X_FORWARDED_FOR) final String forwardedFor,
|
@Context ContainerRequestContext requestContext,
|
||||||
@Context HttpServletRequest request,
|
|
||||||
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
|
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
|
||||||
@Extract final ScoreThreshold captchaScoreThreshold,
|
@Extract final ScoreThreshold captchaScoreThreshold,
|
||||||
@Extract final PushChallengeConfig pushChallengeConfig) throws RateLimitExceededException, IOException {
|
@Extract final PushChallengeConfig pushChallengeConfig) throws RateLimitExceededException, IOException {
|
||||||
|
@ -103,9 +97,8 @@ public class ChallengeController {
|
||||||
} else if (answerRequest instanceof AnswerRecaptchaChallengeRequest recaptchaChallengeRequest) {
|
} else if (answerRequest instanceof AnswerRecaptchaChallengeRequest recaptchaChallengeRequest) {
|
||||||
tags = tags.and(CHALLENGE_TYPE_TAG, "recaptcha");
|
tags = tags.and(CHALLENGE_TYPE_TAG, "recaptcha");
|
||||||
|
|
||||||
final String remoteAddress = useRemoteAddress
|
final String remoteAddress = (String) requestContext.getProperty(
|
||||||
? HttpServletRequestUtil.getRemoteAddress(request)
|
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
|
||||||
: HeaderUtils.getMostRecentProxy(forwardedFor).orElseThrow(BadRequestException::new);
|
|
||||||
boolean success = rateLimitChallengeManager.answerRecaptchaChallenge(
|
boolean success = rateLimitChallengeManager.answerRecaptchaChallenge(
|
||||||
auth.getAccount(),
|
auth.getAccount(),
|
||||||
recaptchaChallengeRequest.getCaptcha(),
|
recaptchaChallengeRequest.getCaptcha(),
|
||||||
|
|
|
@ -31,7 +31,6 @@ import java.util.Optional;
|
||||||
import java.util.concurrent.CancellationException;
|
import java.util.concurrent.CancellationException;
|
||||||
import java.util.concurrent.CompletionException;
|
import java.util.concurrent.CompletionException;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
import javax.servlet.http.HttpServletRequest;
|
|
||||||
import javax.validation.Valid;
|
import javax.validation.Valid;
|
||||||
import javax.validation.constraints.NotNull;
|
import javax.validation.constraints.NotNull;
|
||||||
import javax.ws.rs.BadRequestException;
|
import javax.ws.rs.BadRequestException;
|
||||||
|
@ -49,6 +48,7 @@ import javax.ws.rs.PathParam;
|
||||||
import javax.ws.rs.Produces;
|
import javax.ws.rs.Produces;
|
||||||
import javax.ws.rs.ServerErrorException;
|
import javax.ws.rs.ServerErrorException;
|
||||||
import javax.ws.rs.WebApplicationException;
|
import javax.ws.rs.WebApplicationException;
|
||||||
|
import javax.ws.rs.container.ContainerRequestContext;
|
||||||
import javax.ws.rs.core.Context;
|
import javax.ws.rs.core.Context;
|
||||||
import javax.ws.rs.core.HttpHeaders;
|
import javax.ws.rs.core.HttpHeaders;
|
||||||
import javax.ws.rs.core.MediaType;
|
import javax.ws.rs.core.MediaType;
|
||||||
|
@ -66,6 +66,7 @@ import org.whispersystems.textsecuregcm.entities.SubmitVerificationCodeRequest;
|
||||||
import org.whispersystems.textsecuregcm.entities.UpdateVerificationSessionRequest;
|
import org.whispersystems.textsecuregcm.entities.UpdateVerificationSessionRequest;
|
||||||
import org.whispersystems.textsecuregcm.entities.VerificationCodeRequest;
|
import org.whispersystems.textsecuregcm.entities.VerificationCodeRequest;
|
||||||
import org.whispersystems.textsecuregcm.entities.VerificationSessionResponse;
|
import org.whispersystems.textsecuregcm.entities.VerificationSessionResponse;
|
||||||
|
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
|
||||||
import org.whispersystems.textsecuregcm.limits.RateLimiter;
|
import org.whispersystems.textsecuregcm.limits.RateLimiter;
|
||||||
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
import org.whispersystems.textsecuregcm.limits.RateLimiters;
|
||||||
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
|
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
|
||||||
|
@ -88,8 +89,6 @@ import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||||
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
|
import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager;
|
||||||
import org.whispersystems.textsecuregcm.storage.VerificationSessionManager;
|
import org.whispersystems.textsecuregcm.storage.VerificationSessionManager;
|
||||||
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
|
||||||
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
|
||||||
import org.whispersystems.textsecuregcm.util.HttpServletRequestUtil;
|
|
||||||
import org.whispersystems.textsecuregcm.util.Pair;
|
import org.whispersystems.textsecuregcm.util.Pair;
|
||||||
import org.whispersystems.textsecuregcm.util.Util;
|
import org.whispersystems.textsecuregcm.util.Util;
|
||||||
|
|
||||||
|
@ -123,7 +122,6 @@ public class VerificationController {
|
||||||
private final RateLimiters rateLimiters;
|
private final RateLimiters rateLimiters;
|
||||||
private final AccountsManager accountsManager;
|
private final AccountsManager accountsManager;
|
||||||
|
|
||||||
private final boolean useRemoteAddress;
|
|
||||||
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
|
private final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager;
|
||||||
private final Clock clock;
|
private final Clock clock;
|
||||||
|
|
||||||
|
@ -134,7 +132,6 @@ public class VerificationController {
|
||||||
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager,
|
final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager,
|
||||||
final RateLimiters rateLimiters,
|
final RateLimiters rateLimiters,
|
||||||
final AccountsManager accountsManager,
|
final AccountsManager accountsManager,
|
||||||
final boolean useRemoteAddress,
|
|
||||||
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
final DynamicConfigurationManager<DynamicConfiguration> dynamicConfigurationManager,
|
||||||
final Clock clock) {
|
final Clock clock) {
|
||||||
this.registrationServiceClient = registrationServiceClient;
|
this.registrationServiceClient = registrationServiceClient;
|
||||||
|
@ -144,7 +141,6 @@ public class VerificationController {
|
||||||
this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager;
|
this.registrationRecoveryPasswordsManager = registrationRecoveryPasswordsManager;
|
||||||
this.rateLimiters = rateLimiters;
|
this.rateLimiters = rateLimiters;
|
||||||
this.accountsManager = accountsManager;
|
this.accountsManager = accountsManager;
|
||||||
this.useRemoteAddress = useRemoteAddress;
|
|
||||||
this.dynamicConfigurationManager = dynamicConfigurationManager;
|
this.dynamicConfigurationManager = dynamicConfigurationManager;
|
||||||
this.clock = clock;
|
this.clock = clock;
|
||||||
}
|
}
|
||||||
|
@ -205,16 +201,13 @@ public class VerificationController {
|
||||||
@Consumes(MediaType.APPLICATION_JSON)
|
@Consumes(MediaType.APPLICATION_JSON)
|
||||||
@Produces(MediaType.APPLICATION_JSON)
|
@Produces(MediaType.APPLICATION_JSON)
|
||||||
public VerificationSessionResponse updateSession(@PathParam("sessionId") final String encodedSessionId,
|
public VerificationSessionResponse updateSession(@PathParam("sessionId") final String encodedSessionId,
|
||||||
@HeaderParam(com.google.common.net.HttpHeaders.X_FORWARDED_FOR) String forwardedFor,
|
|
||||||
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
|
@HeaderParam(HttpHeaders.USER_AGENT) final String userAgent,
|
||||||
@Context HttpServletRequest request,
|
@Context ContainerRequestContext requestContext,
|
||||||
@NotNull @Valid final UpdateVerificationSessionRequest updateVerificationSessionRequest,
|
@NotNull @Valid final UpdateVerificationSessionRequest updateVerificationSessionRequest,
|
||||||
@NotNull @Extract final ScoreThreshold scoreThreshold,
|
@NotNull @Extract final ScoreThreshold scoreThreshold,
|
||||||
@NotNull @Extract final SenderOverride senderOverride) {
|
@NotNull @Extract final SenderOverride senderOverride) {
|
||||||
|
|
||||||
final String sourceHost = useRemoteAddress
|
final String sourceHost = (String) requestContext.getProperty(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
|
||||||
? HttpServletRequestUtil.getRemoteAddress(request)
|
|
||||||
: HeaderUtils.getMostRecentProxy(forwardedFor).orElseThrow();
|
|
||||||
|
|
||||||
final Pair<String, PushNotification.TokenType> pushTokenAndType = validateAndExtractPushToken(
|
final Pair<String, PushNotification.TokenType> pushTokenAndType = validateAndExtractPushToken(
|
||||||
updateVerificationSessionRequest);
|
updateVerificationSessionRequest);
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2024 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.filters;
|
||||||
|
|
||||||
|
import javax.servlet.Filter;
|
||||||
|
import javax.servlet.FilterChain;
|
||||||
|
import javax.servlet.ServletException;
|
||||||
|
import javax.servlet.ServletRequest;
|
||||||
|
import javax.servlet.ServletResponse;
|
||||||
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
||||||
|
import org.whispersystems.textsecuregcm.util.HttpServletRequestUtil;
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets a {@link HttpServletRequest} attribute (that will also be available as a
|
||||||
|
* {@link javax.ws.rs.container.ContainerRequestContext} property) with the remote address of the connection, using
|
||||||
|
* either the {@link HttpServletRequest#getRemoteAddr()} or the {@code X-Forwarded-For} HTTP header value, depending on
|
||||||
|
* whether {@link #preferRemoteAddress} is {@code true}.
|
||||||
|
*/
|
||||||
|
public class RemoteAddressFilter implements Filter {
|
||||||
|
|
||||||
|
public static final String REMOTE_ADDRESS_ATTRIBUTE_NAME = RemoteAddressFilter.class.getName() + ".remoteAddress";
|
||||||
|
private static final Logger logger = LoggerFactory.getLogger(RemoteAddressFilter.class);
|
||||||
|
|
||||||
|
private final boolean preferRemoteAddress;
|
||||||
|
|
||||||
|
|
||||||
|
public RemoteAddressFilter(boolean preferRemoteAddress) {
|
||||||
|
this.preferRemoteAddress = preferRemoteAddress;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain)
|
||||||
|
throws ServletException, IOException {
|
||||||
|
|
||||||
|
if (request instanceof HttpServletRequest httpServletRequest) {
|
||||||
|
|
||||||
|
final String remoteAddress;
|
||||||
|
|
||||||
|
if (preferRemoteAddress) {
|
||||||
|
remoteAddress = HttpServletRequestUtil.getRemoteAddress(httpServletRequest);
|
||||||
|
} else {
|
||||||
|
final String forwardedFor = httpServletRequest.getHeader(com.google.common.net.HttpHeaders.X_FORWARDED_FOR);
|
||||||
|
remoteAddress = HeaderUtils.getMostRecentProxy(forwardedFor)
|
||||||
|
.orElseGet(() -> HttpServletRequestUtil.getRemoteAddress(httpServletRequest));
|
||||||
|
}
|
||||||
|
|
||||||
|
request.setAttribute(REMOTE_ADDRESS_ATTRIBUTE_NAME, remoteAddress);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
logger.warn("request was of unexpected type: {}", request.getClass());
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.doFilter(request, response);
|
||||||
|
}
|
||||||
|
}
|
|
@ -7,15 +7,12 @@ package org.whispersystems.textsecuregcm.filters;
|
||||||
|
|
||||||
import static com.codahale.metrics.MetricRegistry.name;
|
import static com.codahale.metrics.MetricRegistry.name;
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
|
||||||
import com.google.common.net.HttpHeaders;
|
import com.google.common.net.HttpHeaders;
|
||||||
import com.vdurmont.semver4j.Semver;
|
import com.vdurmont.semver4j.Semver;
|
||||||
|
|
||||||
import io.grpc.Metadata;
|
import io.grpc.Metadata;
|
||||||
import io.grpc.ServerCall;
|
import io.grpc.ServerCall;
|
||||||
import io.grpc.ServerCallHandler;
|
import io.grpc.ServerCallHandler;
|
||||||
import io.grpc.ServerInterceptor;
|
import io.grpc.ServerInterceptor;
|
||||||
import io.grpc.Status;
|
|
||||||
import io.micrometer.core.instrument.Metrics;
|
import io.micrometer.core.instrument.Metrics;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@ -30,7 +27,6 @@ import javax.servlet.http.HttpServletResponse;
|
||||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
|
||||||
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration;
|
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicRemoteDeprecationConfiguration;
|
||||||
import org.whispersystems.textsecuregcm.grpc.StatusConstants;
|
import org.whispersystems.textsecuregcm.grpc.StatusConstants;
|
||||||
import org.whispersystems.textsecuregcm.grpc.UserAgentInterceptor;
|
|
||||||
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager;
|
||||||
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
|
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
|
||||||
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
|
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
|
||||||
|
|
|
@ -8,33 +8,25 @@ package org.whispersystems.textsecuregcm.limits;
|
||||||
import static java.util.Objects.requireNonNull;
|
import static java.util.Objects.requireNonNull;
|
||||||
|
|
||||||
import com.google.common.annotations.VisibleForTesting;
|
import com.google.common.annotations.VisibleForTesting;
|
||||||
import com.google.common.net.HttpHeaders;
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import javax.inject.Provider;
|
|
||||||
import javax.servlet.http.HttpServletRequest;
|
|
||||||
import javax.ws.rs.ClientErrorException;
|
import javax.ws.rs.ClientErrorException;
|
||||||
import javax.ws.rs.container.ContainerRequestContext;
|
import javax.ws.rs.container.ContainerRequestContext;
|
||||||
import javax.ws.rs.container.ContainerRequestFilter;
|
import javax.ws.rs.container.ContainerRequestFilter;
|
||||||
import javax.ws.rs.core.Context;
|
|
||||||
import javax.ws.rs.core.Response;
|
import javax.ws.rs.core.Response;
|
||||||
import javax.ws.rs.ext.ExceptionMapper;
|
import javax.ws.rs.ext.ExceptionMapper;
|
||||||
import org.glassfish.jersey.server.ExtendedUriInfo;
|
import org.glassfish.jersey.server.ExtendedUriInfo;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||||
|
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
|
||||||
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
|
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
|
||||||
import org.whispersystems.textsecuregcm.util.HeaderUtils;
|
|
||||||
import org.whispersystems.textsecuregcm.util.HttpServletRequestUtil;
|
|
||||||
|
|
||||||
public class RateLimitByIpFilter implements ContainerRequestFilter {
|
public class RateLimitByIpFilter implements ContainerRequestFilter {
|
||||||
|
|
||||||
private static final Logger logger = LoggerFactory.getLogger(RateLimitByIpFilter.class);
|
private static final Logger logger = LoggerFactory.getLogger(RateLimitByIpFilter.class);
|
||||||
|
|
||||||
@Context
|
|
||||||
private Provider<HttpServletRequest> httpServletRequestProvider;
|
|
||||||
|
|
||||||
@VisibleForTesting
|
@VisibleForTesting
|
||||||
static final RateLimitExceededException INVALID_HEADER_EXCEPTION = new RateLimitExceededException(Duration.ofHours(1),
|
static final RateLimitExceededException INVALID_HEADER_EXCEPTION = new RateLimitExceededException(Duration.ofHours(1),
|
||||||
true);
|
true);
|
||||||
|
@ -42,12 +34,10 @@ public class RateLimitByIpFilter implements ContainerRequestFilter {
|
||||||
private static final ExceptionMapper<RateLimitExceededException> EXCEPTION_MAPPER = new RateLimitExceededExceptionMapper();
|
private static final ExceptionMapper<RateLimitExceededException> EXCEPTION_MAPPER = new RateLimitExceededExceptionMapper();
|
||||||
|
|
||||||
private final RateLimiters rateLimiters;
|
private final RateLimiters rateLimiters;
|
||||||
private final boolean useRemoteAddress;
|
|
||||||
|
|
||||||
|
|
||||||
public RateLimitByIpFilter(final RateLimiters rateLimiters, final boolean useRemoteAddress) {
|
public RateLimitByIpFilter(final RateLimiters rateLimiters) {
|
||||||
this.rateLimiters = requireNonNull(rateLimiters);
|
this.rateLimiters = requireNonNull(rateLimiters);
|
||||||
this.useRemoteAddress = useRemoteAddress;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -70,18 +60,14 @@ public class RateLimitByIpFilter implements ContainerRequestFilter {
|
||||||
final RateLimiters.For handle = annotation.value();
|
final RateLimiters.For handle = annotation.value();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
final String xffHeader = requestContext.getHeaders().getFirst(HttpHeaders.X_FORWARDED_FOR);
|
final Optional<String> remoteAddress = Optional.ofNullable(
|
||||||
final Optional<String> remoteAddress = useRemoteAddress
|
(String) requestContext.getProperty(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME));
|
||||||
? Optional.of(HttpServletRequestUtil.getRemoteAddress(httpServletRequestProvider.get()))
|
|
||||||
: Optional.ofNullable(xffHeader)
|
|
||||||
.flatMap(HeaderUtils::getMostRecentProxy);
|
|
||||||
|
|
||||||
// checking if we failed to extract the most recent IP from the X-Forwarded-For header
|
// checking if we failed to extract the most recent IP for any reason
|
||||||
// for any reason
|
|
||||||
if (remoteAddress.isEmpty()) {
|
if (remoteAddress.isEmpty()) {
|
||||||
// checking if annotation is configured to fail when the most recent IP is not resolved
|
// checking if annotation is configured to fail when the most recent IP is not resolved
|
||||||
if (annotation.failOnUnresolvedIp()) {
|
if (annotation.failOnUnresolvedIp()) {
|
||||||
logger.error("Missing/bad X-Forwarded-For: {}", xffHeader);
|
logger.error("Remote address was null");
|
||||||
throw INVALID_HEADER_EXCEPTION;
|
throw INVALID_HEADER_EXCEPTION;
|
||||||
}
|
}
|
||||||
// otherwise, allow request
|
// otherwise, allow request
|
||||||
|
|
|
@ -69,6 +69,7 @@ import org.junit.jupiter.params.provider.Arguments;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.junit.jupiter.params.provider.ValueSource;
|
import org.junit.jupiter.params.provider.ValueSource;
|
||||||
import org.mockito.ArgumentCaptor;
|
import org.mockito.ArgumentCaptor;
|
||||||
|
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
|
||||||
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
|
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
|
||||||
import org.whispersystems.textsecuregcm.storage.Account;
|
import org.whispersystems.textsecuregcm.storage.Account;
|
||||||
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
import org.whispersystems.textsecuregcm.storage.AccountsManager;
|
||||||
|
@ -306,9 +307,9 @@ class AuthEnablementRefreshRequirementProviderTest {
|
||||||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||||
|
|
||||||
provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
|
provider = new WebSocketResourceProvider<>("127.0.0.1", RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME,
|
||||||
requestLog, new TestPrincipal("test", account, authenticatedDevice), new ProtobufWebSocketMessageFactory(),
|
applicationHandler, requestLog, new TestPrincipal("test", account, authenticatedDevice),
|
||||||
Optional.empty(), Duration.ofMillis(30000));
|
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||||
|
|
||||||
remoteEndpoint = mock(RemoteEndpoint.class);
|
remoteEndpoint = mock(RemoteEndpoint.class);
|
||||||
Session session = mock(Session.class);
|
Session session = mock(Session.class);
|
||||||
|
|
|
@ -91,6 +91,7 @@ import org.whispersystems.textsecuregcm.util.CompletableFutureTestUtil;
|
||||||
import org.whispersystems.textsecuregcm.util.MockUtils;
|
import org.whispersystems.textsecuregcm.util.MockUtils;
|
||||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||||
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
import org.whispersystems.textsecuregcm.util.TestRandomUtil;
|
||||||
|
import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
|
||||||
import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier;
|
import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier;
|
||||||
|
|
||||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||||
|
@ -119,9 +120,6 @@ class AccountControllerTest {
|
||||||
private static final UUID SENDER_REG_LOCK_UUID = UUID.randomUUID();
|
private static final UUID SENDER_REG_LOCK_UUID = UUID.randomUUID();
|
||||||
private static final UUID SENDER_TRANSFER_UUID = UUID.randomUUID();
|
private static final UUID SENDER_TRANSFER_UUID = UUID.randomUUID();
|
||||||
|
|
||||||
private static final String NICE_HOST = "127.0.0.1";
|
|
||||||
private static final String RATE_LIMITED_IP_HOST = "10.0.0.1";
|
|
||||||
|
|
||||||
private static AccountsManager accountsManager = mock(AccountsManager.class);
|
private static AccountsManager accountsManager = mock(AccountsManager.class);
|
||||||
private static RateLimiters rateLimiters = mock(RateLimiters.class);
|
private static RateLimiters rateLimiters = mock(RateLimiters.class);
|
||||||
private static RateLimiter rateLimiter = mock(RateLimiter.class);
|
private static RateLimiter rateLimiter = mock(RateLimiter.class);
|
||||||
|
@ -140,6 +138,9 @@ class AccountControllerTest {
|
||||||
|
|
||||||
private byte[] registration_lock_key = new byte[32];
|
private byte[] registration_lock_key = new byte[32];
|
||||||
|
|
||||||
|
private static final TestRemoteAddressFilterProvider TEST_REMOTE_ADDRESS_FILTER_PROVIDER
|
||||||
|
= new TestRemoteAddressFilterProvider("127.0.0.1");
|
||||||
|
|
||||||
private static final ResourceExtension resources = ResourceExtension.builder()
|
private static final ResourceExtension resources = ResourceExtension.builder()
|
||||||
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
|
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
|
||||||
.addProvider(AuthHelper.getAuthFilter())
|
.addProvider(AuthHelper.getAuthFilter())
|
||||||
|
@ -148,7 +149,8 @@ class AccountControllerTest {
|
||||||
.addProvider(new RateLimitExceededExceptionMapper())
|
.addProvider(new RateLimitExceededExceptionMapper())
|
||||||
.addProvider(new ImpossiblePhoneNumberExceptionMapper())
|
.addProvider(new ImpossiblePhoneNumberExceptionMapper())
|
||||||
.addProvider(new NonNormalizedPhoneNumberExceptionMapper())
|
.addProvider(new NonNormalizedPhoneNumberExceptionMapper())
|
||||||
.addProvider(new RateLimitByIpFilter(rateLimiters, true))
|
.addProvider(TEST_REMOTE_ADDRESS_FILTER_PROVIDER)
|
||||||
|
.addProvider(new RateLimitByIpFilter(rateLimiters))
|
||||||
.addProvider(ScoreThresholdProvider.ScoreThresholdFeature.class)
|
.addProvider(ScoreThresholdProvider.ScoreThresholdFeature.class)
|
||||||
.addProvider(SenderOverrideProvider.SenderOverrideFeature.class)
|
.addProvider(SenderOverrideProvider.SenderOverrideFeature.class)
|
||||||
.setMapper(SystemMapper.jsonMapper())
|
.setMapper(SystemMapper.jsonMapper())
|
||||||
|
|
|
@ -43,18 +43,16 @@ import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
|
||||||
import org.whispersystems.textsecuregcm.spam.PushChallengeConfigProvider;
|
import org.whispersystems.textsecuregcm.spam.PushChallengeConfigProvider;
|
||||||
import org.whispersystems.textsecuregcm.spam.ScoreThreshold;
|
import org.whispersystems.textsecuregcm.spam.ScoreThreshold;
|
||||||
import org.whispersystems.textsecuregcm.spam.ScoreThresholdProvider;
|
import org.whispersystems.textsecuregcm.spam.ScoreThresholdProvider;
|
||||||
import org.whispersystems.textsecuregcm.spam.SenderOverride;
|
|
||||||
import org.whispersystems.textsecuregcm.spam.SenderOverrideProvider;
|
|
||||||
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
|
import org.whispersystems.textsecuregcm.tests.util.AuthHelper;
|
||||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||||
|
import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
|
||||||
|
|
||||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||||
class ChallengeControllerTest {
|
class ChallengeControllerTest {
|
||||||
|
|
||||||
private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class);
|
private static final RateLimitChallengeManager rateLimitChallengeManager = mock(RateLimitChallengeManager.class);
|
||||||
|
|
||||||
private static final ChallengeController challengeController = new ChallengeController(rateLimitChallengeManager,
|
private static final ChallengeController challengeController = new ChallengeController(rateLimitChallengeManager);
|
||||||
true);
|
|
||||||
|
|
||||||
private static final AtomicReference<Float> scoreThreshold = new AtomicReference<>();
|
private static final AtomicReference<Float> scoreThreshold = new AtomicReference<>();
|
||||||
|
|
||||||
|
@ -73,6 +71,7 @@ class ChallengeControllerTest {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
.addProvider(new TestRemoteAddressFilterProvider("127.0.0.1"))
|
||||||
.setMapper(SystemMapper.jsonMapper())
|
.setMapper(SystemMapper.jsonMapper())
|
||||||
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
||||||
.addResource(new RateLimitExceededExceptionMapper())
|
.addResource(new RateLimitExceededExceptionMapper())
|
||||||
|
|
|
@ -118,7 +118,7 @@ class VerificationControllerTest {
|
||||||
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
||||||
.addResource(
|
.addResource(
|
||||||
new VerificationController(registrationServiceClient, verificationSessionManager, pushNotificationManager,
|
new VerificationController(registrationServiceClient, verificationSessionManager, pushNotificationManager,
|
||||||
registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters, accountsManager, true,
|
registrationCaptchaManager, registrationRecoveryPasswordsManager, rateLimiters, accountsManager,
|
||||||
dynamicConfigurationManager, clock))
|
dynamicConfigurationManager, clock))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,308 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2024 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.filters;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
import static org.junit.jupiter.api.Assumptions.assumeTrue;
|
||||||
|
|
||||||
|
import com.google.common.net.HttpHeaders;
|
||||||
|
import io.dropwizard.core.Application;
|
||||||
|
import io.dropwizard.core.Configuration;
|
||||||
|
import io.dropwizard.core.setup.Environment;
|
||||||
|
import io.dropwizard.testing.junit5.DropwizardAppExtension;
|
||||||
|
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.net.InetAddress;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.security.Principal;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.EnumSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import javax.security.auth.Subject;
|
||||||
|
import javax.servlet.DispatcherType;
|
||||||
|
import javax.ws.rs.GET;
|
||||||
|
import javax.ws.rs.Path;
|
||||||
|
import javax.ws.rs.client.Client;
|
||||||
|
import javax.ws.rs.container.ContainerRequestContext;
|
||||||
|
import javax.ws.rs.core.Context;
|
||||||
|
import org.eclipse.jetty.util.HostPort;
|
||||||
|
import org.eclipse.jetty.websocket.api.Session;
|
||||||
|
import org.eclipse.jetty.websocket.api.WebSocketListener;
|
||||||
|
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
|
||||||
|
import org.eclipse.jetty.websocket.client.WebSocketClient;
|
||||||
|
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
|
||||||
|
import org.junit.jupiter.api.AfterEach;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Nested;
|
||||||
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
|
import org.junit.jupiter.params.provider.CsvSource;
|
||||||
|
import org.junit.jupiter.params.provider.ValueSource;
|
||||||
|
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||||
|
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
|
||||||
|
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
|
||||||
|
import org.whispersystems.websocket.messages.WebSocketMessage;
|
||||||
|
import org.whispersystems.websocket.messages.WebSocketMessageFactory;
|
||||||
|
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
|
||||||
|
import org.whispersystems.websocket.setup.WebSocketEnvironment;
|
||||||
|
|
||||||
|
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||||
|
class RemoteAddressFilterIntegrationTest {
|
||||||
|
|
||||||
|
private static final String WEBSOCKET_PREFIX = "/websocket";
|
||||||
|
private static final String REMOTE_ADDRESS_PATH = "/remoteAddress";
|
||||||
|
private static final String FORWARDED_FOR_PATH = "/forwardedFor";
|
||||||
|
private static final String WS_REQUEST_PATH = "/wsRequest";
|
||||||
|
|
||||||
|
// The Grizzly test container does not match the Jetty container used in real deployments, and JettyTestContainerFactory
|
||||||
|
// in jersey-test-framework-provider-jetty doesn’t easily support @Context HttpServletRequest, so this test runs a
|
||||||
|
// full Jetty server in a separate process
|
||||||
|
private static final DropwizardAppExtension<Configuration> EXTENSION = new DropwizardAppExtension<>(
|
||||||
|
TestApplication.class);
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class Rest {
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource(strings = {"127.0.0.1", "0:0:0:0:0:0:0:1"})
|
||||||
|
void testRemoteAddress(String ip) throws Exception {
|
||||||
|
final Set<String> addresses = Arrays.stream(InetAddress.getAllByName("localhost"))
|
||||||
|
.map(InetAddress::getHostAddress)
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
|
||||||
|
assumeTrue(addresses.contains(ip), String.format("localhost does not resolve to %s", ip));
|
||||||
|
|
||||||
|
Client client = EXTENSION.client();
|
||||||
|
|
||||||
|
final RemoteAddressFilterIntegrationTest.TestResponse response = client.target(
|
||||||
|
String.format("http://%s:%d%s", HostPort.normalizeHost(ip), EXTENSION.getLocalPort(), REMOTE_ADDRESS_PATH))
|
||||||
|
.request("application/json")
|
||||||
|
.get(RemoteAddressFilterIntegrationTest.TestResponse.class);
|
||||||
|
|
||||||
|
assertEquals(ip, response.remoteAddress());
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@CsvSource(value = {"127.0.0.1, 192.168.1.1 \t 192.168.1.1",
|
||||||
|
"127.0.0.1, fe80:1:1:1:1:1:1:1 \t fe80:1:1:1:1:1:1:1"}, delimiterString = "\t")
|
||||||
|
void testForwardedFor(String forwardedFor, String expectedIp) {
|
||||||
|
|
||||||
|
Client client = EXTENSION.client();
|
||||||
|
|
||||||
|
final RemoteAddressFilterIntegrationTest.TestResponse response = client.target(
|
||||||
|
String.format("http://localhost:%d%s", EXTENSION.getLocalPort(), FORWARDED_FOR_PATH))
|
||||||
|
.request("application/json")
|
||||||
|
.header(HttpHeaders.X_FORWARDED_FOR, forwardedFor)
|
||||||
|
.get(RemoteAddressFilterIntegrationTest.TestResponse.class);
|
||||||
|
|
||||||
|
assertEquals(expectedIp, response.remoteAddress());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Nested
|
||||||
|
class WebSocket {
|
||||||
|
|
||||||
|
private WebSocketClient client;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() throws Exception {
|
||||||
|
client = new WebSocketClient();
|
||||||
|
client.start();
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterEach
|
||||||
|
void tearDown() throws Exception {
|
||||||
|
client.stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@ValueSource(strings = {"127.0.0.1", "0:0:0:0:0:0:0:1"})
|
||||||
|
void testRemoteAddress(String ip) throws Exception {
|
||||||
|
final Set<String> addresses = Arrays.stream(InetAddress.getAllByName("localhost"))
|
||||||
|
.map(InetAddress::getHostAddress)
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
|
||||||
|
assumeTrue(addresses.contains(ip), String.format("localhost does not resolve to %s", ip));
|
||||||
|
|
||||||
|
final CompletableFuture<byte[]> responseFuture = new CompletableFuture<>();
|
||||||
|
final ClientEndpoint clientEndpoint = new ClientEndpoint(WS_REQUEST_PATH, responseFuture);
|
||||||
|
|
||||||
|
client.connect(clientEndpoint,
|
||||||
|
URI.create(
|
||||||
|
String.format("ws://%s:%d%s", HostPort.normalizeHost(ip), EXTENSION.getLocalPort(),
|
||||||
|
WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH)));
|
||||||
|
|
||||||
|
final byte[] responseBytes = responseFuture.get(1, TimeUnit.SECONDS);
|
||||||
|
|
||||||
|
final TestResponse response = SystemMapper.jsonMapper().readValue(responseBytes, TestResponse.class);
|
||||||
|
|
||||||
|
assertEquals(ip, response.remoteAddress());
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@CsvSource(value = {"127.0.0.1, 192.168.1.1 \t 192.168.1.1",
|
||||||
|
"127.0.0.1, fe80:1:1:1:1:1:1:1 \t fe80:1:1:1:1:1:1:1"}, delimiterString = "\t")
|
||||||
|
void testForwardedFor(String forwardedFor, String expectedIp) throws Exception {
|
||||||
|
|
||||||
|
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
|
||||||
|
upgradeRequest.setHeader(HttpHeaders.X_FORWARDED_FOR, forwardedFor);
|
||||||
|
|
||||||
|
final CompletableFuture<byte[]> responseFuture = new CompletableFuture<>();
|
||||||
|
|
||||||
|
client.connect(new ClientEndpoint(WS_REQUEST_PATH, responseFuture),
|
||||||
|
URI.create(
|
||||||
|
String.format("ws://localhost:%d%s", EXTENSION.getLocalPort(), WEBSOCKET_PREFIX + FORWARDED_FOR_PATH)),
|
||||||
|
upgradeRequest);
|
||||||
|
|
||||||
|
final byte[] responseBytes = responseFuture.get(1, TimeUnit.SECONDS);
|
||||||
|
|
||||||
|
final TestResponse response = SystemMapper.jsonMapper().readValue(responseBytes, TestResponse.class);
|
||||||
|
|
||||||
|
assertEquals(expectedIp, response.remoteAddress());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class ClientEndpoint implements WebSocketListener {
|
||||||
|
|
||||||
|
private final String requestPath;
|
||||||
|
private final CompletableFuture<byte[]> responseFuture;
|
||||||
|
private final WebSocketMessageFactory messageFactory;
|
||||||
|
|
||||||
|
ClientEndpoint(String requestPath, CompletableFuture<byte[]> responseFuture) {
|
||||||
|
|
||||||
|
this.requestPath = requestPath;
|
||||||
|
this.responseFuture = responseFuture;
|
||||||
|
this.messageFactory = new ProtobufWebSocketMessageFactory();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onWebSocketConnect(final Session session) {
|
||||||
|
final byte[] requestBytes = messageFactory.createRequest(Optional.of(1L), "GET", requestPath,
|
||||||
|
List.of("Accept: application/json"),
|
||||||
|
Optional.empty()).toByteArray();
|
||||||
|
try {
|
||||||
|
session.getRemote().sendBytes(ByteBuffer.wrap(requestBytes));
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onWebSocketBinary(final byte[] payload, final int offset, final int length) {
|
||||||
|
|
||||||
|
try {
|
||||||
|
WebSocketMessage webSocketMessage = messageFactory.parseMessage(payload, offset, length);
|
||||||
|
|
||||||
|
if (Objects.requireNonNull(webSocketMessage.getType()) == WebSocketMessage.Type.RESPONSE_MESSAGE) {
|
||||||
|
assert 200 == webSocketMessage.getResponseMessage().getStatus();
|
||||||
|
responseFuture.complete(webSocketMessage.getResponseMessage().getBody().orElseThrow());
|
||||||
|
} else {
|
||||||
|
throw new RuntimeException("Unexpected message type: " + webSocketMessage.getType());
|
||||||
|
}
|
||||||
|
} catch (final Exception e) {
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static abstract class TestController {
|
||||||
|
|
||||||
|
@GET
|
||||||
|
public RemoteAddressFilterIntegrationTest.TestResponse get(@Context ContainerRequestContext context) {
|
||||||
|
|
||||||
|
return new RemoteAddressFilterIntegrationTest.TestResponse(
|
||||||
|
(String) context.getProperty(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Path(REMOTE_ADDRESS_PATH)
|
||||||
|
public static class TestRemoteAddressController extends TestController {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Path(FORWARDED_FOR_PATH)
|
||||||
|
public static class TestForwardedForController extends TestController {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Path(WS_REQUEST_PATH)
|
||||||
|
public static class TestWebSocketController extends TestController {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public record TestResponse(String remoteAddress) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class TestApplication extends Application<Configuration> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run(final Configuration configuration,
|
||||||
|
final Environment environment) throws Exception {
|
||||||
|
|
||||||
|
// 2 filters, to cover useRemoteAddress = {true, false}
|
||||||
|
// each has explicit (not wildcard) path matching
|
||||||
|
environment.servlets().addFilter("RemoteAddressFilterRemoteAddress", new RemoteAddressFilter(true))
|
||||||
|
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, REMOTE_ADDRESS_PATH,
|
||||||
|
WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH);
|
||||||
|
environment.servlets().addFilter("RemoteAddressFilterForwardedFor", new RemoteAddressFilter(false))
|
||||||
|
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, FORWARDED_FOR_PATH,
|
||||||
|
WEBSOCKET_PREFIX + FORWARDED_FOR_PATH);
|
||||||
|
|
||||||
|
environment.jersey().register(new TestRemoteAddressController());
|
||||||
|
environment.jersey().register(new TestForwardedForController());
|
||||||
|
|
||||||
|
// WebSocket set up
|
||||||
|
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
|
||||||
|
|
||||||
|
WebSocketEnvironment<TestPrincipal> webSocketEnvironment = new WebSocketEnvironment<>(environment,
|
||||||
|
webSocketConfiguration, Duration.ofMillis(1000));
|
||||||
|
|
||||||
|
webSocketEnvironment.jersey().register(new TestWebSocketController());
|
||||||
|
|
||||||
|
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
|
||||||
|
|
||||||
|
WebSocketResourceProviderFactory<TestPrincipal> webSocketServlet = new WebSocketResourceProviderFactory<>(
|
||||||
|
webSocketEnvironment, TestPrincipal.class, webSocketConfiguration,
|
||||||
|
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
|
||||||
|
|
||||||
|
// 2 servlets, because the filter only runs for the Upgrade request
|
||||||
|
environment.servlets().addServlet("WebSocketForwardedFor", webSocketServlet)
|
||||||
|
.addMapping(WEBSOCKET_PREFIX + FORWARDED_FOR_PATH);
|
||||||
|
environment.servlets().addServlet("WebSocketRemoteAddress", webSocketServlet)
|
||||||
|
.addMapping(WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A minimal {@code Principal} implementation, only used to satisfy constructors
|
||||||
|
*/
|
||||||
|
public static class TestPrincipal implements Principal {
|
||||||
|
|
||||||
|
// Principal implementation
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean implies(final Subject subject) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,60 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2024 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.filters;
|
||||||
|
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
import com.google.common.net.HttpHeaders;
|
||||||
|
import javax.servlet.FilterChain;
|
||||||
|
import javax.servlet.ServletRequest;
|
||||||
|
import javax.servlet.ServletResponse;
|
||||||
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
|
import org.junit.jupiter.params.provider.CsvSource;
|
||||||
|
|
||||||
|
class RemoteAddressFilterTest {
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@CsvSource({
|
||||||
|
"127.0.0.1, 127.0.0.1",
|
||||||
|
"0:0:0:0:0:0:0:1, 0:0:0:0:0:0:0:1",
|
||||||
|
"[0:0:0:0:0:0:0:1], 0:0:0:0:0:0:0:1"
|
||||||
|
})
|
||||||
|
void testGetRemoteAddress(final String remoteAddr, final String expectedRemoteAddr) throws Exception {
|
||||||
|
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
|
||||||
|
when(httpServletRequest.getRemoteAddr()).thenReturn(remoteAddr);
|
||||||
|
|
||||||
|
final RemoteAddressFilter filter = new RemoteAddressFilter(true);
|
||||||
|
|
||||||
|
final FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
filter.doFilter(httpServletRequest, mock(ServletResponse.class), filterChain);
|
||||||
|
|
||||||
|
verify(httpServletRequest).setAttribute(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, expectedRemoteAddr);
|
||||||
|
verify(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
@ParameterizedTest
|
||||||
|
@CsvSource(value = {
|
||||||
|
"192.168.1.1, 127.0.0.1 \t 127.0.0.1",
|
||||||
|
"192.168.1.1, 0:0:0:0:0:0:0:1 \t 0:0:0:0:0:0:0:1"
|
||||||
|
}, delimiterString = "\t")
|
||||||
|
void testGetRemoteAddressFromHeader(final String forwardedFor, final String expectedRemoteAddr) throws Exception {
|
||||||
|
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
|
||||||
|
when(httpServletRequest.getHeader(HttpHeaders.X_FORWARDED_FOR)).thenReturn(forwardedFor);
|
||||||
|
|
||||||
|
final RemoteAddressFilter filter = new RemoteAddressFilter(false);
|
||||||
|
|
||||||
|
final FilterChain filterChain = mock(FilterChain.class);
|
||||||
|
filter.doFilter(httpServletRequest, mock(ServletResponse.class), filterChain);
|
||||||
|
|
||||||
|
verify(httpServletRequest).setAttribute(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, expectedRemoteAddr);
|
||||||
|
verify(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -25,6 +25,7 @@ import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
import org.whispersystems.textsecuregcm.controllers.RateLimitExceededException;
|
||||||
import org.whispersystems.textsecuregcm.util.MockUtils;
|
import org.whispersystems.textsecuregcm.util.MockUtils;
|
||||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||||
|
import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider;
|
||||||
|
|
||||||
@ExtendWith(DropwizardExtensionsSupport.class)
|
@ExtendWith(DropwizardExtensionsSupport.class)
|
||||||
public class RateLimitedByIpTest {
|
public class RateLimitedByIpTest {
|
||||||
|
@ -60,7 +61,8 @@ public class RateLimitedByIpTest {
|
||||||
.setMapper(SystemMapper.jsonMapper())
|
.setMapper(SystemMapper.jsonMapper())
|
||||||
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
|
||||||
.addResource(new Controller())
|
.addResource(new Controller())
|
||||||
.addProvider(new RateLimitByIpFilter(RATE_LIMITERS, true))
|
.addProvider(new RateLimitByIpFilter(RATE_LIMITERS))
|
||||||
|
.addProvider(new TestRemoteAddressFilterProvider(IP))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -49,6 +49,7 @@ import org.glassfish.jersey.uri.UriTemplate;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.mockito.ArgumentCaptor;
|
import org.mockito.ArgumentCaptor;
|
||||||
|
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
|
||||||
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
|
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
|
||||||
import org.whispersystems.websocket.WebSocketResourceProvider;
|
import org.whispersystems.websocket.WebSocketResourceProvider;
|
||||||
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
|
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
|
||||||
|
@ -138,12 +139,8 @@ class MetricsRequestEventListenerTest {
|
||||||
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||||
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||||
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||||
applicationHandler,
|
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
|
||||||
requestLog,
|
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||||
new TestPrincipal("foo"),
|
|
||||||
new ProtobufWebSocketMessageFactory(),
|
|
||||||
Optional.empty(),
|
|
||||||
Duration.ofMillis(30000));
|
|
||||||
|
|
||||||
final Session session = mock(Session.class);
|
final Session session = mock(Session.class);
|
||||||
final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||||
|
@ -204,9 +201,8 @@ class MetricsRequestEventListenerTest {
|
||||||
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
final ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||||
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
final WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||||
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
final WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||||
applicationHandler,
|
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
|
||||||
requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
|
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||||
Duration.ofMillis(30000));
|
|
||||||
|
|
||||||
final Session session = mock(Session.class);
|
final Session session = mock(Session.class);
|
||||||
final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
final RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||||
|
|
|
@ -35,8 +35,7 @@ class HttpServletRequestUtilIntegrationTest {
|
||||||
// The Grizzly test container does not match the Jetty container used in real deployments, and JettyTestContainerFactory
|
// The Grizzly test container does not match the Jetty container used in real deployments, and JettyTestContainerFactory
|
||||||
// in jersey-test-framework-provider-jetty doesn’t easily support @Context HttpServletRequest, so this test runs a
|
// in jersey-test-framework-provider-jetty doesn’t easily support @Context HttpServletRequest, so this test runs a
|
||||||
// full Jetty server in a separate process
|
// full Jetty server in a separate process
|
||||||
private final DropwizardAppExtension<TestConfiguration> EXTENSION = new DropwizardAppExtension<>(
|
private final DropwizardAppExtension<Configuration> EXTENSION = new DropwizardAppExtension<>(TestApplication.class);
|
||||||
TestApplication.class);
|
|
||||||
|
|
||||||
@ParameterizedTest
|
@ParameterizedTest
|
||||||
@ValueSource(strings = {"127.0.0.1", "0:0:0:0:0:0:0:1"})
|
@ValueSource(strings = {"127.0.0.1", "0:0:0:0:0:0:0:1"})
|
||||||
|
@ -72,13 +71,11 @@ class HttpServletRequestUtilIntegrationTest {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class TestApplication extends Application<TestConfiguration> {
|
public static class TestApplication extends Application<Configuration> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run(final TestConfiguration configuration, final Environment environment) throws Exception {
|
public void run(final Configuration configuration, final Environment environment) throws Exception {
|
||||||
environment.jersey().register(new TestController());
|
environment.jersey().register(new TestController());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class TestConfiguration extends Configuration {}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2024 Signal Messenger, LLC
|
||||||
|
* SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.whispersystems.textsecuregcm.util;
|
||||||
|
|
||||||
|
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
|
||||||
|
import javax.annotation.Priority;
|
||||||
|
import javax.ws.rs.container.ContainerRequestContext;
|
||||||
|
import javax.ws.rs.container.ContainerRequestFilter;
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds the request property set by {@link RemoteAddressFilter} for test scenarios that depend on it, but do not have
|
||||||
|
* access to a full {@code HttpServletRequest} pipline
|
||||||
|
*/
|
||||||
|
@Priority(Integer.MIN_VALUE) // highest priority, since other filters might depend on it
|
||||||
|
public class TestRemoteAddressFilterProvider implements ContainerRequestFilter {
|
||||||
|
|
||||||
|
private final String ip;
|
||||||
|
|
||||||
|
public TestRemoteAddressFilterProvider(String ip) {
|
||||||
|
this.ip = ip;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void filter(final ContainerRequestContext requestContext) throws IOException {
|
||||||
|
requestContext.setProperty(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, ip);
|
||||||
|
}
|
||||||
|
}
|
|
@ -55,6 +55,7 @@ import org.junit.jupiter.params.ParameterizedTest;
|
||||||
import org.junit.jupiter.params.provider.Arguments;
|
import org.junit.jupiter.params.provider.Arguments;
|
||||||
import org.junit.jupiter.params.provider.MethodSource;
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
|
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
|
||||||
import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper;
|
import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper;
|
||||||
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
import org.whispersystems.textsecuregcm.util.SystemMapper;
|
||||||
import org.whispersystems.websocket.WebSocketResourceProvider;
|
import org.whispersystems.websocket.WebSocketResourceProvider;
|
||||||
|
@ -173,9 +174,9 @@ class LoggingUnhandledExceptionMapperTest {
|
||||||
|
|
||||||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
|
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||||
requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
|
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
|
||||||
Duration.ofMillis(30000));
|
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||||
|
|
||||||
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||||
doAnswer(answer -> {
|
doAnswer(answer -> {
|
||||||
|
|
|
@ -65,6 +65,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
|
||||||
private final WebsocketRequestLog requestLog;
|
private final WebsocketRequestLog requestLog;
|
||||||
private final Duration idleTimeout;
|
private final Duration idleTimeout;
|
||||||
private final String remoteAddress;
|
private final String remoteAddress;
|
||||||
|
private final String remoteAddressPropertyName;
|
||||||
|
|
||||||
private Session session;
|
private Session session;
|
||||||
private RemoteEndpoint remoteEndpoint;
|
private RemoteEndpoint remoteEndpoint;
|
||||||
|
@ -73,6 +74,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
|
||||||
private static final Set<String> EXCLUDED_UPGRADE_REQUEST_HEADERS = Set.of("connection", "upgrade");
|
private static final Set<String> EXCLUDED_UPGRADE_REQUEST_HEADERS = Set.of("connection", "upgrade");
|
||||||
|
|
||||||
public WebSocketResourceProvider(String remoteAddress,
|
public WebSocketResourceProvider(String remoteAddress,
|
||||||
|
String remoteAddressPropertyName,
|
||||||
ApplicationHandler jerseyHandler,
|
ApplicationHandler jerseyHandler,
|
||||||
WebsocketRequestLog requestLog,
|
WebsocketRequestLog requestLog,
|
||||||
T authenticated,
|
T authenticated,
|
||||||
|
@ -80,6 +82,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
|
||||||
Optional<WebSocketConnectListener> connectListener,
|
Optional<WebSocketConnectListener> connectListener,
|
||||||
Duration idleTimeout) {
|
Duration idleTimeout) {
|
||||||
this.remoteAddress = remoteAddress;
|
this.remoteAddress = remoteAddress;
|
||||||
|
this.remoteAddressPropertyName = remoteAddressPropertyName;
|
||||||
this.jerseyHandler = jerseyHandler;
|
this.jerseyHandler = jerseyHandler;
|
||||||
this.requestLog = requestLog;
|
this.requestLog = requestLog;
|
||||||
this.authenticated = authenticated;
|
this.authenticated = authenticated;
|
||||||
|
@ -169,6 +172,8 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
|
||||||
containerRequest.setEntityStream(new ByteArrayInputStream(requestMessage.getBody().get()));
|
containerRequest.setEntityStream(new ByteArrayInputStream(requestMessage.getBody().get()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
containerRequest.setProperty(remoteAddressPropertyName, remoteAddress);
|
||||||
|
|
||||||
ByteArrayOutputStream responseBody = new ByteArrayOutputStream();
|
ByteArrayOutputStream responseBody = new ByteArrayOutputStream();
|
||||||
CompletableFuture<ContainerResponse> responseFuture = (CompletableFuture<ContainerResponse>) jerseyHandler.apply(
|
CompletableFuture<ContainerResponse> responseFuture = (CompletableFuture<ContainerResponse>) jerseyHandler.apply(
|
||||||
containerRequest, responseBody);
|
containerRequest, responseBody);
|
||||||
|
|
|
@ -6,13 +6,12 @@ package org.whispersystems.websocket;
|
||||||
|
|
||||||
import static java.util.Optional.ofNullable;
|
import static java.util.Optional.ofNullable;
|
||||||
|
|
||||||
import com.google.common.net.HttpHeaders;
|
|
||||||
import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider;
|
import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.net.InetSocketAddress;
|
|
||||||
import java.security.Principal;
|
import java.security.Principal;
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
import javax.ws.rs.InternalServerErrorException;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
|
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
|
||||||
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
|
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
|
||||||
import org.eclipse.jetty.websocket.server.JettyWebSocketCreator;
|
import org.eclipse.jetty.websocket.server.JettyWebSocketCreator;
|
||||||
|
@ -38,8 +37,10 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
|
||||||
private final ApplicationHandler jerseyApplicationHandler;
|
private final ApplicationHandler jerseyApplicationHandler;
|
||||||
private final WebSocketConfiguration configuration;
|
private final WebSocketConfiguration configuration;
|
||||||
|
|
||||||
|
private final String remoteAddressPropertyName;
|
||||||
|
|
||||||
public WebSocketResourceProviderFactory(WebSocketEnvironment<T> environment, Class<T> principalClass,
|
public WebSocketResourceProviderFactory(WebSocketEnvironment<T> environment, Class<T> principalClass,
|
||||||
WebSocketConfiguration configuration) {
|
WebSocketConfiguration configuration, String remoteAddressPropertyName) {
|
||||||
this.environment = environment;
|
this.environment = environment;
|
||||||
|
|
||||||
environment.jersey().register(new WebSocketSessionContextValueFactoryProvider.Binder());
|
environment.jersey().register(new WebSocketSessionContextValueFactoryProvider.Binder());
|
||||||
|
@ -49,6 +50,7 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
|
||||||
this.jerseyApplicationHandler = new ApplicationHandler(environment.jersey());
|
this.jerseyApplicationHandler = new ApplicationHandler(environment.jersey());
|
||||||
|
|
||||||
this.configuration = configuration;
|
this.configuration = configuration;
|
||||||
|
this.remoteAddressPropertyName = remoteAddressPropertyName;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -69,6 +71,7 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
|
||||||
}
|
}
|
||||||
|
|
||||||
return new WebSocketResourceProvider<>(getRemoteAddress(request),
|
return new WebSocketResourceProvider<>(getRemoteAddress(request),
|
||||||
|
remoteAddressPropertyName,
|
||||||
this.jerseyApplicationHandler,
|
this.jerseyApplicationHandler,
|
||||||
this.environment.getRequestLog(),
|
this.environment.getRequestLog(),
|
||||||
authenticated,
|
authenticated,
|
||||||
|
@ -93,18 +96,11 @@ public class WebSocketResourceProviderFactory<T extends Principal> extends Jetty
|
||||||
}
|
}
|
||||||
|
|
||||||
private String getRemoteAddress(JettyServerUpgradeRequest request) {
|
private String getRemoteAddress(JettyServerUpgradeRequest request) {
|
||||||
String forwardedFor = request.getHeader(HttpHeaders.X_FORWARDED_FOR);
|
final String remoteAddress = (String) request.getHttpServletRequest().getAttribute(remoteAddressPropertyName);
|
||||||
|
if (StringUtils.isBlank(remoteAddress)) {
|
||||||
if (forwardedFor == null || forwardedFor.isBlank()) {
|
logger.error("Remote address property is not present");
|
||||||
if (request.getRemoteSocketAddress() instanceof InetSocketAddress inetSocketAddress) {
|
throw new InternalServerErrorException();
|
||||||
return inetSocketAddress.getAddress().getHostAddress();
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
} else {
|
|
||||||
return Arrays.stream(forwardedFor.split(","))
|
|
||||||
.map(String::trim)
|
|
||||||
.reduce((a, b) -> b)
|
|
||||||
.orElseThrow();
|
|
||||||
}
|
}
|
||||||
|
return remoteAddress;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,8 +18,8 @@ import java.io.IOException;
|
||||||
import java.security.Principal;
|
import java.security.Principal;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import javax.security.auth.Subject;
|
import javax.security.auth.Subject;
|
||||||
|
import javax.servlet.http.HttpServletRequest;
|
||||||
import org.eclipse.jetty.websocket.api.Session;
|
import org.eclipse.jetty.websocket.api.Session;
|
||||||
import org.eclipse.jetty.websocket.api.UpgradeRequest;
|
|
||||||
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
|
import org.eclipse.jetty.websocket.server.JettyServerUpgradeRequest;
|
||||||
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
|
import org.eclipse.jetty.websocket.server.JettyServerUpgradeResponse;
|
||||||
import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory;
|
import org.eclipse.jetty.websocket.server.JettyWebSocketServletFactory;
|
||||||
|
@ -33,6 +33,8 @@ import org.whispersystems.websocket.setup.WebSocketEnvironment;
|
||||||
|
|
||||||
public class WebSocketResourceProviderFactoryTest {
|
public class WebSocketResourceProviderFactoryTest {
|
||||||
|
|
||||||
|
private static final String REMOTE_ADDRESS_PROPERTY_NAME = "org.whispersystems.websocket.test.remoteAddress";
|
||||||
|
|
||||||
private ResourceConfig jerseyEnvironment;
|
private ResourceConfig jerseyEnvironment;
|
||||||
private WebSocketEnvironment<Account> environment;
|
private WebSocketEnvironment<Account> environment;
|
||||||
private WebSocketAuthenticator<Account> authenticator;
|
private WebSocketAuthenticator<Account> authenticator;
|
||||||
|
@ -59,7 +61,7 @@ public class WebSocketResourceProviderFactoryTest {
|
||||||
when(environment.jersey()).thenReturn(jerseyEnvironment);
|
when(environment.jersey()).thenReturn(jerseyEnvironment);
|
||||||
|
|
||||||
WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
|
WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
|
||||||
mock(WebSocketConfiguration.class));
|
mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME);
|
||||||
Object connection = factory.createWebSocket(request, response);
|
Object connection = factory.createWebSocket(request, response);
|
||||||
|
|
||||||
assertNull(connection);
|
assertNull(connection);
|
||||||
|
@ -69,24 +71,25 @@ public class WebSocketResourceProviderFactoryTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testValidAuthorization() throws AuthenticationException {
|
void testValidAuthorization() throws AuthenticationException {
|
||||||
Session session = mock(Session.class);
|
|
||||||
Account account = new Account();
|
Account account = new Account();
|
||||||
|
|
||||||
when(environment.getAuthenticator()).thenReturn(authenticator);
|
when(environment.getAuthenticator()).thenReturn(authenticator);
|
||||||
when(authenticator.authenticate(eq(request))).thenReturn(
|
when(authenticator.authenticate(eq(request))).thenReturn(
|
||||||
new WebSocketAuthenticator.AuthenticationResult<>(Optional.of(account), true));
|
new WebSocketAuthenticator.AuthenticationResult<>(Optional.of(account), true));
|
||||||
when(environment.jersey()).thenReturn(jerseyEnvironment);
|
when(environment.jersey()).thenReturn(jerseyEnvironment);
|
||||||
when(session.getUpgradeRequest()).thenReturn(mock(UpgradeRequest.class));
|
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
|
||||||
|
when(httpServletRequest.getAttribute(REMOTE_ADDRESS_PROPERTY_NAME)).thenReturn("127.0.0.1");
|
||||||
|
when(request.getHttpServletRequest()).thenReturn(httpServletRequest);
|
||||||
|
|
||||||
WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
|
WebSocketResourceProviderFactory<?> factory = new WebSocketResourceProviderFactory<>(environment, Account.class,
|
||||||
mock(WebSocketConfiguration.class));
|
mock(WebSocketConfiguration.class), REMOTE_ADDRESS_PROPERTY_NAME);
|
||||||
Object connection = factory.createWebSocket(request, response);
|
Object connection = factory.createWebSocket(request, response);
|
||||||
|
|
||||||
assertNotNull(connection);
|
assertNotNull(connection);
|
||||||
verifyNoMoreInteractions(response);
|
verifyNoMoreInteractions(response);
|
||||||
verify(authenticator).authenticate(eq(request));
|
verify(authenticator).authenticate(eq(request));
|
||||||
|
|
||||||
((WebSocketResourceProvider<?>) connection).onWebSocketConnect(session);
|
((WebSocketResourceProvider<?>) connection).onWebSocketConnect(mock(Session.class));
|
||||||
|
|
||||||
assertNotNull(((WebSocketResourceProvider<?>) connection).getContext().getAuthenticated());
|
assertNotNull(((WebSocketResourceProvider<?>) connection).getContext().getAuthenticated());
|
||||||
assertEquals(((WebSocketResourceProvider<?>) connection).getContext().getAuthenticated(), account);
|
assertEquals(((WebSocketResourceProvider<?>) connection).getContext().getAuthenticated(), account);
|
||||||
|
@ -100,7 +103,8 @@ public class WebSocketResourceProviderFactoryTest {
|
||||||
|
|
||||||
WebSocketResourceProviderFactory<Account> factory = new WebSocketResourceProviderFactory<>(environment,
|
WebSocketResourceProviderFactory<Account> factory = new WebSocketResourceProviderFactory<>(environment,
|
||||||
Account.class,
|
Account.class,
|
||||||
mock(WebSocketConfiguration.class));
|
mock(WebSocketConfiguration.class),
|
||||||
|
REMOTE_ADDRESS_PROPERTY_NAME);
|
||||||
Object connection = factory.createWebSocket(request, response);
|
Object connection = factory.createWebSocket(request, response);
|
||||||
|
|
||||||
assertNull(connection);
|
assertNull(connection);
|
||||||
|
@ -115,7 +119,8 @@ public class WebSocketResourceProviderFactoryTest {
|
||||||
|
|
||||||
WebSocketResourceProviderFactory<Account> factory = new WebSocketResourceProviderFactory<>(environment,
|
WebSocketResourceProviderFactory<Account> factory = new WebSocketResourceProviderFactory<>(environment,
|
||||||
Account.class,
|
Account.class,
|
||||||
mock(WebSocketConfiguration.class));
|
mock(WebSocketConfiguration.class),
|
||||||
|
REMOTE_ADDRESS_PROPERTY_NAME);
|
||||||
factory.configure(servletFactory);
|
factory.configure(servletFactory);
|
||||||
|
|
||||||
verify(servletFactory).setCreator(eq(factory));
|
verify(servletFactory).setCreator(eq(factory));
|
||||||
|
|
|
@ -70,12 +70,15 @@ import org.whispersystems.websocket.setup.WebSocketConnectListener;
|
||||||
|
|
||||||
class WebSocketResourceProviderTest {
|
class WebSocketResourceProviderTest {
|
||||||
|
|
||||||
|
private static final String REMOTE_ADDRESS_PROPERTY_NAME = "org.whispersystems.weboscket.test.remoteAddress";
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testOnConnect() {
|
void testOnConnect() {
|
||||||
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
|
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
|
||||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||||
WebSocketConnectListener connectListener = mock(WebSocketConnectListener.class);
|
WebSocketConnectListener connectListener = mock(WebSocketConnectListener.class);
|
||||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||||
|
REMOTE_ADDRESS_PROPERTY_NAME,
|
||||||
applicationHandler, requestLog,
|
applicationHandler, requestLog,
|
||||||
new TestPrincipal("fooz"),
|
new TestPrincipal("fooz"),
|
||||||
new ProtobufWebSocketMessageFactory(),
|
new ProtobufWebSocketMessageFactory(),
|
||||||
|
@ -104,9 +107,9 @@ class WebSocketResourceProviderTest {
|
||||||
void testMockedRouteMessageSuccess() throws Exception {
|
void testMockedRouteMessageSuccess() throws Exception {
|
||||||
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
|
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
|
||||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
|
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||||
requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
|
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
|
||||||
Duration.ofMillis(30000));
|
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||||
|
|
||||||
Session session = mock(Session.class);
|
Session session = mock(Session.class);
|
||||||
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||||
|
@ -180,9 +183,9 @@ class WebSocketResourceProviderTest {
|
||||||
void testMockedRouteMessageFailure() throws Exception {
|
void testMockedRouteMessageFailure() throws Exception {
|
||||||
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
|
ApplicationHandler applicationHandler = mock(ApplicationHandler.class);
|
||||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
|
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||||
requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
|
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
|
||||||
Duration.ofMillis(30000));
|
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||||
|
|
||||||
Session session = mock(Session.class);
|
Session session = mock(Session.class);
|
||||||
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||||
|
@ -236,9 +239,9 @@ class WebSocketResourceProviderTest {
|
||||||
|
|
||||||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
|
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||||
requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
|
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
|
||||||
Duration.ofMillis(30000));
|
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||||
|
|
||||||
Session session = mock(Session.class);
|
Session session = mock(Session.class);
|
||||||
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||||
|
@ -276,9 +279,9 @@ class WebSocketResourceProviderTest {
|
||||||
|
|
||||||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
|
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||||
requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
|
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("foo"),
|
||||||
Duration.ofMillis(30000));
|
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||||
|
|
||||||
Session session = mock(Session.class);
|
Session session = mock(Session.class);
|
||||||
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||||
|
@ -316,9 +319,9 @@ class WebSocketResourceProviderTest {
|
||||||
|
|
||||||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
|
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||||
requestLog, new TestPrincipal("authorizedUserName"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
|
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("authorizedUserName"),
|
||||||
Duration.ofMillis(30000));
|
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||||
|
|
||||||
Session session = mock(Session.class);
|
Session session = mock(Session.class);
|
||||||
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||||
|
@ -356,8 +359,9 @@ class WebSocketResourceProviderTest {
|
||||||
|
|
||||||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
|
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||||
requestLog, null, new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(),
|
||||||
|
Optional.empty(), Duration.ofMillis(30000));
|
||||||
|
|
||||||
Session session = mock(Session.class);
|
Session session = mock(Session.class);
|
||||||
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||||
|
@ -394,9 +398,9 @@ class WebSocketResourceProviderTest {
|
||||||
|
|
||||||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
|
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||||
requestLog, new TestPrincipal("something"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
|
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("something"),
|
||||||
Duration.ofMillis(30000));
|
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||||
|
|
||||||
Session session = mock(Session.class);
|
Session session = mock(Session.class);
|
||||||
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||||
|
@ -434,8 +438,9 @@ class WebSocketResourceProviderTest {
|
||||||
|
|
||||||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
|
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||||
requestLog, null, new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(),
|
||||||
|
Optional.empty(), Duration.ofMillis(30000));
|
||||||
|
|
||||||
Session session = mock(Session.class);
|
Session session = mock(Session.class);
|
||||||
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||||
|
@ -473,9 +478,9 @@ class WebSocketResourceProviderTest {
|
||||||
|
|
||||||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
|
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||||
requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
|
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
|
||||||
Duration.ofMillis(30000));
|
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||||
|
|
||||||
Session session = mock(Session.class);
|
Session session = mock(Session.class);
|
||||||
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||||
|
@ -514,9 +519,9 @@ class WebSocketResourceProviderTest {
|
||||||
|
|
||||||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
|
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||||
requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
|
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
|
||||||
Duration.ofMillis(30000));
|
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||||
|
|
||||||
Session session = mock(Session.class);
|
Session session = mock(Session.class);
|
||||||
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||||
|
@ -556,9 +561,9 @@ class WebSocketResourceProviderTest {
|
||||||
|
|
||||||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
|
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||||
requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
|
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
|
||||||
Duration.ofMillis(30000));
|
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||||
|
|
||||||
Session session = mock(Session.class);
|
Session session = mock(Session.class);
|
||||||
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||||
|
@ -596,9 +601,9 @@ class WebSocketResourceProviderTest {
|
||||||
|
|
||||||
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
|
||||||
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
|
||||||
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
|
WebSocketResourceProvider<TestPrincipal> provider = new WebSocketResourceProvider<>("127.0.0.1",
|
||||||
requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(),
|
REMOTE_ADDRESS_PROPERTY_NAME, applicationHandler, requestLog, new TestPrincipal("gooduser"),
|
||||||
Duration.ofMillis(30000));
|
new ProtobufWebSocketMessageFactory(), Optional.empty(), Duration.ofMillis(30000));
|
||||||
|
|
||||||
Session session = mock(Session.class);
|
Session session = mock(Session.class);
|
||||||
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class);
|
||||||
|
|
Loading…
Reference in New Issue