From 3cdc58200a02701b2b707e0365f23878fb03f78f Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Tue, 9 Mar 2021 14:18:21 -0500 Subject: [PATCH] Copy headers from the initial websocket upgrade request into subsequent resource requests. --- .../MetricsRequestEventListenerTest.java | 2 + .../websocket/WebSocketResourceProvider.java | 41 +++++++++++++------ .../WebSocketResourceProviderTest.java | 40 ++++++++++++++++++ 3 files changed, 71 insertions(+), 12 deletions(-) diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java index 018246991..c6be54ccb 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsRequestEventListenerTest.java @@ -47,6 +47,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; @@ -158,6 +159,7 @@ public class MetricsRequestEventListenerTest { when(session.getUpgradeRequest()).thenReturn(request); when(session.getRemote()).thenReturn(remoteEndpoint); when(request.getHeader("User-Agent")).thenReturn("Signal-Android 4.53.7 (Android 8.1)"); + when(request.getHeaders()).thenReturn(Map.of("User-Agent", List.of("Signal-Android 4.53.7 (Android 8.1)"))); final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_COUNTER_NAME), any(Iterable.class))).thenReturn(counter); diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java index c52b4c8c5..3133d9d2a 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java @@ -37,6 +37,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; @@ -59,7 +60,8 @@ public class WebSocketResourceProvider implements WebSocket private Session session; private RemoteEndpoint remoteEndpoint; private WebSocketSessionContext context; - private String userAgent; + + private static final Set EXCLUDED_UPGRADE_REQUEST_HEADERS = Set.of("connection", "upgrade"); public WebSocketResourceProvider(String remoteAddress, ApplicationHandler jerseyHandler, @@ -81,7 +83,6 @@ public class WebSocketResourceProvider implements WebSocket @Override public void onWebSocketConnect(Session session) { this.session = session; - this.userAgent = session.getUpgradeRequest().getHeader("User-Agent"); this.remoteEndpoint = session.getRemote(); this.context = new WebSocketSessionContext(new WebSocketClient(session, remoteEndpoint, messageFactory, requestMap)); this.context.setAuthenticated(authenticated); @@ -142,16 +143,7 @@ public class WebSocketResourceProvider implements WebSocket private void handleRequest(WebSocketRequestMessage requestMessage) { ContainerRequest containerRequest = new ContainerRequest(null, URI.create(requestMessage.getPath()), requestMessage.getVerb(), new WebSocketSecurityContext(new ContextPrincipal(context)), new MapPropertiesDelegate(new HashMap<>()), null); - - for (Map.Entry entry : requestMessage.getHeaders().entrySet()) { - containerRequest.header(entry.getKey(), entry.getValue()); - } - - final List requestUserAgentHeader = containerRequest.getRequestHeader("User-Agent"); - - if ((requestUserAgentHeader == null || requestUserAgentHeader.isEmpty()) && userAgent != null) { - containerRequest.header("User-Agent", userAgent); - } + containerRequest.headers(getCombinedHeaders(session.getUpgradeRequest().getHeaders(), requestMessage.getHeaders())); if (requestMessage.getBody().isPresent()) { containerRequest.setEntityStream(new ByteArrayInputStream(requestMessage.getBody().get())); @@ -171,6 +163,31 @@ public class WebSocketResourceProvider implements WebSocket }); } + @VisibleForTesting + static Map> getCombinedHeaders(final Map> upgradeRequestHeaders, final Map requestMessageHeaders) { + final Map> combinedHeaders = new HashMap<>(); + + upgradeRequestHeaders.entrySet().stream() + .filter(entry -> shouldIncludeUpgradeRequestHeader(entry.getKey())) + .forEach(entry -> combinedHeaders.put(entry.getKey(), entry.getValue())); + + requestMessageHeaders.entrySet().stream() + .filter(entry -> shouldIncludeRequestMessageHeader(entry.getKey())) + .forEach(entry -> combinedHeaders.put(entry.getKey(), List.of(entry.getValue()))); + + return combinedHeaders; + } + + @VisibleForTesting + static boolean shouldIncludeUpgradeRequestHeader(final String header) { + return !EXCLUDED_UPGRADE_REQUEST_HEADERS.contains(header.toLowerCase()) && !header.toLowerCase().contains("websocket-"); + } + + @VisibleForTesting + static boolean shouldIncludeRequestMessageHeader(final String header) { + return !"X-Forwarded-For".equalsIgnoreCase(header.trim()); + } + private void handleResponse(WebSocketResponseMessage responseMessage) { CompletableFuture future = requestMap.remove(responseMessage.getRequestId()); diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java index ad08af621..d8fc2772b 100644 --- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java +++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java @@ -54,6 +54,7 @@ import java.security.Principal; import java.util.Arrays; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -613,6 +614,45 @@ public class WebSocketResourceProviderTest { } } + @Test + public void testShouldIncludeUpgradeRequestHeader() { + assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("Upgrade")).isFalse(); + assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("Connection")).isFalse(); + assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("Sec-WebSocket-Key")).isFalse(); + assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("User-Agent")).isTrue(); + assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("X-Forwarded-For")).isTrue(); + } + + @Test + public void testShouldIncludeRequestMessageHeader() { + assertThat(WebSocketResourceProvider.shouldIncludeRequestMessageHeader("X-Forwarded-For")).isFalse(); + assertThat(WebSocketResourceProvider.shouldIncludeRequestMessageHeader("User-Agent")).isTrue(); + } + + @Test + public void testGetCombinedHeaders() { + final Map> upgradeRequestHeaders = Map.of( + "Host", List.of("server.example.com"), + "Upgrade", List.of("websocket"), + "Connection", List.of("Upgrade"), + "Sec-WebSocket-Key", List.of("dGhlIHNhbXBsZSBub25jZQ=="), + "Sec-WebSocket-Protocol", List.of("chat, superchat"), + "Sec-WebSocket-Version", List.of("13"), + "X-Forwarded-For", List.of("127.0.0.1"), + "User-Agent", List.of("Upgrade request user agent")); + + final Map requestMessageHeaders = Map.of( + "X-Forwarded-For", "192.168.0.1", + "User-Agent", "Request message user agent"); + + final Map> expectedHeaders = Map.of( + "Host", List.of("server.example.com"), + "X-Forwarded-For", List.of("127.0.0.1"), + "User-Agent", List.of("Request message user agent")); + + assertThat(WebSocketResourceProvider.getCombinedHeaders(upgradeRequestHeaders, requestMessageHeaders)).isEqualTo(expectedHeaders); + } + private SubProtocol.WebSocketResponseMessage getResponse(ArgumentCaptor responseCaptor) throws InvalidProtocolBufferException { return SubProtocol.WebSocketMessage.parseFrom(responseCaptor.getValue().array()).getResponse(); }