Always copy HTTP response headers to websocket responses.
This commit is contained in:
		
							parent
							
								
									817f057927
								
							
						
					
					
						commit
						eb8b5e5c01
					
				|  | @ -36,6 +36,7 @@ import org.whispersystems.websocket.session.ContextPrincipal; | |||
| import org.whispersystems.websocket.session.WebSocketSessionContext; | ||||
| import org.whispersystems.websocket.setup.WebSocketConnectListener; | ||||
| 
 | ||||
| import javax.ws.rs.core.MultivaluedMap; | ||||
| import javax.ws.rs.core.Response; | ||||
| import java.io.ByteArrayInputStream; | ||||
| import java.io.ByteArrayOutputStream; | ||||
|  | @ -197,7 +198,7 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket | |||
|       byte[] responseBytes = messageFactory.createResponse(requestMessage.getRequestId(), | ||||
|                                                            response.getStatus(), | ||||
|                                                            response.getStatusInfo().getReasonPhrase(), | ||||
|                                                            new LinkedList<>(), | ||||
|                                                            getHeaderList(response.getStringHeaders()), | ||||
|                                                            Optional.ofNullable(body)) | ||||
|                                            .toByteArray(); | ||||
| 
 | ||||
|  | @ -207,16 +208,10 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket | |||
| 
 | ||||
|   private void sendErrorResponse(WebSocketRequestMessage requestMessage, Response error) { | ||||
|     if (requestMessage.hasRequestId()) { | ||||
|       List<String> headers = new LinkedList<>(); | ||||
| 
 | ||||
|       for (String key : error.getStringHeaders().keySet()) { | ||||
|         headers.add(key + ":" + error.getStringHeaders().getFirst(key)); | ||||
|       } | ||||
| 
 | ||||
|       WebSocketMessage response = messageFactory.createResponse(requestMessage.getRequestId(), | ||||
|                                                                 error.getStatus(), | ||||
|                                                                 "Error response", | ||||
|                                                                 headers, | ||||
|                                                                 getHeaderList(error.getStringHeaders()), | ||||
|                                                                 Optional.empty()); | ||||
| 
 | ||||
|       remoteEndpoint.sendBytesByFuture(ByteBuffer.wrap(response.toByteArray())); | ||||
|  | @ -229,4 +224,16 @@ public class WebSocketResourceProvider<T extends Principal> implements WebSocket | |||
|     return context; | ||||
|   } | ||||
| 
 | ||||
|   @VisibleForTesting | ||||
|   static List<String> getHeaderList(final MultivaluedMap<String, String> headerMap) { | ||||
|     final List<String> headers = new LinkedList<>(); | ||||
| 
 | ||||
|     if (headerMap != null) { | ||||
|       for (String key : headerMap.keySet()) { | ||||
|         headers.add(key + ":" + headerMap.getFirst(key)); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     return headers; | ||||
|   } | ||||
| } | ||||
|  |  | |||
|  | @ -36,19 +36,21 @@ import javax.ws.rs.Path; | |||
| import javax.ws.rs.PathParam; | ||||
| import javax.ws.rs.Produces; | ||||
| import javax.ws.rs.core.MediaType; | ||||
| import javax.ws.rs.core.MultivaluedHashMap; | ||||
| import javax.ws.rs.core.MultivaluedMap; | ||||
| import javax.ws.rs.core.Response; | ||||
| import javax.ws.rs.ext.ExceptionMapper; | ||||
| import javax.ws.rs.ext.Provider; | ||||
| import java.io.OutputStream; | ||||
| import java.nio.ByteBuffer; | ||||
| import java.security.Principal; | ||||
| import java.util.Arrays; | ||||
| import java.util.LinkedList; | ||||
| import java.util.List; | ||||
| import java.util.Optional; | ||||
| import java.util.concurrent.CompletableFuture; | ||||
| 
 | ||||
| import io.dropwizard.auth.Auth; | ||||
| import io.dropwizard.auth.AuthValueFactoryProvider; | ||||
| import io.dropwizard.jersey.DropwizardResourceConfig; | ||||
| import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider; | ||||
| import static org.assertj.core.api.Assertions.assertThat; | ||||
|  | @ -585,6 +587,21 @@ public class WebSocketResourceProviderTest { | |||
|     assertThat(response.getBody().toStringUtf8()).isEqualTo("my response"); | ||||
|   } | ||||
| 
 | ||||
|   @Test | ||||
|   public void testGetHeaderList() { | ||||
|     assertThat(WebSocketResourceProvider.getHeaderList(new MultivaluedHashMap<>())).isEmpty(); | ||||
| 
 | ||||
|     { | ||||
|       final MultivaluedMap<String, String> headers = new MultivaluedHashMap<>(); | ||||
|       headers.put("test", Arrays.asList("a", "b", "c")); | ||||
| 
 | ||||
|       final List<String> headerStrings = WebSocketResourceProvider.getHeaderList(headers); | ||||
| 
 | ||||
|       assertThat(headerStrings).hasSize(1); | ||||
|       assertThat(headerStrings).contains("test:a"); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   private SubProtocol.WebSocketResponseMessage getResponse(ArgumentCaptor<ByteBuffer> responseCaptor) throws InvalidProtocolBufferException { | ||||
|     return SubProtocol.WebSocketMessage.parseFrom(responseCaptor.getValue().array()).getResponse(); | ||||
|   } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Jon Chambers
						Jon Chambers