Always copy HTTP response headers to websocket responses.
This commit is contained in:
parent
817f057927
commit
eb8b5e5c01
|
@ -36,6 +36,7 @@ import org.whispersystems.websocket.session.ContextPrincipal;
|
||||||
import org.whispersystems.websocket.session.WebSocketSessionContext;
|
import org.whispersystems.websocket.session.WebSocketSessionContext;
|
||||||
import org.whispersystems.websocket.setup.WebSocketConnectListener;
|
import org.whispersystems.websocket.setup.WebSocketConnectListener;
|
||||||
|
|
||||||
|
import javax.ws.rs.core.MultivaluedMap;
|
||||||
import javax.ws.rs.core.Response;
|
import javax.ws.rs.core.Response;
|
||||||
import java.io.ByteArrayInputStream;
|
import java.io.ByteArrayInputStream;
|
||||||
import java.io.ByteArrayOutputStream;
|
import java.io.ByteArrayOutputStream;
|
||||||
|
@ -197,7 +198,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
|
||||||
byte[] responseBytes = messageFactory.createResponse(requestMessage.getRequestId(),
|
byte[] responseBytes = messageFactory.createResponse(requestMessage.getRequestId(),
|
||||||
response.getStatus(),
|
response.getStatus(),
|
||||||
response.getStatusInfo().getReasonPhrase(),
|
response.getStatusInfo().getReasonPhrase(),
|
||||||
new LinkedList<>(),
|
getHeaderList(response.getStringHeaders()),
|
||||||
Optional.ofNullable(body))
|
Optional.ofNullable(body))
|
||||||
.toByteArray();
|
.toByteArray();
|
||||||
|
|
||||||
|
@ -207,16 +208,10 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
|
||||||
|
|
||||||
private void sendErrorResponse(WebSocketRequestMessage requestMessage, Response error) {
|
private void sendErrorResponse(WebSocketRequestMessage requestMessage, Response error) {
|
||||||
if (requestMessage.hasRequestId()) {
|
if (requestMessage.hasRequestId()) {
|
||||||
List<String> headers = new LinkedList<>();
|
|
||||||
|
|
||||||
for (String key : error.getStringHeaders().keySet()) {
|
|
||||||
headers.add(key + ":" + error.getStringHeaders().getFirst(key));
|
|
||||||
}
|
|
||||||
|
|
||||||
WebSocketMessage response = messageFactory.createResponse(requestMessage.getRequestId(),
|
WebSocketMessage response = messageFactory.createResponse(requestMessage.getRequestId(),
|
||||||
error.getStatus(),
|
error.getStatus(),
|
||||||
"Error response",
|
"Error response",
|
||||||
headers,
|
getHeaderList(error.getStringHeaders()),
|
||||||
Optional.empty());
|
Optional.empty());
|
||||||
|
|
||||||
remoteEndpoint.sendBytesByFuture(ByteBuffer.wrap(response.toByteArray()));
|
remoteEndpoint.sendBytesByFuture(ByteBuffer.wrap(response.toByteArray()));
|
||||||
|
@ -229,4 +224,16 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket
|
||||||
return context;
|
return context;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
static List<String> getHeaderList(final MultivaluedMap<String, String> headerMap) {
|
||||||
|
final List<String> headers = new LinkedList<>();
|
||||||
|
|
||||||
|
if (headerMap != null) {
|
||||||
|
for (String key : headerMap.keySet()) {
|
||||||
|
headers.add(key + ":" + headerMap.getFirst(key));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return headers;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,19 +36,21 @@ import javax.ws.rs.Path;
|
||||||
import javax.ws.rs.PathParam;
|
import javax.ws.rs.PathParam;
|
||||||
import javax.ws.rs.Produces;
|
import javax.ws.rs.Produces;
|
||||||
import javax.ws.rs.core.MediaType;
|
import javax.ws.rs.core.MediaType;
|
||||||
|
import javax.ws.rs.core.MultivaluedHashMap;
|
||||||
|
import javax.ws.rs.core.MultivaluedMap;
|
||||||
import javax.ws.rs.core.Response;
|
import javax.ws.rs.core.Response;
|
||||||
import javax.ws.rs.ext.ExceptionMapper;
|
import javax.ws.rs.ext.ExceptionMapper;
|
||||||
import javax.ws.rs.ext.Provider;
|
import javax.ws.rs.ext.Provider;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.security.Principal;
|
import java.security.Principal;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.LinkedList;
|
import java.util.LinkedList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
|
||||||
import io.dropwizard.auth.Auth;
|
import io.dropwizard.auth.Auth;
|
||||||
import io.dropwizard.auth.AuthValueFactoryProvider;
|
|
||||||
import io.dropwizard.jersey.DropwizardResourceConfig;
|
import io.dropwizard.jersey.DropwizardResourceConfig;
|
||||||
import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider;
|
import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider;
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
@ -585,6 +587,21 @@ public class WebSocketResourceProviderTest {
|
||||||
assertThat(response.getBody().toStringUtf8()).isEqualTo("my response");
|
assertThat(response.getBody().toStringUtf8()).isEqualTo("my response");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGetHeaderList() {
|
||||||
|
assertThat(WebSocketResourceProvider.getHeaderList(new MultivaluedHashMap<>())).isEmpty();
|
||||||
|
|
||||||
|
{
|
||||||
|
final MultivaluedMap<String, String> headers = new MultivaluedHashMap<>();
|
||||||
|
headers.put("test", Arrays.asList("a", "b", "c"));
|
||||||
|
|
||||||
|
final List<String> headerStrings = WebSocketResourceProvider.getHeaderList(headers);
|
||||||
|
|
||||||
|
assertThat(headerStrings).hasSize(1);
|
||||||
|
assertThat(headerStrings).contains("test:a");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private SubProtocol.WebSocketResponseMessage getResponse(ArgumentCaptor<ByteBuffer> responseCaptor) throws InvalidProtocolBufferException {
|
private SubProtocol.WebSocketResponseMessage getResponse(ArgumentCaptor<ByteBuffer> responseCaptor) throws InvalidProtocolBufferException {
|
||||||
return SubProtocol.WebSocketMessage.parseFrom(responseCaptor.getValue().array()).getResponse();
|
return SubProtocol.WebSocketMessage.parseFrom(responseCaptor.getValue().array()).getResponse();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue