diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java index e75708d79..6cabf7b15 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java @@ -36,6 +36,7 @@ import org.whispersystems.websocket.session.ContextPrincipal; import org.whispersystems.websocket.session.WebSocketSessionContext; import org.whispersystems.websocket.setup.WebSocketConnectListener; +import javax.ws.rs.core.MultivaluedMap; import javax.ws.rs.core.Response; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -197,7 +198,7 @@ public class WebSocketResourceProvider implements WebSocket byte[] responseBytes = messageFactory.createResponse(requestMessage.getRequestId(), response.getStatus(), response.getStatusInfo().getReasonPhrase(), - new LinkedList<>(), + getHeaderList(response.getStringHeaders()), Optional.ofNullable(body)) .toByteArray(); @@ -207,16 +208,10 @@ public class WebSocketResourceProvider implements WebSocket private void sendErrorResponse(WebSocketRequestMessage requestMessage, Response error) { if (requestMessage.hasRequestId()) { - List headers = new LinkedList<>(); - - for (String key : error.getStringHeaders().keySet()) { - headers.add(key + ":" + error.getStringHeaders().getFirst(key)); - } - WebSocketMessage response = messageFactory.createResponse(requestMessage.getRequestId(), error.getStatus(), "Error response", - headers, + getHeaderList(error.getStringHeaders()), Optional.empty()); remoteEndpoint.sendBytesByFuture(ByteBuffer.wrap(response.toByteArray())); @@ -229,4 +224,16 @@ public class WebSocketResourceProvider implements WebSocket return context; } + @VisibleForTesting + static List getHeaderList(final MultivaluedMap headerMap) { + final List headers = new LinkedList<>(); + + if (headerMap != null) { + for (String key : headerMap.keySet()) { + headers.add(key + ":" + headerMap.getFirst(key)); + } + } + + return headers; + } } diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java index a5d874733..803fecc1b 100644 --- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java +++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java @@ -36,19 +36,21 @@ import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.MultivaluedHashMap; +import javax.ws.rs.core.MultivaluedMap; import javax.ws.rs.core.Response; import javax.ws.rs.ext.ExceptionMapper; import javax.ws.rs.ext.Provider; import java.io.OutputStream; import java.nio.ByteBuffer; import java.security.Principal; +import java.util.Arrays; import java.util.LinkedList; import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; import io.dropwizard.auth.Auth; -import io.dropwizard.auth.AuthValueFactoryProvider; import io.dropwizard.jersey.DropwizardResourceConfig; import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider; import static org.assertj.core.api.Assertions.assertThat; @@ -585,6 +587,21 @@ public class WebSocketResourceProviderTest { assertThat(response.getBody().toStringUtf8()).isEqualTo("my response"); } + @Test + public void testGetHeaderList() { + assertThat(WebSocketResourceProvider.getHeaderList(new MultivaluedHashMap<>())).isEmpty(); + + { + final MultivaluedMap headers = new MultivaluedHashMap<>(); + headers.put("test", Arrays.asList("a", "b", "c")); + + final List headerStrings = WebSocketResourceProvider.getHeaderList(headers); + + assertThat(headerStrings).hasSize(1); + assertThat(headerStrings).contains("test:a"); + } + } + private SubProtocol.WebSocketResponseMessage getResponse(ArgumentCaptor responseCaptor) throws InvalidProtocolBufferException { return SubProtocol.WebSocketMessage.parseFrom(responseCaptor.getValue().array()).getResponse(); }