Copy headers from the initial websocket upgrade request into subsequent resource requests.
This commit is contained in:
parent
933dd81d82
commit
3cdc58200a
|
@ -47,6 +47,7 @@ import java.util.Collections;
|
|||
import java.util.HashSet;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
|
@ -158,6 +159,7 @@ public class MetricsRequestEventListenerTest {
|
|||
when(session.getUpgradeRequest()).thenReturn(request);
|
||||
when(session.getRemote()).thenReturn(remoteEndpoint);
|
||||
when(request.getHeader("User-Agent")).thenReturn("Signal-Android 4.53.7 (Android 8.1)");
|
||||
when(request.getHeaders()).thenReturn(Map.of("User-Agent", List.of("Signal-Android 4.53.7 (Android 8.1)")));
|
||||
|
||||
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
|
||||
when(meterRegistry.counter(eq(MetricsRequestEventListener.REQUEST_COUNTER_NAME), any(Iterable.class))).thenReturn(counter);
|
||||
|
|
|
@ -37,6 +37,7 @@ import java.util.LinkedList;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
|
@ -59,7 +60,8 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
|
|||
private Session session;
|
||||
private RemoteEndpoint remoteEndpoint;
|
||||
private WebSocketSessionContext context;
|
||||
private String userAgent;
|
||||
|
||||
private static final Set<String> EXCLUDED_UPGRADE_REQUEST_HEADERS = Set.of("connection", "upgrade");
|
||||
|
||||
public WebSocketResourceProvider(String remoteAddress,
|
||||
ApplicationHandler jerseyHandler,
|
||||
|
@ -81,7 +83,6 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
|
|||
@Override
|
||||
public void onWebSocketConnect(Session session) {
|
||||
this.session = session;
|
||||
this.userAgent = session.getUpgradeRequest().getHeader("User-Agent");
|
||||
this.remoteEndpoint = session.getRemote();
|
||||
this.context = new WebSocketSessionContext(new WebSocketClient(session, remoteEndpoint, messageFactory, requestMap));
|
||||
this.context.setAuthenticated(authenticated);
|
||||
|
@ -142,16 +143,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
|
|||
|
||||
private void handleRequest(WebSocketRequestMessage requestMessage) {
|
||||
ContainerRequest containerRequest = new ContainerRequest(null, URI.create(requestMessage.getPath()), requestMessage.getVerb(), new WebSocketSecurityContext(new ContextPrincipal(context)), new MapPropertiesDelegate(new HashMap<>()), null);
|
||||
|
||||
for (Map.Entry<String, String> entry : requestMessage.getHeaders().entrySet()) {
|
||||
containerRequest.header(entry.getKey(), entry.getValue());
|
||||
}
|
||||
|
||||
final List<String> requestUserAgentHeader = containerRequest.getRequestHeader("User-Agent");
|
||||
|
||||
if ((requestUserAgentHeader == null || requestUserAgentHeader.isEmpty()) && userAgent != null) {
|
||||
containerRequest.header("User-Agent", userAgent);
|
||||
}
|
||||
containerRequest.headers(getCombinedHeaders(session.getUpgradeRequest().getHeaders(), requestMessage.getHeaders()));
|
||||
|
||||
if (requestMessage.getBody().isPresent()) {
|
||||
containerRequest.setEntityStream(new ByteArrayInputStream(requestMessage.getBody().get()));
|
||||
|
@ -171,6 +163,31 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
|
|||
});
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static Map<String, List<String>> getCombinedHeaders(final Map<String, List<String>> upgradeRequestHeaders, final Map<String, String> requestMessageHeaders) {
|
||||
final Map<String, List<String>> combinedHeaders = new HashMap<>();
|
||||
|
||||
upgradeRequestHeaders.entrySet().stream()
|
||||
.filter(entry -> shouldIncludeUpgradeRequestHeader(entry.getKey()))
|
||||
.forEach(entry -> combinedHeaders.put(entry.getKey(), entry.getValue()));
|
||||
|
||||
requestMessageHeaders.entrySet().stream()
|
||||
.filter(entry -> shouldIncludeRequestMessageHeader(entry.getKey()))
|
||||
.forEach(entry -> combinedHeaders.put(entry.getKey(), List.of(entry.getValue())));
|
||||
|
||||
return combinedHeaders;
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static boolean shouldIncludeUpgradeRequestHeader(final String header) {
|
||||
return !EXCLUDED_UPGRADE_REQUEST_HEADERS.contains(header.toLowerCase()) && !header.toLowerCase().contains("websocket-");
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static boolean shouldIncludeRequestMessageHeader(final String header) {
|
||||
return !"X-Forwarded-For".equalsIgnoreCase(header.trim());
|
||||
}
|
||||
|
||||
private void handleResponse(WebSocketResponseMessage responseMessage) {
|
||||
CompletableFuture<WebSocketResponseMessage> future = requestMap.remove(responseMessage.getRequestId());
|
||||
|
||||
|
|
|
@ -54,6 +54,7 @@ import java.security.Principal;
|
|||
import java.util.Arrays;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
|
@ -613,6 +614,45 @@ public class WebSocketResourceProviderTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testShouldIncludeUpgradeRequestHeader() {
|
||||
assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("Upgrade")).isFalse();
|
||||
assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("Connection")).isFalse();
|
||||
assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("Sec-WebSocket-Key")).isFalse();
|
||||
assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("User-Agent")).isTrue();
|
||||
assertThat(WebSocketResourceProvider.shouldIncludeUpgradeRequestHeader("X-Forwarded-For")).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testShouldIncludeRequestMessageHeader() {
|
||||
assertThat(WebSocketResourceProvider.shouldIncludeRequestMessageHeader("X-Forwarded-For")).isFalse();
|
||||
assertThat(WebSocketResourceProvider.shouldIncludeRequestMessageHeader("User-Agent")).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetCombinedHeaders() {
|
||||
final Map<String, List<String>> upgradeRequestHeaders = Map.of(
|
||||
"Host", List.of("server.example.com"),
|
||||
"Upgrade", List.of("websocket"),
|
||||
"Connection", List.of("Upgrade"),
|
||||
"Sec-WebSocket-Key", List.of("dGhlIHNhbXBsZSBub25jZQ=="),
|
||||
"Sec-WebSocket-Protocol", List.of("chat, superchat"),
|
||||
"Sec-WebSocket-Version", List.of("13"),
|
||||
"X-Forwarded-For", List.of("127.0.0.1"),
|
||||
"User-Agent", List.of("Upgrade request user agent"));
|
||||
|
||||
final Map<String, String> requestMessageHeaders = Map.of(
|
||||
"X-Forwarded-For", "192.168.0.1",
|
||||
"User-Agent", "Request message user agent");
|
||||
|
||||
final Map<String, List<String>> expectedHeaders = Map.of(
|
||||
"Host", List.of("server.example.com"),
|
||||
"X-Forwarded-For", List.of("127.0.0.1"),
|
||||
"User-Agent", List.of("Request message user agent"));
|
||||
|
||||
assertThat(WebSocketResourceProvider.getCombinedHeaders(upgradeRequestHeaders, requestMessageHeaders)).isEqualTo(expectedHeaders);
|
||||
}
|
||||
|
||||
private SubProtocol.WebSocketResponseMessage getResponse(ArgumentCaptor<ByteBuffer> responseCaptor) throws InvalidProtocolBufferException {
|
||||
return SubProtocol.WebSocketMessage.parseFrom(responseCaptor.getValue().array()).getResponse();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue