diff --git a/pom.xml b/pom.xml index 7726efa48..96b267048 100644 --- a/pom.xml +++ b/pom.xml @@ -16,14 +16,21 @@ - 1.3.16 - 2.9.10.20191020 + 2.0.2 0.14.1 + 2.25.1 UTF-8 3.06 + + io.dropwizard + dropwizard-dependencies + 2.0.2 + + + org.whispersystems.textsecure TextSecureServer 1.0 @@ -78,50 +85,32 @@ ${dropwizard.version} test + + org.hamcrest + hamcrest-all + 1.3 + test + com.github.tomakehurst wiremock-jre8 - 2.23.2 + 2.26.2 test - com.google.guava - guava + org.hamcrest + hamcrest-core - org.eclipse.jetty - jetty-server - - - org.eclipse.jetty - jetty-servlets - - - org.eclipse.jetty - jetty-servlet - - - org.eclipse.jetty - jetty-webapp - - - com.fasterxml.jackson.core - jackson-core - - - com.fasterxml.jackson.core - jackson-annotations - - - com.fasterxml.jackson.core - jackson-databind + javax.xml.bind + jaxb-api org.mockito mockito-core - 2.25.1 + ${mockito.version} test @@ -156,6 +145,15 @@ 3.0.0-M1 + + org.apache.maven.plugins + maven-enforcer-plugin + 1.4.1 + + + + + diff --git a/service/pom.xml b/service/pom.xml index db670d2bc..efd720a3c 100644 --- a/service/pom.xml +++ b/service/pom.xml @@ -83,7 +83,7 @@ com.amazonaws aws-java-sdk-sqs - 1.11.362 + 1.11.366 @@ -140,7 +140,7 @@ org.glassfish.jersey.test-framework.providers jersey-test-framework-provider-grizzly2 - 2.25.1 + 2.30 test @@ -171,22 +171,6 @@ - - - - com.fasterxml.jackson.core - jackson-databind - 2.9.10.1 - - - com.fasterxml.jackson.dataformat - jackson-dataformat-cbor - 2.9.10 - - - - - ${parent.artifactId}-${TextSecureServer.version} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 6b5ddcdef..1a89bb289 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -284,7 +284,7 @@ public class WhisperServerService extends Application webSocketEnvironment = new WebSocketEnvironment<>(environment, config.getWebSocketConfiguration(), 90000); webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator)); webSocketEnvironment.setConnectListener(new AuthenticatedConnectListener(pushSender, receiptSender, messagesManager, pubSubManager, apnFallbackManager)); webSocketEnvironment.jersey().register(new KeepAliveController(pubSubManager)); @@ -294,12 +294,15 @@ public class WhisperServerService extends Application provisioningEnvironment = new WebSocketEnvironment<>(environment, webSocketEnvironment.getRequestLog(), 60000); provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(pubSubManager)); provisioningEnvironment.jersey().register(new KeepAliveController(pubSubManager)); - WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory(webSocketEnvironment ); - WebSocketResourceProviderFactory provisioningServlet = new WebSocketResourceProviderFactory(provisioningEnvironment); + registerCorsFilter(environment); + registerExceptionMappers(environment, webSocketEnvironment, provisioningEnvironment); + + WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory<>(webSocketEnvironment, Account.class); + WebSocketResourceProviderFactory provisioningServlet = new WebSocketResourceProviderFactory<>(provisioningEnvironment, Account.class); ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet ); ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet); @@ -310,26 +313,11 @@ public class WhisperServerService extends Application webSocketEnvironment, WebSocketEnvironment provisioningEnvironment) { + environment.jersey().register(new IOExceptionMapper()); + environment.jersey().register(new RateLimitExceededExceptionMapper()); + environment.jersey().register(new InvalidWebsocketAddressExceptionMapper()); + environment.jersey().register(new DeviceLimitExceededExceptionMapper()); + + webSocketEnvironment.jersey().register(new IOExceptionMapper()); + webSocketEnvironment.jersey().register(new RateLimitExceededExceptionMapper()); + webSocketEnvironment.jersey().register(new InvalidWebsocketAddressExceptionMapper()); + webSocketEnvironment.jersey().register(new DeviceLimitExceededExceptionMapper()); + + provisioningEnvironment.jersey().register(new IOExceptionMapper()); + provisioningEnvironment.jersey().register(new RateLimitExceededExceptionMapper()); + provisioningEnvironment.jersey().register(new InvalidWebsocketAddressExceptionMapper()); + provisioningEnvironment.jersey().register(new DeviceLimitExceededExceptionMapper()); + } + + private void registerCorsFilter(Environment environment) { + FilterRegistration.Dynamic filter = environment.servlets().addFilter("CORS", CrossOriginFilter.class); + filter.addMappingForUrlPatterns(EnumSet.allOf(DispatcherType.class), true, "/*"); + filter.setInitParameter("allowedOrigins", "*"); + filter.setInitParameter("allowedHeaders", "Content-Type,Authorization,X-Requested-With,Content-Length,Accept,Origin,X-Signal-Agent"); + filter.setInitParameter("allowedMethods", "GET,PUT,POST,DELETE,OPTIONS"); + filter.setInitParameter("preflightMaxAge", "5184000"); + filter.setInitParameter("allowCredentials", "true"); + } + public static void main(String[] args) throws Exception { new WhisperServerService().run(args); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java index d4610f1eb..fc517aaa4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/ProfileController.java @@ -5,7 +5,6 @@ import com.codahale.metrics.annotation.Timed; import org.apache.commons.codec.DecoderException; import org.apache.commons.codec.binary.Base64; import org.apache.commons.codec.binary.Hex; -import org.hibernate.validator.valuehandling.UnwrapValidatedValue; import org.signal.zkgroup.InvalidInputException; import org.signal.zkgroup.VerificationFailedException; import org.signal.zkgroup.profiles.ProfileKeyCommitment; @@ -34,6 +33,7 @@ import org.whispersystems.textsecuregcm.util.ExactlySize; import org.whispersystems.textsecuregcm.util.Pair; import javax.validation.Valid; +import javax.validation.valueextraction.Unwrapping; import javax.ws.rs.Consumes; import javax.ws.rs.GET; import javax.ws.rs.HeaderParam; @@ -266,7 +266,7 @@ public class ProfileController { @PUT @Produces(MediaType.APPLICATION_JSON) @Path("/name/{name}") - public void setProfile(@Auth Account account, @PathParam("name") @UnwrapValidatedValue(true) @ExactlySize({72, 108}) Optional name) { + public void setProfile(@Auth Account account, @PathParam("name") @ExactlySize(value = {72, 108}, payload = {Unwrapping.Unwrap.class}) Optional name) { account.setProfileName(name.orElse(null)); accountsManager.update(account); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/DirectoryFeedbackRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/DirectoryFeedbackRequest.java index cafb4a779..ce7fe744d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/DirectoryFeedbackRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/DirectoryFeedbackRequest.java @@ -20,11 +20,12 @@ package org.whispersystems.textsecuregcm.entities; import com.fasterxml.jackson.annotation.JsonProperty; import javax.validation.constraints.Size; +import javax.validation.valueextraction.Unwrapping; import java.util.Optional; public class DirectoryFeedbackRequest { - @Size(max = 1024) + @Size(max = 1024, payload = {Unwrapping.Unwrap.class}) @JsonProperty private Optional reason; diff --git a/websocket-resources/pom.xml b/websocket-resources/pom.xml index dbb773b02..b81d2427f 100644 --- a/websocket-resources/pom.xml +++ b/websocket-resources/pom.xml @@ -16,13 +16,20 @@ org.eclipse.jetty.websocket websocket-server - 9.4.18.v20190429 + 9.4.26.v20200117 com.google.protobuf protobuf-java 2.6.1 + + + org.mockito + mockito-inline + ${mockito.version} + test + \ No newline at end of file diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java index 26b9c4d5a..e75708d79 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProvider.java @@ -17,32 +17,33 @@ package org.whispersystems.websocket; import com.google.common.annotations.VisibleForTesting; -import org.eclipse.jetty.server.RequestLog; import org.eclipse.jetty.websocket.api.RemoteEndpoint; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.WebSocketListener; +import org.glassfish.jersey.internal.MapPropertiesDelegate; +import org.glassfish.jersey.server.ApplicationHandler; +import org.glassfish.jersey.server.ContainerRequest; +import org.glassfish.jersey.server.ContainerResponse; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.websocket.logging.WebsocketRequestLog; import org.whispersystems.websocket.messages.InvalidMessageException; import org.whispersystems.websocket.messages.WebSocketMessage; import org.whispersystems.websocket.messages.WebSocketMessageFactory; import org.whispersystems.websocket.messages.WebSocketRequestMessage; import org.whispersystems.websocket.messages.WebSocketResponseMessage; -import org.whispersystems.websocket.servlet.LoggableRequest; -import org.whispersystems.websocket.servlet.LoggableResponse; -import org.whispersystems.websocket.servlet.NullServletResponse; -import org.whispersystems.websocket.servlet.WebSocketServletRequest; -import org.whispersystems.websocket.servlet.WebSocketServletResponse; +import org.whispersystems.websocket.session.ContextPrincipal; import org.whispersystems.websocket.session.WebSocketSessionContext; import org.whispersystems.websocket.setup.WebSocketConnectListener; -import javax.servlet.ServletException; -import javax.servlet.http.HttpServlet; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; import javax.ws.rs.core.Response; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.net.URI; import java.nio.ByteBuffer; +import java.security.Principal; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -52,31 +53,34 @@ import java.util.concurrent.ConcurrentHashMap; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") -public class WebSocketResourceProvider implements WebSocketListener { +public class WebSocketResourceProvider implements WebSocketListener { private static final Logger logger = LoggerFactory.getLogger(WebSocketResourceProvider.class); private final Map> requestMap = new ConcurrentHashMap<>(); - private final Object authenticated; + private final T authenticated; private final WebSocketMessageFactory messageFactory; private final Optional connectListener; - private final HttpServlet servlet; - private final RequestLog requestLog; + private final ApplicationHandler jerseyHandler; + private final WebsocketRequestLog requestLog; private final long idleTimeoutMillis; + private final String remoteAddress; private Session session; private RemoteEndpoint remoteEndpoint; private WebSocketSessionContext context; - public WebSocketResourceProvider(HttpServlet servlet, - RequestLog requestLog, - Object authenticated, + public WebSocketResourceProvider(String remoteAddress, + ApplicationHandler jerseyHandler, + WebsocketRequestLog requestLog, + T authenticated, WebSocketMessageFactory messageFactory, Optional connectListener, long idleTimeoutMillis) { - this.servlet = servlet; + this.remoteAddress = remoteAddress; + this.jerseyHandler = jerseyHandler; this.requestLog = requestLog; this.authenticated = authenticated; this.messageFactory = messageFactory; @@ -131,7 +135,7 @@ public class WebSocketResourceProvider implements WebSocketListener { context.notifyClosed(statusCode, reason); for (long requestId : requestMap.keySet()) { - CompletableFuture outstandingRequest = requestMap.remove(requestId); + CompletableFuture outstandingRequest = requestMap.remove(requestId); if (outstandingRequest != null) { outstandingRequest.completeExceptionally(new IOException("Connection closed!")); @@ -146,17 +150,28 @@ public class WebSocketResourceProvider implements WebSocketListener { } private void handleRequest(WebSocketRequestMessage requestMessage) { - try { - HttpServletRequest servletRequest = createRequest(requestMessage, context); - HttpServletResponse servletResponse = createResponse(requestMessage); + ContainerRequest containerRequest = new ContainerRequest(null, URI.create(requestMessage.getPath()), requestMessage.getVerb(), new WebSocketSecurityContext(new ContextPrincipal(context)), new MapPropertiesDelegate(new HashMap<>()), null); - servlet.service(servletRequest, servletResponse); - servletResponse.flushBuffer(); - requestLog.log(new LoggableRequest(servletRequest), new LoggableResponse(servletResponse)); - } catch (IOException | ServletException e) { - logger.warn("Servlet Error: " + requestMessage.getVerb() + " " + requestMessage.getPath() + "\n" + requestMessage.getBody(), e); - sendErrorResponse(requestMessage, Response.status(500).build()); + for (Map.Entry entry : requestMessage.getHeaders().entrySet()) { + containerRequest.header(entry.getKey(), entry.getValue()); } + + if (requestMessage.getBody().isPresent()) { + containerRequest.setEntityStream(new ByteArrayInputStream(requestMessage.getBody().get())); + } + + ByteArrayOutputStream responseBody = new ByteArrayOutputStream(); + CompletableFuture responseFuture = (CompletableFuture) jerseyHandler.apply(containerRequest, responseBody); + + responseFuture.thenAccept(response -> { + sendResponse(requestMessage, response, responseBody); + requestLog.log(remoteAddress, containerRequest, response); + }).exceptionally(exception -> { + logger.warn("Websocket Error: " + requestMessage.getVerb() + " " + requestMessage.getPath() + "\n" + requestMessage.getBody(), exception); + sendErrorResponse(requestMessage, Response.status(500).build()); + requestLog.log(remoteAddress, containerRequest, new ContainerResponse(containerRequest, Response.status(500).build())); + return null; + }); } private void handleResponse(WebSocketResponseMessage responseMessage) { @@ -171,17 +186,22 @@ public class WebSocketResourceProvider implements WebSocketListener { session.close(status, message); } - private HttpServletRequest createRequest(WebSocketRequestMessage message, - WebSocketSessionContext context) - { - return new WebSocketServletRequest(context, message, servlet.getServletContext()); - } + private void sendResponse(WebSocketRequestMessage requestMessage, ContainerResponse response, ByteArrayOutputStream responseBody) { + if (requestMessage.hasRequestId()) { + byte[] body = responseBody.toByteArray(); - private HttpServletResponse createResponse(WebSocketRequestMessage message) { - if (message.hasRequestId()) { - return new WebSocketServletResponse(remoteEndpoint, message.getRequestId(), messageFactory); - } else { - return new NullServletResponse(); + if (body.length <= 0) { + body = null; + } + + byte[] responseBytes = messageFactory.createResponse(requestMessage.getRequestId(), + response.getStatus(), + response.getStatusInfo().getReasonPhrase(), + new LinkedList<>(), + Optional.ofNullable(body)) + .toByteArray(); + + remoteEndpoint.sendBytesByFuture(ByteBuffer.wrap(responseBytes)); } } @@ -203,8 +223,10 @@ public class WebSocketResourceProvider implements WebSocketListener { } } + @VisibleForTesting WebSocketSessionContext getContext() { return context; } + } diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java index 93396b46e..d8e99da95 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketResourceProviderFactory.java @@ -16,77 +16,56 @@ */ package org.whispersystems.websocket; -import org.eclipse.jetty.server.Server; -import org.eclipse.jetty.util.AttributesMap; import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest; import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse; import org.eclipse.jetty.websocket.servlet.WebSocketCreator; import org.eclipse.jetty.websocket.servlet.WebSocketServlet; import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory; +import org.glassfish.jersey.server.ApplicationHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.websocket.auth.AuthenticationException; import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.auth.WebSocketAuthenticator.AuthenticationResult; -import org.whispersystems.websocket.auth.internal.WebSocketAuthValueFactoryProvider; +import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider; import org.whispersystems.websocket.setup.WebSocketEnvironment; -import javax.servlet.Filter; -import javax.servlet.FilterRegistration; -import javax.servlet.RequestDispatcher; -import javax.servlet.Servlet; -import javax.servlet.ServletConfig; -import javax.servlet.ServletContext; -import javax.servlet.ServletException; -import javax.servlet.ServletRegistration; -import javax.servlet.SessionCookieConfig; -import javax.servlet.SessionTrackingMode; -import javax.servlet.descriptor.JspConfigDescriptor; import java.io.IOException; -import java.io.InputStream; -import java.net.MalformedURLException; -import java.net.URL; -import java.security.AccessController; -import java.util.Collections; -import java.util.Enumeration; -import java.util.EventListener; -import java.util.Map; +import java.security.Principal; +import java.util.Arrays; import java.util.Optional; -import java.util.Set; import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider; +import static java.util.Optional.ofNullable; -public class WebSocketResourceProviderFactory extends WebSocketServlet implements WebSocketCreator { +public class WebSocketResourceProviderFactory extends WebSocketServlet implements WebSocketCreator { private static final Logger logger = LoggerFactory.getLogger(WebSocketResourceProviderFactory.class); - private final WebSocketEnvironment environment; + private final WebSocketEnvironment environment; + private final ApplicationHandler jerseyApplicationHandler; - public WebSocketResourceProviderFactory(WebSocketEnvironment environment) - throws ServletException - { + public WebSocketResourceProviderFactory(WebSocketEnvironment environment, Class principalClass) { this.environment = environment; environment.jersey().register(new WebSocketSessionContextValueFactoryProvider.Binder()); - environment.jersey().register(new WebSocketAuthValueFactoryProvider.Binder()); + environment.jersey().register(new WebsocketAuthValueFactoryProvider.Binder(principalClass)); environment.jersey().register(new JacksonMessageBodyProvider(environment.getObjectMapper())); - } - public void start() throws ServletException { - this.environment.getJerseyServletContainer().init(new WServletConfig()); + this.jerseyApplicationHandler = new ApplicationHandler(environment.jersey()); } @Override public Object createWebSocket(ServletUpgradeRequest request, ServletUpgradeResponse response) { try { - Optional authenticator = Optional.ofNullable(environment.getAuthenticator()); - Object authenticated = null; + Optional> authenticator = Optional.ofNullable(environment.getAuthenticator()); + T authenticated = null; if (authenticator.isPresent()) { - AuthenticationResult authenticationResult = authenticator.get().authenticate(request); + AuthenticationResult authenticationResult = authenticator.get().authenticate(request); - if (!authenticationResult.getUser().isPresent() && authenticationResult.isRequired()) { + if (authenticationResult.getUser().isEmpty() && authenticationResult.isRequired()) { response.sendForbidden("Unauthorized"); return null; } else { @@ -94,14 +73,18 @@ public class WebSocketResourceProviderFactory extends WebSocketServlet implement } } - return new WebSocketResourceProvider(this.environment.getJerseyServletContainer(), - this.environment.getRequestLog(), - authenticated, - this.environment.getMessageFactory(), - Optional.ofNullable(this.environment.getConnectListener()), - this.environment.getIdleTimeoutMillis()); + return new WebSocketResourceProvider(getRemoteAddress(request), + this.jerseyApplicationHandler, + this.environment.getRequestLog(), + authenticated, + this.environment.getMessageFactory(), + ofNullable(this.environment.getConnectListener()), + this.environment.getIdleTimeoutMillis()); } catch (AuthenticationException | IOException e) { logger.warn("Authentication failure", e); + try { + response.sendError(500, "Failure"); + } catch (IOException ex) {} return null; } } @@ -111,358 +94,16 @@ public class WebSocketResourceProviderFactory extends WebSocketServlet implement factory.setCreator(this); } - private static class WServletConfig implements ServletConfig { + private String getRemoteAddress(ServletUpgradeRequest request) { + String forwardedFor = request.getHeader("X-Forwarded-For"); - private final ServletContext context = new NoContext(); - - @Override - public String getServletName() { - return "WebSocketResourceServlet"; - } - - @Override - public ServletContext getServletContext() { - return context; - } - - @Override - public String getInitParameter(String name) { - return null; - } - - @Override - public Enumeration getInitParameterNames() { - return new Enumeration() { - @Override - public boolean hasMoreElements() { - return false; - } - - @Override - public String nextElement() { - return null; - } - }; + if (forwardedFor == null || forwardedFor.isBlank()) { + return request.getRemoteAddress(); + } else { + return Arrays.stream(forwardedFor.split(",")) + .map(String::trim) + .reduce((a, b) -> b) + .orElseThrow(); } } - - public static class NoContext extends AttributesMap implements ServletContext - { - - private int effectiveMajorVersion = 3; - private int effectiveMinorVersion = 0; - - @Override - public ServletContext getContext(String uripath) - { - return null; - } - - @Override - public int getMajorVersion() - { - return 3; - } - - @Override - public String getMimeType(String file) - { - return null; - } - - @Override - public int getMinorVersion() - { - return 0; - } - - @Override - public RequestDispatcher getNamedDispatcher(String name) - { - return null; - } - - @Override - public RequestDispatcher getRequestDispatcher(String uriInContext) - { - return null; - } - - @Override - public String getRealPath(String path) - { - return null; - } - - @Override - public URL getResource(String path) throws MalformedURLException - { - return null; - } - - @Override - public InputStream getResourceAsStream(String path) - { - return null; - } - - @Override - public Set getResourcePaths(String path) - { - return null; - } - - @Override - public String getServerInfo() - { - return "websocketresources/" + Server.getVersion(); - } - - @Override - @Deprecated - public Servlet getServlet(String name) throws ServletException - { - return null; - } - - @SuppressWarnings("unchecked") - @Override - @Deprecated - public Enumeration getServletNames() - { - return Collections.enumeration(Collections.EMPTY_LIST); - } - - @SuppressWarnings("unchecked") - @Override - @Deprecated - public Enumeration getServlets() - { - return Collections.enumeration(Collections.EMPTY_LIST); - } - - @Override - public void log(Exception exception, String msg) - { - logger.warn(msg,exception); - } - - @Override - public void log(String msg) - { - logger.info(msg); - } - - @Override - public void log(String message, Throwable throwable) - { - logger.warn(message,throwable); - } - - @Override - public String getInitParameter(String name) - { - return null; - } - - @SuppressWarnings("unchecked") - @Override - public Enumeration getInitParameterNames() - { - return Collections.enumeration(Collections.EMPTY_LIST); - } - - - @Override - public String getServletContextName() - { - return "No Context"; - } - - @Override - public String getContextPath() - { - return null; - } - - - @Override - public boolean setInitParameter(String name, String value) - { - return false; - } - - @Override - public FilterRegistration.Dynamic addFilter(String filterName, Class filterClass) - { - return null; - } - - @Override - public FilterRegistration.Dynamic addFilter(String filterName, Filter filter) - { - return null; - } - - @Override - public FilterRegistration.Dynamic addFilter(String filterName, String className) - { - return null; - } - - @Override - public javax.servlet.ServletRegistration.Dynamic addServlet(String servletName, Class servletClass) - { - return null; - } - - @Override - public javax.servlet.ServletRegistration.Dynamic addServlet(String servletName, Servlet servlet) - { - return null; - } - - @Override - public javax.servlet.ServletRegistration.Dynamic addServlet(String servletName, String className) - { - return null; - } - - @Override - public T createFilter(Class c) throws ServletException - { - return null; - } - - @Override - public T createServlet(Class c) throws ServletException - { - return null; - } - - @Override - public Set getDefaultSessionTrackingModes() - { - return null; - } - - @Override - public Set getEffectiveSessionTrackingModes() - { - return null; - } - - @Override - public FilterRegistration getFilterRegistration(String filterName) - { - return null; - } - - @Override - public Map getFilterRegistrations() - { - return null; - } - - @Override - public ServletRegistration getServletRegistration(String servletName) - { - return null; - } - - @Override - public Map getServletRegistrations() - { - return null; - } - - @Override - public SessionCookieConfig getSessionCookieConfig() - { - return null; - } - - @Override - public void setSessionTrackingModes(Set sessionTrackingModes) - { - } - - @Override - public void addListener(String className) - { - } - - @Override - public void addListener(T t) - { - } - - @Override - public void addListener(Class listenerClass) - { - } - - @Override - public T createListener(Class clazz) throws ServletException - { - try - { - return clazz.newInstance(); - } - catch (InstantiationException e) - { - throw new ServletException(e); - } - catch (IllegalAccessException e) - { - throw new ServletException(e); - } - } - - @Override - public ClassLoader getClassLoader() - { - AccessController.checkPermission(new RuntimePermission("getClassLoader")); - return WebSocketResourceProviderFactory.class.getClassLoader(); - } - - @Override - public int getEffectiveMajorVersion() - { - return effectiveMajorVersion; - } - - @Override - public int getEffectiveMinorVersion() - { - return effectiveMinorVersion; - } - - public void setEffectiveMajorVersion (int v) - { - this.effectiveMajorVersion = v; - } - - public void setEffectiveMinorVersion (int v) - { - this.effectiveMinorVersion = v; - } - - @Override - public JspConfigDescriptor getJspConfigDescriptor() - { - return null; - } - - @Override - public void declareRoles(String... roleNames) - { - } - - @Override - public String getVirtualServerName() { - return null; - } - } - } diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketSecurityContext.java b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketSecurityContext.java new file mode 100644 index 000000000..c0072e51a --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/WebSocketSecurityContext.java @@ -0,0 +1,40 @@ +package org.whispersystems.websocket; + +import org.whispersystems.websocket.session.ContextPrincipal; +import org.whispersystems.websocket.session.WebSocketSessionContext; + +import javax.ws.rs.core.SecurityContext; +import java.security.Principal; + +public class WebSocketSecurityContext implements SecurityContext { + + private final ContextPrincipal principal; + + public WebSocketSecurityContext(ContextPrincipal principal) { + this.principal = principal; + } + + @Override + public Principal getUserPrincipal() { + return (Principal)principal.getContext().getAuthenticated(); + } + + @Override + public boolean isUserInRole(String role) { + return false; + } + + @Override + public boolean isSecure() { + return principal != null; + } + + @Override + public String getAuthenticationScheme() { + return null; + } + + public WebSocketSessionContext getSessionContext() { + return principal.getContext(); + } +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java index 86ffcb611..326de5641 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebSocketAuthenticator.java @@ -19,9 +19,10 @@ package org.whispersystems.websocket.auth; import org.eclipse.jetty.server.Authentication; import org.eclipse.jetty.websocket.api.UpgradeRequest; +import java.security.Principal; import java.util.Optional; -public interface WebSocketAuthenticator { +public interface WebSocketAuthenticator { AuthenticationResult authenticate(UpgradeRequest request) throws AuthenticationException; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebsocketAuthValueFactoryProvider.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebsocketAuthValueFactoryProvider.java new file mode 100644 index 000000000..768de8dfb --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/WebsocketAuthValueFactoryProvider.java @@ -0,0 +1,114 @@ +package org.whispersystems.websocket.auth; + +import org.glassfish.jersey.internal.inject.AbstractBinder; +import org.glassfish.jersey.server.ContainerRequest; +import org.glassfish.jersey.server.internal.inject.AbstractValueParamProvider; +import org.glassfish.jersey.server.internal.inject.MultivaluedParameterExtractorProvider; +import org.glassfish.jersey.server.model.Parameter; +import org.glassfish.jersey.server.spi.internal.ValueParamProvider; + +import javax.annotation.Nullable; +import javax.inject.Inject; +import javax.inject.Singleton; +import javax.ws.rs.WebApplicationException; +import java.lang.reflect.ParameterizedType; +import java.security.Principal; +import java.util.Optional; +import java.util.function.Function; + +import io.dropwizard.auth.Auth; + +@Singleton +public class WebsocketAuthValueFactoryProvider extends AbstractValueParamProvider { + + private final Class principalClass; + + @Inject + public WebsocketAuthValueFactoryProvider(MultivaluedParameterExtractorProvider mpep, WebsocketPrincipalClassProvider principalClassProvider) { + super(() -> mpep, Parameter.Source.UNKNOWN); + this.principalClass = principalClassProvider.clazz; + } + + @Nullable + @Override + protected Function createValueProvider(Parameter parameter) { + if (!parameter.isAnnotationPresent(Auth.class)) { + return null; + } + + if (parameter.getRawType() == Optional.class && + ParameterizedType.class.isAssignableFrom(parameter.getType().getClass()) && + principalClass == ((ParameterizedType)parameter.getType()).getActualTypeArguments()[0]) + { + return request -> new OptionalContainerRequestValueFactory(request).provide(); + } else if (principalClass.equals(parameter.getRawType())) { + return request -> new StandardContainerRequestValueFactory(request).provide(); + } else { + throw new IllegalStateException("Can't inject unassignable principal: " + principalClass + " for parameter: " + parameter); + } + } + + @Singleton + static class WebsocketPrincipalClassProvider { + + private final Class clazz; + + WebsocketPrincipalClassProvider(Class clazz) { + this.clazz = clazz; + } + } + + /** + * Injection binder for {@link io.dropwizard.auth.AuthValueFactoryProvider}. + * + * @param the type of the principal + */ + public static class Binder extends AbstractBinder { + + private final Class principalClass; + + public Binder(Class principalClass) { + this.principalClass = principalClass; + } + + @Override + protected void configure() { + bind(new WebsocketPrincipalClassProvider<>(principalClass)).to(WebsocketPrincipalClassProvider.class); + bind(WebsocketAuthValueFactoryProvider.class).to(ValueParamProvider.class).in(Singleton.class); + } + } + + private static class StandardContainerRequestValueFactory { + + private final ContainerRequest request; + + public StandardContainerRequestValueFactory(ContainerRequest request) { + this.request = request; + } + + public Principal provide() { + final Principal principal = request.getSecurityContext().getUserPrincipal(); + + if (principal == null) { + throw new WebApplicationException("Authenticated resource", 401); + } + + return principal; + } + + } + + private static class OptionalContainerRequestValueFactory { + + private final ContainerRequest request; + + public OptionalContainerRequestValueFactory(ContainerRequest request) { + this.request = request; + } + + public Optional provide() { + return Optional.ofNullable(request.getSecurityContext().getUserPrincipal()); + } + } + +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/internal/WebSocketAuthValueFactoryProvider.java b/websocket-resources/src/main/java/org/whispersystems/websocket/auth/internal/WebSocketAuthValueFactoryProvider.java deleted file mode 100644 index 104f89a82..000000000 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/auth/internal/WebSocketAuthValueFactoryProvider.java +++ /dev/null @@ -1,120 +0,0 @@ -package org.whispersystems.websocket.auth.internal; - -import org.glassfish.hk2.api.InjectionResolver; -import org.glassfish.hk2.api.ServiceLocator; -import org.glassfish.hk2.api.TypeLiteral; -import org.glassfish.hk2.utilities.binding.AbstractBinder; -import org.glassfish.jersey.server.internal.inject.AbstractContainerRequestValueFactory; -import org.glassfish.jersey.server.internal.inject.AbstractValueFactoryProvider; -import org.glassfish.jersey.server.internal.inject.MultivaluedParameterExtractorProvider; -import org.glassfish.jersey.server.internal.inject.ParamInjectionResolver; -import org.glassfish.jersey.server.model.Parameter; -import org.glassfish.jersey.server.spi.internal.ValueFactoryProvider; -import org.whispersystems.websocket.servlet.WebSocketServletRequest; - -import javax.inject.Inject; -import javax.inject.Singleton; -import javax.ws.rs.WebApplicationException; -import java.security.Principal; -import java.util.Optional; - -import io.dropwizard.auth.Auth; - -@Singleton -public class WebSocketAuthValueFactoryProvider extends AbstractValueFactoryProvider { - - @Inject - public WebSocketAuthValueFactoryProvider(MultivaluedParameterExtractorProvider mpep, - ServiceLocator injector) - { - super(mpep, injector, Parameter.Source.UNKNOWN); - } - - @Override - public AbstractContainerRequestValueFactory createValueFactory(final Parameter parameter) { - if (parameter.getAnnotation(Auth.class) == null) { - return null; - } - - if (parameter.getRawType() == Optional.class) { - return new OptionalContainerRequestValueFactory(parameter); - } else { - return new StandardContainerRequestValueFactory(parameter); - } - } - - private static class OptionalContainerRequestValueFactory extends AbstractContainerRequestValueFactory { - private final Parameter parameter; - - private OptionalContainerRequestValueFactory(Parameter parameter) { - this.parameter = parameter; - } - - @Override - public Object provide() { - Principal principal = getContainerRequest().getSecurityContext().getUserPrincipal(); - - if (principal != null && !(principal instanceof WebSocketServletRequest.ContextPrincipal)) { - throw new IllegalArgumentException("Can't inject non-ContextPrincipal into request"); - } - - if (principal == null) return Optional.empty(); - else return Optional.ofNullable(((WebSocketServletRequest.ContextPrincipal)principal).getContext().getAuthenticated()); - - } - } - - private static class StandardContainerRequestValueFactory extends AbstractContainerRequestValueFactory { - private final Parameter parameter; - - private StandardContainerRequestValueFactory(Parameter parameter) { - this.parameter = parameter; - } - - @Override - public Object provide() { - Principal principal = getContainerRequest().getSecurityContext().getUserPrincipal(); - - if (principal == null) { - throw new IllegalStateException("Cannot inject a custom principal into unauthenticated request"); - } - - if (!(principal instanceof WebSocketServletRequest.ContextPrincipal)) { - throw new IllegalArgumentException("Cannot inject a non-WebSocket AuthPrincipal into request"); - } - - Object authenticated = ((WebSocketServletRequest.ContextPrincipal)principal).getContext().getAuthenticated(); - - if (authenticated == null) { - throw new WebApplicationException("Authenticated resource", 401); - } - - if (!parameter.getRawType().isAssignableFrom(authenticated.getClass())) { - throw new IllegalArgumentException("Authenticated principal is of the wrong type: " + authenticated.getClass() + " looking for: " + parameter.getRawType()); - } - - return parameter.getRawType().cast(authenticated); - } - } - - @Singleton - private static class AuthInjectionResolver extends ParamInjectionResolver { - public AuthInjectionResolver() { - super(WebSocketAuthValueFactoryProvider.class); - } - } - - public static class Binder extends AbstractBinder { - - - public Binder() { - } - - @Override - protected void configure() { - bind(WebSocketAuthValueFactoryProvider.class).to(ValueFactoryProvider.class).in(Singleton.class); - bind(AuthInjectionResolver.class).to(new TypeLiteral>() { - }).in(Singleton.class); - } - } -} \ No newline at end of file diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/configuration/WebSocketConfiguration.java b/websocket-resources/src/main/java/org/whispersystems/websocket/configuration/WebSocketConfiguration.java index 05178ae5b..ab380ea8b 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/configuration/WebSocketConfiguration.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/configuration/WebSocketConfiguration.java @@ -2,6 +2,8 @@ package org.whispersystems.websocket.configuration; import com.fasterxml.jackson.annotation.JsonProperty; +import org.whispersystems.websocket.logging.WebsocketRequestLoggerFactory; + import javax.validation.Valid; import javax.validation.constraints.NotNull; @@ -13,9 +15,9 @@ public class WebSocketConfiguration { @Valid @NotNull @JsonProperty - private RequestLogFactory requestLog = new LogbackAccessRequestLogFactory(); + private WebsocketRequestLoggerFactory requestLog = new WebsocketRequestLoggerFactory(); - public RequestLogFactory getRequestLog() { + public WebsocketRequestLoggerFactory getRequestLog() { return requestLog; } } diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/logging/AsyncWebsocketEventAppenderFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/AsyncWebsocketEventAppenderFactory.java new file mode 100644 index 000000000..44b51554c --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/AsyncWebsocketEventAppenderFactory.java @@ -0,0 +1,16 @@ +package org.whispersystems.websocket.logging; + +import ch.qos.logback.core.AsyncAppenderBase; +import io.dropwizard.logging.async.AsyncAppenderFactory; + +public class AsyncWebsocketEventAppenderFactory implements AsyncAppenderFactory { + @Override + public AsyncAppenderBase build() { + return new AsyncAppenderBase() { + @Override + protected void preprocess(WebsocketEvent event) { + event.prepareForDeferredProcessing(); + } + }; + } +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/logging/WebsocketEvent.java b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/WebsocketEvent.java new file mode 100644 index 000000000..4622b226e --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/WebsocketEvent.java @@ -0,0 +1,73 @@ +package org.whispersystems.websocket.logging; + +import com.google.common.annotations.VisibleForTesting; +import org.glassfish.jersey.server.ContainerRequest; +import org.glassfish.jersey.server.ContainerResponse; + +import javax.ws.rs.core.MultivaluedMap; + +import java.util.List; + +import ch.qos.logback.core.spi.DeferredProcessingAware; + +public class WebsocketEvent implements DeferredProcessingAware { + + public static final int SENTINEL = -1; + public static final String NA = "-"; + + private final String remoteAddress; + private final ContainerRequest request; + private final ContainerResponse response; + private final long timestamp; + + public WebsocketEvent(String remoteAddress, ContainerRequest jerseyRequest, ContainerResponse jettyResponse) { + this.timestamp = System.currentTimeMillis(); + this.remoteAddress = remoteAddress; + this.request = jerseyRequest; + this.response = jettyResponse; + } + + public String getRemoteHost() { + return remoteAddress; + } + + public long getTimestamp() { + return timestamp; + } + + @Override + public void prepareForDeferredProcessing() { + + } + + public String getMethod() { + return request.getMethod(); + } + + public String getPath() { + return request.getBaseUri().getPath() + request.getPath(false); + } + + public String getProtocol() { + return "WS"; + } + + public int getStatusCode() { + return response.getStatus(); + } + + public long getContentLength() { + return response.getLength(); + } + + public String getRequestHeader(String key) { + List values = request.getRequestHeader(key); + + if (values == null) return NA; + else return values.stream().findFirst().orElse(NA); + } + + public MultivaluedMap getRequestHeaderMap() { + return request.getRequestHeaders(); + } +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/logging/WebsocketRequestLog.java b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/WebsocketRequestLog.java new file mode 100644 index 000000000..8b82aab28 --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/WebsocketRequestLog.java @@ -0,0 +1,43 @@ +package org.whispersystems.websocket.logging; + +import com.google.common.annotations.VisibleForTesting; +import org.glassfish.jersey.server.ContainerRequest; +import org.glassfish.jersey.server.ContainerResponse; + +import ch.qos.logback.core.Appender; +import ch.qos.logback.core.filter.Filter; +import ch.qos.logback.core.spi.AppenderAttachableImpl; +import ch.qos.logback.core.spi.FilterAttachableImpl; +import ch.qos.logback.core.spi.FilterReply; + +public class WebsocketRequestLog { + + private AppenderAttachableImpl aai = new AppenderAttachableImpl<>(); + private FilterAttachableImpl fai = new FilterAttachableImpl<>(); + + public WebsocketRequestLog() { + } + + public void log(String remoteAddress, ContainerRequest jerseyRequest, ContainerResponse jettyResponse) { + WebsocketEvent event = new WebsocketEvent(remoteAddress, jerseyRequest, jettyResponse); + + if (getFilterChainDecision(event) == FilterReply.DENY) { + return; + } + + aai.appendLoopOnAppenders(event); + } + + + public void addAppender(Appender newAppender) { + aai.addAppender(newAppender); + } + + public void addFilter(Filter newFilter) { + fai.addFilter(newFilter); + } + + public FilterReply getFilterChainDecision(WebsocketEvent event) { + return fai.getFilterChainDecision(event); + } +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/logging/WebsocketRequestLoggerFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/WebsocketRequestLoggerFactory.java new file mode 100644 index 000000000..9bbe0baa3 --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/WebsocketRequestLoggerFactory.java @@ -0,0 +1,45 @@ +package org.whispersystems.websocket.logging; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.LoggerFactory; +import org.whispersystems.websocket.logging.layout.WebsocketEventLayoutFactory; + +import javax.validation.Valid; +import javax.validation.constraints.NotNull; +import java.util.Collections; +import java.util.List; + +import ch.qos.logback.classic.Logger; +import ch.qos.logback.classic.LoggerContext; +import io.dropwizard.logging.AppenderFactory; +import io.dropwizard.logging.ConsoleAppenderFactory; +import io.dropwizard.logging.async.AsyncAppenderFactory; +import io.dropwizard.logging.filter.LevelFilterFactory; +import io.dropwizard.logging.filter.NullLevelFilterFactory; +import io.dropwizard.logging.layout.LayoutFactory; + +public class WebsocketRequestLoggerFactory { + + @VisibleForTesting + @Valid + @NotNull + public List> appenders = Collections.singletonList(new ConsoleAppenderFactory<>()); + + public WebsocketRequestLog build(String name) { + final Logger logger = (Logger) LoggerFactory.getLogger("websocket.request"); + logger.setAdditive(false); + + final LoggerContext context = logger.getLoggerContext(); + final WebsocketRequestLog requestLog = new WebsocketRequestLog(); + final LevelFilterFactory levelFilterFactory = new NullLevelFilterFactory<>(); + final AsyncAppenderFactory asyncAppenderFactory = new AsyncWebsocketEventAppenderFactory(); + final LayoutFactory layoutFactory = new WebsocketEventLayoutFactory(); + + for (AppenderFactory output : appenders) { + requestLog.addAppender(output.build(context, name, layoutFactory, levelFilterFactory, asyncAppenderFactory)); + } + + return requestLog; + } + +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/WebsocketEventLayout.java b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/WebsocketEventLayout.java new file mode 100644 index 000000000..fd92d3f38 --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/WebsocketEventLayout.java @@ -0,0 +1,77 @@ +package org.whispersystems.websocket.logging.layout; + +import org.whispersystems.websocket.logging.WebsocketEvent; +import org.whispersystems.websocket.logging.layout.converters.ContentLengthConverter; +import org.whispersystems.websocket.logging.layout.converters.DateConverter; +import org.whispersystems.websocket.logging.layout.converters.EnsureLineSeparation; +import org.whispersystems.websocket.logging.layout.converters.NAConverter; +import org.whispersystems.websocket.logging.layout.converters.RemoteHostConverter; +import org.whispersystems.websocket.logging.layout.converters.RequestHeaderConverter; +import org.whispersystems.websocket.logging.layout.converters.RequestUrlConverter; +import org.whispersystems.websocket.logging.layout.converters.StatusCodeConverter; + +import java.util.HashMap; +import java.util.Map; + +import ch.qos.logback.core.Context; +import ch.qos.logback.core.pattern.PatternLayoutBase; + +public class WebsocketEventLayout extends PatternLayoutBase { + + private static final Map DEFAULT_CONVERTERS = new HashMap<>() {{ + put("h", RemoteHostConverter.class.getName()); + put("l", NAConverter.class.getName()); + put("u", NAConverter.class.getName()); + put("t", DateConverter.class.getName()); + put("r", RequestUrlConverter.class.getName()); + put("s", StatusCodeConverter.class.getName()); + put("b", ContentLengthConverter.class.getName()); + put("i", RequestHeaderConverter.class.getName()); + }}; + + public static final String CLF_PATTERN = "%h %l %u [%t] \"%r\" %s %b"; + public static final String CLF_PATTERN_NAME = "common"; + public static final String CLF_PATTERN_NAME_2 = "clf"; + public static final String COMBINED_PATTERN = "%h %l %u [%t] \"%r\" %s %b \"%i{Referer}\" \"%i{User-Agent}\""; + public static final String COMBINED_PATTERN_NAME = "combined"; + public static final String HEADER_PREFIX = "#logback.access pattern: "; + + public WebsocketEventLayout(Context context) { + setOutputPatternAsHeader(false); + setPattern(COMBINED_PATTERN); + setContext(context); + + this.postCompileProcessor = new EnsureLineSeparation(); + } + + @Override + public Map getDefaultConverterMap() { + return DEFAULT_CONVERTERS; + } + + @Override + public String doLayout(WebsocketEvent event) { + if (!isStarted()) { + return null; + } + + return writeLoopOnConverters(event); + } + + @Override + public void start() { + if (getPattern().equalsIgnoreCase(CLF_PATTERN_NAME) || getPattern().equalsIgnoreCase(CLF_PATTERN_NAME_2)) { + setPattern(CLF_PATTERN); + } else if (getPattern().equalsIgnoreCase(COMBINED_PATTERN_NAME)) { + setPattern(COMBINED_PATTERN); + } + + super.start(); + } + + @Override + protected String getPresentationHeaderPrefix() { + return HEADER_PREFIX; + } + +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/WebsocketEventLayoutFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/WebsocketEventLayoutFactory.java new file mode 100644 index 000000000..d6f35b3b4 --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/WebsocketEventLayoutFactory.java @@ -0,0 +1,16 @@ +package org.whispersystems.websocket.logging.layout; + +import org.whispersystems.websocket.logging.WebsocketEvent; + +import java.util.TimeZone; + +import ch.qos.logback.classic.LoggerContext; +import ch.qos.logback.core.pattern.PatternLayoutBase; +import io.dropwizard.logging.layout.LayoutFactory; + +public class WebsocketEventLayoutFactory implements LayoutFactory { + @Override + public PatternLayoutBase build(LoggerContext context, TimeZone timeZone) { + return new WebsocketEventLayout(context); + } +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/ContentLengthConverter.java b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/ContentLengthConverter.java new file mode 100644 index 000000000..b01bcdd85 --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/ContentLengthConverter.java @@ -0,0 +1,14 @@ +package org.whispersystems.websocket.logging.layout.converters; + +import org.whispersystems.websocket.logging.WebsocketEvent; + +public class ContentLengthConverter extends WebSocketEventConverter { + @Override + public String convert(WebsocketEvent event) { + if (event.getContentLength() == WebsocketEvent.SENTINEL) { + return WebsocketEvent.NA; + } else { + return Long.toString(event.getContentLength()); + } + } +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/DateConverter.java b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/DateConverter.java new file mode 100644 index 000000000..baa950d45 --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/DateConverter.java @@ -0,0 +1,51 @@ +package org.whispersystems.websocket.logging.layout.converters; + +import org.whispersystems.websocket.logging.WebsocketEvent; + +import java.util.List; +import java.util.TimeZone; + +import ch.qos.logback.core.CoreConstants; +import ch.qos.logback.core.util.CachingDateFormatter; + +public class DateConverter extends WebSocketEventConverter { + + private CachingDateFormatter cachingDateFormatter = null; + + @Override + public void start() { + + String datePattern = getFirstOption(); + if (datePattern == null) { + datePattern = CoreConstants.CLF_DATE_PATTERN; + } + + if (datePattern.equals(CoreConstants.ISO8601_STR)) { + datePattern = CoreConstants.ISO8601_PATTERN; + } + + try { + cachingDateFormatter = new CachingDateFormatter(datePattern); + // maximumCacheValidity = CachedDateFormat.getMaximumCacheValidity(pattern); + } catch (IllegalArgumentException e) { + addWarn("Could not instantiate SimpleDateFormat with pattern " + datePattern, e); + addWarn("Defaulting to " + CoreConstants.CLF_DATE_PATTERN); + cachingDateFormatter = new CachingDateFormatter(CoreConstants.CLF_DATE_PATTERN); + } + + List optionList = getOptionList(); + + // if the option list contains a TZ option, then set it. + if (optionList != null && optionList.size() > 1) { + TimeZone tz = TimeZone.getTimeZone((String) optionList.get(1)); + cachingDateFormatter.setTimeZone(tz); + } + } + + @Override + public String convert(WebsocketEvent websocketEvent) { + long timestamp = websocketEvent.getTimestamp(); + return cachingDateFormatter.format(timestamp); + } + +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/EnsureLineSeparation.java b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/EnsureLineSeparation.java new file mode 100644 index 000000000..653240d66 --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/EnsureLineSeparation.java @@ -0,0 +1,29 @@ +package org.whispersystems.websocket.logging.layout.converters; + +import org.whispersystems.websocket.logging.WebsocketEvent; + +import ch.qos.logback.core.Context; +import ch.qos.logback.core.pattern.Converter; +import ch.qos.logback.core.pattern.ConverterUtil; +import ch.qos.logback.core.pattern.PostCompileProcessor; + +public class EnsureLineSeparation implements PostCompileProcessor { + + /** + * Add a line separator converter so that access event appears on a separate + * line. + */ + @Override + public void process(Context context, Converter head) { + if (head == null) + throw new IllegalArgumentException("Empty converter chain"); + + // if head != null, then tail != null as well + Converter tail = ConverterUtil.findTail(head); + Converter newLineConverter = new LineSeparatorConverter(); + + if (!(tail instanceof LineSeparatorConverter)) { + tail.setNext(newLineConverter); + } + } +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/LineSeparatorConverter.java b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/LineSeparatorConverter.java new file mode 100644 index 000000000..63dd2673e --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/LineSeparatorConverter.java @@ -0,0 +1,14 @@ +package org.whispersystems.websocket.logging.layout.converters; + +import org.whispersystems.websocket.logging.WebsocketEvent; + +import ch.qos.logback.core.CoreConstants; + +public class LineSeparatorConverter extends WebSocketEventConverter { + public LineSeparatorConverter() { + } + + public String convert(WebsocketEvent event) { + return CoreConstants.LINE_SEPARATOR; + } +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/NAConverter.java b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/NAConverter.java new file mode 100644 index 000000000..fe675159e --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/NAConverter.java @@ -0,0 +1,10 @@ +package org.whispersystems.websocket.logging.layout.converters; + +import org.whispersystems.websocket.logging.WebsocketEvent; + +public class NAConverter extends WebSocketEventConverter { + @Override + public String convert(WebsocketEvent event) { + return WebsocketEvent.NA; + } +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/RemoteHostConverter.java b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/RemoteHostConverter.java new file mode 100644 index 000000000..c3f3b2eee --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/RemoteHostConverter.java @@ -0,0 +1,10 @@ +package org.whispersystems.websocket.logging.layout.converters; + +import org.whispersystems.websocket.logging.WebsocketEvent; + +public class RemoteHostConverter extends WebSocketEventConverter { + @Override + public String convert(WebsocketEvent event) { + return event.getRemoteHost(); + } +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/RequestHeaderConverter.java b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/RequestHeaderConverter.java new file mode 100644 index 000000000..6bb4957bd --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/RequestHeaderConverter.java @@ -0,0 +1,33 @@ +package org.whispersystems.websocket.logging.layout.converters; + +import org.whispersystems.websocket.logging.WebsocketEvent; + +import ch.qos.logback.core.util.OptionHelper; + +public class RequestHeaderConverter extends WebSocketEventConverter { + + private String key; + + @Override + public void start() { + key = getFirstOption(); + if (OptionHelper.isEmpty(key)) { + addWarn("Missing key for the requested header. Defaulting to all keys."); + key = null; + } + super.start(); + } + + @Override + public String convert(WebsocketEvent websocketEvent) { + if (!isStarted()) { + return "INACTIVE_HEADER_CONV"; + } + + if (key != null) { + return websocketEvent.getRequestHeader(key); + } else { + return websocketEvent.getRequestHeaderMap().toString(); + } + } +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/RequestUrlConverter.java b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/RequestUrlConverter.java new file mode 100644 index 000000000..046a8a8da --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/RequestUrlConverter.java @@ -0,0 +1,15 @@ +package org.whispersystems.websocket.logging.layout.converters; + +import org.whispersystems.websocket.logging.WebsocketEvent; + +public class RequestUrlConverter extends WebSocketEventConverter { + @Override + public String convert(WebsocketEvent event) { + return + event.getMethod() + + WebSocketEventConverter.SPACE_CHAR + + event.getPath() + + WebSocketEventConverter.SPACE_CHAR + + event.getProtocol(); + } +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/StatusCodeConverter.java b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/StatusCodeConverter.java new file mode 100644 index 000000000..bde399fcf --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/StatusCodeConverter.java @@ -0,0 +1,14 @@ +package org.whispersystems.websocket.logging.layout.converters; + +import org.whispersystems.websocket.logging.WebsocketEvent; + +public class StatusCodeConverter extends WebSocketEventConverter { + @Override + public String convert(WebsocketEvent event) { + if (event.getStatusCode() == WebsocketEvent.SENTINEL) { + return WebsocketEvent.NA; + } else { + return Integer.toString(event.getStatusCode()); + } + } +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/WebSocketEventConverter.java b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/WebSocketEventConverter.java new file mode 100644 index 000000000..a40214d7f --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/logging/layout/converters/WebSocketEventConverter.java @@ -0,0 +1,63 @@ +package org.whispersystems.websocket.logging.layout.converters; + +import org.whispersystems.websocket.logging.WebsocketEvent; + +import ch.qos.logback.core.Context; +import ch.qos.logback.core.pattern.DynamicConverter; +import ch.qos.logback.core.spi.ContextAware; +import ch.qos.logback.core.spi.ContextAwareBase; +import ch.qos.logback.core.status.Status; + +public abstract class WebSocketEventConverter extends DynamicConverter implements ContextAware { + + public final static char SPACE_CHAR = ' '; + public final static char QUESTION_CHAR = '?'; + + ContextAwareBase cab = new ContextAwareBase(); + + @Override + public void setContext(Context context) { + cab.setContext(context); + } + + @Override + public Context getContext() { + return cab.getContext(); + } + + @Override + public void addStatus(Status status) { + cab.addStatus(status); + } + + @Override + public void addInfo(String msg) { + cab.addInfo(msg); + } + + @Override + public void addInfo(String msg, Throwable ex) { + cab.addInfo(msg, ex); + } + + @Override + public void addWarn(String msg) { + cab.addWarn(msg); + } + + @Override + public void addWarn(String msg, Throwable ex) { + cab.addWarn(msg, ex); + } + + @Override + public void addError(String msg) { + cab.addError(msg); + } + + @Override + public void addError(String msg, Throwable ex) { + cab.addError(msg, ex); + } + +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/BufferingServletInputStream.java b/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/BufferingServletInputStream.java deleted file mode 100644 index c35070857..000000000 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/BufferingServletInputStream.java +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright (C) 2014 Open WhisperSystems - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ -package org.whispersystems.websocket.servlet; - -import javax.servlet.ReadListener; -import javax.servlet.ServletInputStream; -import java.io.ByteArrayInputStream; -import java.io.IOException; - -public class BufferingServletInputStream extends ServletInputStream { - - private final ByteArrayInputStream buffer; - - public BufferingServletInputStream(byte[] body) { - this.buffer = new ByteArrayInputStream(body); - } - - @Override - public int read(byte[] buf, int offset, int length) { - return buffer.read(buf, offset, length); - } - - @Override - public int read(byte[] buf) { - return read(buf, 0, buf.length); - } - - @Override - public int read() throws IOException { - return buffer.read(); - } - - @Override - public int available() { - return buffer.available(); - } - - @Override - public boolean isFinished() { - return available() > 0; - } - - @Override - public boolean isReady() { - return true; - } - - @Override - public void setReadListener(ReadListener readListener) { - - } -} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/BufferingServletOutputStream.java b/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/BufferingServletOutputStream.java deleted file mode 100644 index 4956da977..000000000 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/BufferingServletOutputStream.java +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright (C) 2014 Open WhisperSystems - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ -package org.whispersystems.websocket.servlet; - -import javax.servlet.ServletOutputStream; -import javax.servlet.WriteListener; -import java.io.ByteArrayOutputStream; -import java.io.IOException; - -public class BufferingServletOutputStream extends ServletOutputStream { - - private final ByteArrayOutputStream buffer; - - public BufferingServletOutputStream(ByteArrayOutputStream buffer) { - this.buffer = buffer; - } - - @Override - public void write(byte[] buf, int offset, int length) { - buffer.write(buf, offset, length); - } - - @Override - public void write(byte[] buf) { - write(buf, 0, buf.length); - } - - @Override - public void write(int b) throws IOException { - buffer.write(b); - } - - @Override - public void flush() { - - } - - @Override - public void close() { - - } - - @Override - public boolean isReady() { - return true; - } - - @Override - public void setWriteListener(WriteListener writeListener) { - - } -} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/LoggableRequest.java b/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/LoggableRequest.java deleted file mode 100644 index 5728b3388..000000000 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/LoggableRequest.java +++ /dev/null @@ -1,629 +0,0 @@ -package org.whispersystems.websocket.servlet; - -import org.eclipse.jetty.http.HttpFields; -import org.eclipse.jetty.http.HttpURI; -import org.eclipse.jetty.http.HttpVersion; -import org.eclipse.jetty.server.Authentication; -import org.eclipse.jetty.server.HttpChannel; -import org.eclipse.jetty.server.HttpChannelState; -import org.eclipse.jetty.server.HttpInput; -import org.eclipse.jetty.server.Request; -import org.eclipse.jetty.server.Response; -import org.eclipse.jetty.server.UserIdentity; -import org.eclipse.jetty.server.handler.ContextHandler; -import org.eclipse.jetty.util.Attributes; - -import javax.servlet.AsyncContext; -import javax.servlet.DispatcherType; -import javax.servlet.RequestDispatcher; -import javax.servlet.ServletContext; -import javax.servlet.ServletException; -import javax.servlet.ServletInputStream; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; -import javax.servlet.http.Cookie; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import javax.servlet.http.HttpSession; -import javax.servlet.http.Part; -import java.io.BufferedReader; -import java.io.IOException; -import java.io.UnsupportedEncodingException; -import java.net.InetSocketAddress; -import java.security.Principal; -import java.util.Collection; -import java.util.Enumeration; -import java.util.EventListener; -import java.util.Locale; -import java.util.Map; - -public class LoggableRequest extends Request { - - private final HttpServletRequest request; - - public LoggableRequest(HttpServletRequest request) { - super(null, null); - this.request = request; - } - - @Override - public HttpFields getHttpFields() { - throw new AssertionError(); - } - - @Override - public HttpInput getHttpInput() { - throw new AssertionError(); - } - - @Override - public void addEventListener(EventListener listener) { - throw new AssertionError(); - } - - @Override - public AsyncContext getAsyncContext() { - throw new AssertionError(); - } - - @Override - public HttpChannelState getHttpChannelState() { - throw new AssertionError(); - } - - @Override - public Object getAttribute(String name) { - return request.getAttribute(name); - } - - @Override - public Enumeration getAttributeNames() { - return request.getAttributeNames(); - } - - @Override - public Attributes getAttributes() { - throw new AssertionError(); - } - - @Override - public Authentication getAuthentication() { - return null; - } - - @Override - public String getAuthType() { - return request.getAuthType(); - } - - @Override - public String getCharacterEncoding() { - return request.getCharacterEncoding(); - } - - @Override - public HttpChannel getHttpChannel() { - throw new AssertionError(); - } - - @Override - public int getContentLength() { - return request.getContentLength(); - } - - @Override - public String getContentType() { - return request.getContentType(); - } - - @Override - public ContextHandler.Context getContext() { - throw new AssertionError(); - } - - @Override - public String getContextPath() { - return request.getContextPath(); - } - - @Override - public Cookie[] getCookies() { - return request.getCookies(); - } - - @Override - public long getDateHeader(String name) { - return request.getDateHeader(name); - } - - @Override - public DispatcherType getDispatcherType() { - return request.getDispatcherType(); - } - - @Override - public String getHeader(String name) { - return request.getHeader(name); - } - - @Override - public Enumeration getHeaderNames() { - return request.getHeaderNames(); - } - - @Override - public Enumeration getHeaders(String name) { - return request.getHeaders(name); - } - - @Override - public int getInputState() { - throw new AssertionError(); - } - - @Override - public ServletInputStream getInputStream() throws IOException { - return request.getInputStream(); - } - - @Override - public int getIntHeader(String name) { - return request.getIntHeader(name); - } - - @Override - public Locale getLocale() { - return request.getLocale(); - } - - @Override - public Enumeration getLocales() { - return request.getLocales(); - } - - @Override - public String getLocalAddr() { - return request.getLocalAddr(); - } - - @Override - public String getLocalName() { - return request.getLocalName(); - } - - @Override - public int getLocalPort() { - return request.getLocalPort(); - } - - @Override - public String getMethod() { - return request.getMethod(); - } - - @Override - public String getParameter(String name) { - return request.getParameter(name); - } - - @Override - public Map getParameterMap() { - return request.getParameterMap(); - } - - @Override - public Enumeration getParameterNames() { - return request.getParameterNames(); - } - - @Override - public String[] getParameterValues(String name) { - return request.getParameterValues(name); - } - - @Override - public String getPathInfo() { - return request.getPathInfo(); - } - - @Override - public String getPathTranslated() { - return request.getPathTranslated(); - } - - @Override - public String getProtocol() { - return request.getProtocol(); - } - - @Override - public HttpVersion getHttpVersion() { - throw new AssertionError(); - } - - @Override - public String getQueryEncoding() { - throw new AssertionError(); - } - - @Override - public String getQueryString() { - return request.getQueryString(); - } - - @Override - public BufferedReader getReader() throws IOException { - throw new AssertionError(); - } - - @Override - public String getRealPath(String path) { - return request.getRealPath(path); - } - - @Override - public String getRemoteAddr() { - return request.getRemoteAddr(); - } - - @Override - public String getRemoteHost() { - return request.getRemoteHost(); - } - - @Override - public int getRemotePort() { - return request.getRemotePort(); - } - - @Override - public String getRemoteUser() { - return request.getRemoteUser(); - } - - @Override - public RequestDispatcher getRequestDispatcher(String path) { - return request.getRequestDispatcher(path); - } - - @Override - public String getRequestedSessionId() { - return request.getRequestedSessionId(); - } - - @Override - public String getRequestURI() { - return request.getRequestURI(); - } - - @Override - public StringBuffer getRequestURL() { - return request.getRequestURL(); - } - - @Override - public Response getResponse() { - throw new AssertionError(); - } - - @Override - public StringBuilder getRootURL() { - throw new AssertionError(); - } - - @Override - public String getScheme() { - return request.getScheme(); - } - - @Override - public String getServerName() { - return request.getServerName(); - } - - @Override - public int getServerPort() { - return request.getServerPort(); - } - - @Override - public ServletContext getServletContext() { - return request.getServletContext(); - } - - @Override - public String getServletName() { - throw new AssertionError(); - } - - @Override - public String getServletPath() { - return request.getServletPath(); - } - - @Override - public ServletResponse getServletResponse() { - throw new AssertionError(); - } - - @Override - public String changeSessionId() { - throw new AssertionError(); - } - - @Override - public HttpSession getSession() { - return request.getSession(); - } - - @Override - public HttpSession getSession(boolean create) { - return request.getSession(create); - } - - @Override - public long getTimeStamp() { - return System.currentTimeMillis(); - } - - @Override - public HttpURI getHttpURI() { - return new HttpURI(getRequestURI()); - } - - @Override - public UserIdentity getUserIdentity() { - throw new AssertionError(); - } - - @Override - public UserIdentity getResolvedUserIdentity() { - throw new AssertionError(); - } - - @Override - public UserIdentity.Scope getUserIdentityScope() { - throw new AssertionError(); - } - - @Override - public Principal getUserPrincipal() { - throw new AssertionError(); - } - - @Override - public boolean isHandled() { - throw new AssertionError(); - } - - @Override - public boolean isAsyncStarted() { - return request.isAsyncStarted(); - } - - @Override - public boolean isAsyncSupported() { - return request.isAsyncSupported(); - } - - @Override - public boolean isRequestedSessionIdFromCookie() { - return request.isRequestedSessionIdFromCookie(); - } - - @Override - public boolean isRequestedSessionIdFromUrl() { - return request.isRequestedSessionIdFromUrl(); - } - - @Override - public boolean isRequestedSessionIdFromURL() { - return request.isRequestedSessionIdFromURL(); - } - - @Override - public boolean isRequestedSessionIdValid() { - return request.isRequestedSessionIdValid(); - } - - @Override - public boolean isSecure() { - return request.isSecure(); - } - - @Override - public void setSecure(boolean secure) { - throw new AssertionError(); - } - - @Override - public boolean isUserInRole(String role) { - return request.isUserInRole(role); - } - - @Override - public void removeAttribute(String name) { - request.removeAttribute(name); - } - - @Override - public void removeEventListener(EventListener listener) { - throw new AssertionError(); - } - - @Override - public void setAsyncSupported(boolean supported, String source) { - throw new AssertionError(); - } - - @Override - public void setAttribute(String name, Object value) { - throw new AssertionError(); - } - - @Override - public void setAttributes(Attributes attributes) { - throw new AssertionError(); - } - - @Override - public void setAuthentication(Authentication authentication) { - throw new AssertionError(); - } - - @Override - public void setCharacterEncoding(String encoding) throws UnsupportedEncodingException { - throw new AssertionError(); - } - - @Override - public void setCharacterEncodingUnchecked(String encoding) { - throw new AssertionError(); - } - - @Override - public void setContentType(String contentType) { - throw new AssertionError(); - } - - @Override - public void setContext(ContextHandler.Context context) { - throw new AssertionError(); - } - - @Override - public boolean takeNewContext() { - throw new AssertionError(); - } - - @Override - public void setContextPath(String contextPath) { - throw new AssertionError(); - } - - @Override - public void setCookies(Cookie[] cookies) { - throw new AssertionError(); - } - - @Override - public void setDispatcherType(DispatcherType type) { - throw new AssertionError(); - } - - @Override - public void setHandled(boolean h) { - throw new AssertionError(); - } - - @Override - public boolean isHead() { - throw new AssertionError(); - } - - @Override - public void setPathInfo(String pathInfo) { - throw new AssertionError(); - } - - @Override - public void setHttpVersion(HttpVersion version) { - throw new AssertionError(); - } - - @Override - public void setQueryEncoding(String queryEncoding) { - throw new AssertionError(); - } - - @Override - public void setQueryString(String queryString) { - throw new AssertionError(); - } - - @Override - public void setRemoteAddr(InetSocketAddress addr) { - throw new AssertionError(); - } - - @Override - public void setRequestedSessionId(String requestedSessionId) { - throw new AssertionError(); - } - - @Override - public void setRequestedSessionIdFromCookie(boolean requestedSessionIdCookie) { - throw new AssertionError(); - } - - @Override - public void setScheme(String scheme) { - throw new AssertionError(); - } - - @Override - public void setServletPath(String servletPath) { - throw new AssertionError(); - } - - @Override - public void setSession(HttpSession session) { - throw new AssertionError(); - } - - @Override - public void setTimeStamp(long ts) { - throw new AssertionError(); - } - - @Override - public void setHttpURI(HttpURI uri) { - throw new AssertionError(); - } - - @Override - public void setUserIdentityScope(UserIdentity.Scope scope) { - throw new AssertionError(); - } - - @Override - public AsyncContext startAsync() throws IllegalStateException { - throw new AssertionError(); - } - - @Override - public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException { - throw new AssertionError(); - } - - @Override - public String toString() { - return request.toString(); - } - - @Override - public boolean authenticate(HttpServletResponse response) throws IOException, ServletException { - throw new AssertionError(); - } - - @Override - public Part getPart(String name) throws IOException, ServletException { - return request.getPart(name); - } - - @Override - public Collection getParts() throws IOException, ServletException { - return request.getParts(); - } - - @Override - public void login(String username, String password) throws ServletException { - throw new AssertionError(); - } - - @Override - public void logout() throws ServletException { - throw new AssertionError(); - } - -} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/LoggableResponse.java b/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/LoggableResponse.java deleted file mode 100644 index be1dbfb9a..000000000 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/LoggableResponse.java +++ /dev/null @@ -1,449 +0,0 @@ -package org.whispersystems.websocket.servlet; - -import org.eclipse.jetty.http.HttpContent; -import org.eclipse.jetty.http.HttpCookie; -import org.eclipse.jetty.http.HttpFields; -import org.eclipse.jetty.http.HttpHeader; -import org.eclipse.jetty.http.HttpVersion; -import org.eclipse.jetty.http.MetaData; -import org.eclipse.jetty.io.Connection; -import org.eclipse.jetty.io.EndPoint; -import org.eclipse.jetty.server.Connector; -import org.eclipse.jetty.server.HttpChannel; -import org.eclipse.jetty.server.HttpConfiguration; -import org.eclipse.jetty.server.HttpOutput; -import org.eclipse.jetty.server.HttpTransport; -import org.eclipse.jetty.server.Response; -import org.eclipse.jetty.util.Callback; - -import javax.servlet.ServletOutputStream; -import javax.servlet.http.Cookie; -import javax.servlet.http.HttpServletResponse; -import java.io.IOException; -import java.io.PrintWriter; -import java.net.InetSocketAddress; -import java.nio.ByteBuffer; -import java.nio.channels.ReadPendingException; -import java.nio.channels.WritePendingException; -import java.util.Collection; -import java.util.Locale; - -public class LoggableResponse extends Response { - - private final HttpServletResponse response; - - public LoggableResponse(HttpServletResponse response) { - super(null, null); - this.response = response; - } - - @Override - public void putHeaders(HttpContent httpContent, long contentLength, boolean etag) { - throw new AssertionError(); - } - - @Override - public HttpOutput getHttpOutput() { - throw new AssertionError(); - } - - @Override - public boolean isIncluding() { - throw new AssertionError(); - } - - @Override - public void include() { - throw new AssertionError(); - } - - @Override - public void included() { - throw new AssertionError(); - } - - @Override - public void addCookie(HttpCookie cookie) { - throw new AssertionError(); - } - - @Override - public void addCookie(Cookie cookie) { - throw new AssertionError(); - } - - @Override - public boolean containsHeader(String name) { - return response.containsHeader(name); - } - - @Override - public String encodeURL(String url) { - return response.encodeURL(url); - } - - @Override - public String encodeRedirectURL(String url) { - return response.encodeRedirectURL(url); - } - - @Override - public String encodeUrl(String url) { - return response.encodeUrl(url); - } - - @Override - public String encodeRedirectUrl(String url) { - return response.encodeRedirectUrl(url); - } - - @Override - public void sendError(int sc) throws IOException { - throw new AssertionError(); - } - - @Override - public void sendError(int code, String message) throws IOException { - throw new AssertionError(); - } - - @Override - public void sendProcessing() throws IOException { - throw new AssertionError(); - } - - @Override - public void sendRedirect(String location) throws IOException { - throw new AssertionError(); - } - - @Override - public void setDateHeader(String name, long date) { - throw new AssertionError(); - } - - @Override - public void addDateHeader(String name, long date) { - throw new AssertionError(); - } - - @Override - public void setHeader(HttpHeader name, String value) { - throw new AssertionError(); - } - - @Override - public void setHeader(String name, String value) { - throw new AssertionError(); - } - - @Override - public Collection getHeaderNames() { - return response.getHeaderNames(); - } - - @Override - public String getHeader(String name) { - return response.getHeader(name); - } - - @Override - public Collection getHeaders(String name) { - return response.getHeaders(name); - } - - @Override - public void addHeader(String name, String value) { - throw new AssertionError(); - } - - @Override - public void setIntHeader(String name, int value) { - throw new AssertionError(); - } - - @Override - public void addIntHeader(String name, int value) { - throw new AssertionError(); - } - - @Override - public void setStatus(int sc) { - throw new AssertionError(); - } - - @Override - public void setStatus(int sc, String sm) { - throw new AssertionError(); - } - - @Override - public void setStatusWithReason(int sc, String sm) { - throw new AssertionError(); - } - - @Override - public String getCharacterEncoding() { - return response.getCharacterEncoding(); - } - - @Override - public String getContentType() { - return response.getContentType(); - } - - @Override - public ServletOutputStream getOutputStream() throws IOException { - throw new AssertionError(); - } - - @Override - public boolean isWriting() { - throw new AssertionError(); - } - - @Override - public PrintWriter getWriter() throws IOException { - throw new AssertionError(); - } - - @Override - public void setContentLength(int len) { - throw new AssertionError(); - } - - @Override - public boolean isAllContentWritten(long written) { - throw new AssertionError(); - } - - @Override - public void closeOutput() throws IOException { - throw new AssertionError(); - } - - @Override - public long getLongContentLength() { - return response.getBufferSize(); - } - - @Override - public void setLongContentLength(long len) { - throw new AssertionError(); - } - - @Override - public void setCharacterEncoding(String encoding) { - throw new AssertionError(); - } - - @Override - public void setContentType(String contentType) { - throw new AssertionError(); - } - - @Override - public void setBufferSize(int size) { - throw new AssertionError(); - } - - @Override - public int getBufferSize() { - return response.getBufferSize(); - } - - @Override - public void flushBuffer() throws IOException { - throw new AssertionError(); - } - - @Override - public void reset() { - throw new AssertionError(); - } - - @Override - public void reset(boolean preserveCookies) { - throw new AssertionError(); - } - - @Override - public void resetForForward() { - throw new AssertionError(); - } - - @Override - public void resetBuffer() { - throw new AssertionError(); - } - - @Override - public boolean isCommitted() { - throw new AssertionError(); - } - - @Override - public void setLocale(Locale locale) { - throw new AssertionError(); - } - - @Override - public Locale getLocale() { - return response.getLocale(); - } - - @Override - public int getStatus() { - return response.getStatus(); - } - - @Override - public String getReason() { - throw new AssertionError(); - } - - @Override - public HttpFields getHttpFields() { - return new HttpFields(); - } - - @Override - public long getContentCount() { - return 0; - } - - @Override - public String toString() { - return response.toString(); - } - - @Override - public MetaData.Response getCommittedMetaData() { - return new MetaData.Response(HttpVersion.HTTP_2, getStatus(), null); - } - - @Override - public HttpChannel getHttpChannel() - { - return new HttpChannel(null, new HttpConfiguration(), new NullEndPoint(), null); - } - - private static class NullEndPoint implements EndPoint { - - @Override - public InetSocketAddress getLocalAddress() { - return null; - } - - @Override - public InetSocketAddress getRemoteAddress() { - return null; - } - - @Override - public boolean isOpen() { - return false; - } - - @Override - public long getCreatedTimeStamp() { - return 0; - } - - @Override - public void shutdownOutput() { - - } - - @Override - public boolean isOutputShutdown() { - return false; - } - - @Override - public boolean isInputShutdown() { - return false; - } - - @Override - public void close() { - - } - - @Override - public int fill(ByteBuffer buffer) throws IOException { - return 0; - } - - @Override - public boolean flush(ByteBuffer... buffer) throws IOException { - return false; - } - - @Override - public Object getTransport() { - return null; - } - - @Override - public long getIdleTimeout() { - return 0; - } - - @Override - public void setIdleTimeout(long idleTimeout) { - - } - - @Override - public void fillInterested(Callback callback) throws ReadPendingException { - - } - - @Override - public boolean tryFillInterested(Callback callback) { - return false; - } - - @Override - public boolean isFillInterested() { - return false; - } - - @Override - public void write(Callback callback, ByteBuffer... buffers) throws WritePendingException { - - } - - @Override - public Connection getConnection() { - return null; - } - - @Override - public void setConnection(Connection connection) { - - } - - @Override - public void onOpen() { - - } - - @Override - public void onClose() { - - } - - @Override - public boolean isOptimizedForDirectBuffers() { - return false; - } - - @Override - public void upgrade(Connection newConnection) { - - } - } - -} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/NullServletOutputStream.java b/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/NullServletOutputStream.java deleted file mode 100644 index 7e3a856cd..000000000 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/NullServletOutputStream.java +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright (C) 2014 Open WhisperSystems - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ -package org.whispersystems.websocket.servlet; - -import javax.servlet.ServletOutputStream; -import javax.servlet.WriteListener; -import java.io.IOException; - -public class NullServletOutputStream extends ServletOutputStream { - @Override - public void write(int b) throws IOException {} - - @Override - public void write(byte[] buf) {} - - @Override - public void write(byte[] buf, int offset, int len) {} - - @Override - public boolean isReady() { - return false; - } - - @Override - public void setWriteListener(WriteListener writeListener) { - - } -} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/NullServletResponse.java b/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/NullServletResponse.java deleted file mode 100644 index 6783d9307..000000000 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/NullServletResponse.java +++ /dev/null @@ -1,171 +0,0 @@ -/** - * Copyright (C) 2014 Open WhisperSystems - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ -package org.whispersystems.websocket.servlet; - -import javax.servlet.ServletOutputStream; -import javax.servlet.http.Cookie; -import javax.servlet.http.HttpServletResponse; -import java.io.IOException; -import java.io.PrintWriter; -import java.util.Collection; -import java.util.LinkedList; -import java.util.Locale; - -public class NullServletResponse implements HttpServletResponse { - @Override - public void addCookie(Cookie cookie) {} - - @Override - public boolean containsHeader(String name) { - return false; - } - - @Override - public String encodeURL(String url) { - return url; - } - - @Override - public String encodeRedirectURL(String url) { - return url; - } - - @Override - public String encodeUrl(String url) { - return url; - } - - @Override - public String encodeRedirectUrl(String url) { - return url; - } - - @Override - public void sendError(int sc, String msg) throws IOException {} - - @Override - public void sendError(int sc) throws IOException {} - - @Override - public void sendRedirect(String location) throws IOException {} - - @Override - public void setDateHeader(String name, long date) {} - - @Override - public void addDateHeader(String name, long date) {} - - @Override - public void setHeader(String name, String value) {} - - @Override - public void addHeader(String name, String value) {} - - @Override - public void setIntHeader(String name, int value) {} - - @Override - public void addIntHeader(String name, int value) {} - - @Override - public void setStatus(int sc) {} - - @Override - public void setStatus(int sc, String sm) {} - - @Override - public int getStatus() { - return 200; - } - - @Override - public String getHeader(String name) { - return null; - } - - @Override - public Collection getHeaders(String name) { - return new LinkedList<>(); - } - - @Override - public Collection getHeaderNames() { - return new LinkedList<>(); - } - - @Override - public String getCharacterEncoding() { - return "UTF-8"; - } - - @Override - public String getContentType() { - return null; - } - - @Override - public ServletOutputStream getOutputStream() throws IOException { - return new NullServletOutputStream(); - } - - @Override - public PrintWriter getWriter() throws IOException { - return new PrintWriter(new NullServletOutputStream()); - } - - @Override - public void setCharacterEncoding(String charset) {} - - @Override - public void setContentLength(int len) {} - - @Override - public void setContentLengthLong(long len) {} - - @Override - public void setContentType(String type) {} - - @Override - public void setBufferSize(int size) {} - - @Override - public int getBufferSize() { - return 0; - } - - @Override - public void flushBuffer() throws IOException {} - - @Override - public void resetBuffer() {} - - @Override - public boolean isCommitted() { - return true; - } - - @Override - public void reset() {} - - @Override - public void setLocale(Locale loc) {} - - @Override - public Locale getLocale() { - return Locale.US; - } -} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/WebSocketServletRequest.java b/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/WebSocketServletRequest.java deleted file mode 100644 index 014b514ac..000000000 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/WebSocketServletRequest.java +++ /dev/null @@ -1,506 +0,0 @@ -/** - * Copyright (C) 2014 Open WhisperSystems - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ -package org.whispersystems.websocket.servlet; - -import org.whispersystems.websocket.messages.WebSocketRequestMessage; -import org.whispersystems.websocket.session.WebSocketSessionContext; - -import javax.servlet.AsyncContext; -import javax.servlet.DispatcherType; -import javax.servlet.RequestDispatcher; -import javax.servlet.ServletContext; -import javax.servlet.ServletException; -import javax.servlet.ServletInputStream; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; -import javax.servlet.http.Cookie; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import javax.servlet.http.HttpSession; -import javax.servlet.http.HttpUpgradeHandler; -import javax.servlet.http.Part; -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStreamReader; -import java.io.UnsupportedEncodingException; -import java.security.Principal; -import java.util.Collection; -import java.util.Enumeration; -import java.util.HashMap; -import java.util.LinkedList; -import java.util.Locale; -import java.util.Map; -import java.util.Vector; - - -public class WebSocketServletRequest implements HttpServletRequest { - - private final Map headers = new HashMap<>(); - private final Map attributes = new HashMap<>(); - - private final WebSocketRequestMessage requestMessage; - private final ServletInputStream inputStream; - private final ServletContext servletContext; - private final WebSocketSessionContext sessionContext; - - public WebSocketServletRequest(WebSocketSessionContext sessionContext, - WebSocketRequestMessage requestMessage, - ServletContext servletContext) - { - this.requestMessage = requestMessage; - this.servletContext = servletContext; - this.sessionContext = sessionContext; - - if (requestMessage.getBody().isPresent()) { - inputStream = new BufferingServletInputStream(requestMessage.getBody().get()); - } else { - inputStream = new BufferingServletInputStream(new byte[0]); - } - - headers.putAll(requestMessage.getHeaders()); - } - - @Override - public String getAuthType() { - return BASIC_AUTH; - } - - @Override - public Cookie[] getCookies() { - return new Cookie[0]; - } - - @Override - public long getDateHeader(String name) { - return -1; - } - - @Override - public String getHeader(String name) { - return headers.get(name.toLowerCase()); - } - - @Override - public Enumeration getHeaders(String name) { - String header = this.headers.get(name.toLowerCase()); - Vector results = new Vector<>(); - - if (header != null) { - results.add(header); - } - - return results.elements(); - } - - @Override - public Enumeration getHeaderNames() { - return new Vector<>(headers.keySet()).elements(); - } - - @Override - public int getIntHeader(String name) { - return -1; - } - - @Override - public String getMethod() { - return requestMessage.getVerb(); - } - - @Override - public String getPathInfo() { - return requestMessage.getPath(); - } - - @Override - public String getPathTranslated() { - return requestMessage.getPath(); - } - - @Override - public String getContextPath() { - return ""; - } - - @Override - public String getQueryString() { - if (requestMessage.getPath().contains("?")) { - return requestMessage.getPath().substring(requestMessage.getPath().indexOf("?") + 1); - } - - return null; - } - - @Override - public String getRemoteUser() { - return null; - } - - @Override - public boolean isUserInRole(String role) { - return false; - } - - @Override - public Principal getUserPrincipal() { - return new ContextPrincipal(sessionContext); - } - - @Override - public String getRequestedSessionId() { - return null; - } - - @Override - public String getRequestURI() { - if (requestMessage.getPath().contains("?")) { - return requestMessage.getPath().substring(0, requestMessage.getPath().indexOf("?")); - } else { - return requestMessage.getPath(); - } - } - - @Override - public StringBuffer getRequestURL() { - StringBuffer stringBuffer = new StringBuffer(); - stringBuffer.append("http://websocket"); - stringBuffer.append(getRequestURI()); - - return stringBuffer; - } - - @Override - public String getServletPath() { - return ""; - } - - @Override - public HttpSession getSession(boolean create) { - return null; - } - - @Override - public HttpSession getSession() { - return null; - } - - @Override - public String changeSessionId() { - return null; - } - - @Override - public boolean isRequestedSessionIdValid() { - return false; - } - - @Override - public boolean isRequestedSessionIdFromCookie() { - return false; - } - - @Override - public boolean isRequestedSessionIdFromURL() { - return false; - } - - @Override - public boolean isRequestedSessionIdFromUrl() { - return false; - } - - @Override - public boolean authenticate(HttpServletResponse response) throws IOException, ServletException { - return false; - } - - @Override - public void login(String username, String password) throws ServletException { - - } - - @Override - public void logout() throws ServletException { - - } - - @Override - public Collection getParts() throws IOException, ServletException { - return new LinkedList<>(); - } - - @Override - public Part getPart(String name) throws IOException, ServletException { - return null; - } - - @Override - public T upgrade(Class handlerClass) throws IOException, ServletException { - return null; - } - - @Override - public Object getAttribute(String name) { - return attributes.get(name); - } - - @Override - public Enumeration getAttributeNames() { - return new Vector<>(attributes.keySet()).elements(); - } - - @Override - public String getCharacterEncoding() { - return null; - } - - @Override - public void setCharacterEncoding(String env) throws UnsupportedEncodingException {} - - @Override - public int getContentLength() { - if (requestMessage.getBody().isPresent()) { - return requestMessage.getBody().get().length; - } else { - return 0; - } - } - - @Override - public long getContentLengthLong() { - return getContentLength(); - } - - @Override - public String getContentType() { - if (requestMessage.getBody().isPresent()) { - return "application/json"; - } else { - return null; - } - } - - @Override - public ServletInputStream getInputStream() throws IOException { - return inputStream; - } - - @Override - public String getParameter(String name) { - String[] result = getParameterMap().get(name); - - if (result != null && result.length > 0) { - return result[0]; - } - - return null; - } - - @Override - public Enumeration getParameterNames() { - return new Vector<>(getParameterMap().keySet()).elements(); - } - - @Override - public String[] getParameterValues(String name) { - return getParameterMap().get(name); - } - - @Override - public Map getParameterMap() { - Map parameterMap = new HashMap<>(); - String queryParameters = getQueryString(); - - if (queryParameters == null) { - return parameterMap; - } - - String[] tokens = queryParameters.split("&"); - - for (String token : tokens) { - String[] parts = token.split("="); - - if (parts != null && parts.length > 1) { - parameterMap.put(parts[0], new String[] {parts[1]}); - } - } - - return parameterMap; - } - - @Override - public String getProtocol() { - return "HTTP/1.0"; - } - - @Override - public String getScheme() { - return "http"; - } - - @Override - public String getServerName() { - return "websocket"; - } - - @Override - public int getServerPort() { - return 8080; - } - - @Override - public BufferedReader getReader() throws IOException { - return new BufferedReader(new InputStreamReader(inputStream)); - } - - @Override - public String getRemoteAddr() { - return "127.0.0.1"; - } - - @Override - public String getRemoteHost() { - return "localhost"; - } - - @Override - public void setAttribute(String name, Object o) { - if (o != null) attributes.put(name, o); - else removeAttribute(name); - } - - @Override - public void removeAttribute(String name) { - attributes.remove(name); - } - - @Override - public Locale getLocale() { - return Locale.US; - } - - @Override - public Enumeration getLocales() { - Vector results = new Vector<>(); - results.add(getLocale()); - return results.elements(); - } - - @Override - public boolean isSecure() { - return false; - } - - @Override - public RequestDispatcher getRequestDispatcher(String path) { - return servletContext.getRequestDispatcher(path); - } - - @Override - public String getRealPath(String path) { - return path; - } - - @Override - public int getRemotePort() { - return 31337; - } - - @Override - public String getLocalName() { - return "localhost"; - } - - @Override - public String getLocalAddr() { - return "127.0.0.1"; - } - - @Override - public int getLocalPort() { - return 8080; - } - - @Override - public ServletContext getServletContext() { - return servletContext; - } - - @Override - public AsyncContext startAsync() throws IllegalStateException { - throw new AssertionError("nyi"); - } - - @Override - public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException { - throw new AssertionError("nyi"); - } - - @Override - public boolean isAsyncStarted() { - return false; - } - - @Override - public boolean isAsyncSupported() { - return false; - } - - @Override - public AsyncContext getAsyncContext() { - return null; - } - - @Override - public DispatcherType getDispatcherType() { - return DispatcherType.REQUEST; - } - - public static class ContextPrincipal implements Principal { - - private final WebSocketSessionContext context; - - public ContextPrincipal(WebSocketSessionContext context) { - this.context = context; - } - - @Override - public boolean equals(Object another) { - return another instanceof ContextPrincipal && - context.equals(((ContextPrincipal) another).context); - } - - @Override - public String toString() { - return super.toString(); - } - - @Override - public int hashCode() { - return context.hashCode(); - } - - @Override - public String getName() { - return "WebSocketSessionContext"; - } - - public WebSocketSessionContext getContext() { - return context; - } - } -} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/WebSocketServletResponse.java b/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/WebSocketServletResponse.java deleted file mode 100644 index 6295bbfea..000000000 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/servlet/WebSocketServletResponse.java +++ /dev/null @@ -1,270 +0,0 @@ -/** - * Copyright (C) 2014 Open WhisperSystems - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ -package org.whispersystems.websocket.servlet; - -import org.eclipse.jetty.websocket.api.RemoteEndpoint; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.websocket.messages.WebSocketMessageFactory; - -import javax.servlet.ServletOutputStream; -import javax.servlet.http.Cookie; -import javax.servlet.http.HttpServletResponse; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.PrintWriter; -import java.nio.ByteBuffer; -import java.util.Collection; -import java.util.LinkedList; -import java.util.Locale; -import java.util.Optional; - - -public class WebSocketServletResponse implements HttpServletResponse { - - @SuppressWarnings("unused") - private static final Logger logger = LoggerFactory.getLogger(WebSocketServletResponse.class); - - private final RemoteEndpoint endPoint; - private final long requestId; - private final WebSocketMessageFactory messageFactory; - - private ResponseBuilder responseBuilder = new ResponseBuilder(); - private ByteArrayOutputStream responseBody = new ByteArrayOutputStream(); - private ServletOutputStream servletOutputStream = new BufferingServletOutputStream(responseBody); - private boolean isCommitted = false; - - public WebSocketServletResponse(RemoteEndpoint endPoint, long requestId, - WebSocketMessageFactory messageFactory) - { - this.endPoint = endPoint; - this.requestId = requestId; - this.messageFactory = messageFactory; - - this.responseBuilder.setRequestId(requestId); - } - - @Override - public void addCookie(Cookie cookie) {} - - @Override - public boolean containsHeader(String name) { - return false; - } - - @Override - public String encodeURL(String url) { - return url; - } - - @Override - public String encodeRedirectURL(String url) { - return url; - } - - @Override - public String encodeUrl(String url) { - return url; - } - - @Override - public String encodeRedirectUrl(String url) { - return url; - } - - @Override - public void sendError(int sc, String msg) throws IOException { - setStatus(sc, msg); - } - - @Override - public void sendError(int sc) throws IOException { - setStatus(sc); - } - - @Override - public void sendRedirect(String location) throws IOException { - throw new IOException("Not supported!"); - } - - @Override - public void setDateHeader(String name, long date) {} - - @Override - public void addDateHeader(String name, long date) {} - - @Override - public void setHeader(String name, String value) {} - - @Override - public void addHeader(String name, String value) {} - - @Override - public void setIntHeader(String name, int value) {} - - @Override - public void addIntHeader(String name, int value) {} - - @Override - public void setStatus(int sc) { - setStatus(sc, ""); - } - - @Override - public void setStatus(int sc, String sm) { - this.responseBuilder.setStatusCode(sc); - this.responseBuilder.setMessage(sm); - } - - @Override - public int getStatus() { - return this.responseBuilder.getStatusCode(); - } - - @Override - public String getHeader(String name) { - return null; - } - - @Override - public Collection getHeaders(String name) { - return new LinkedList<>(); - } - - @Override - public Collection getHeaderNames() { - return new LinkedList<>(); - } - - @Override - public String getCharacterEncoding() { - return "UTF-8"; - } - - @Override - public String getContentType() { - return null; - } - - @Override - public ServletOutputStream getOutputStream() throws IOException { - return servletOutputStream; - } - - @Override - public PrintWriter getWriter() throws IOException { - return new PrintWriter(servletOutputStream); - } - - @Override - public void setCharacterEncoding(String charset) {} - - @Override - public void setContentLength(int len) {} - - @Override - public void setContentLengthLong(long len) {} - - @Override - public void setContentType(String type) {} - - @Override - public void setBufferSize(int size) {} - - @Override - public int getBufferSize() { - return 0; - } - - @Override - public void flushBuffer() throws IOException { - if (!isCommitted) { - byte[] body = responseBody.toByteArray(); - - if (body.length <= 0) { - body = null; - } - - byte[] response = messageFactory.createResponse(responseBuilder.getRequestId(), - responseBuilder.getStatusCode(), - responseBuilder.getMessage(), - new LinkedList<>(), - Optional.ofNullable(body)) - .toByteArray(); - - endPoint.sendBytesByFuture(ByteBuffer.wrap(response)); - isCommitted = true; - } - } - - @Override - public void resetBuffer() { - if (isCommitted) throw new IllegalStateException("Buffer already flushed!"); - responseBody.reset(); - } - - @Override - public boolean isCommitted() { - return isCommitted; - } - - @Override - public void reset() { - if (isCommitted) throw new IllegalStateException("Buffer already flushed!"); - responseBuilder = new ResponseBuilder(); - responseBuilder.setRequestId(requestId); - resetBuffer(); - } - - @Override - public void setLocale(Locale loc) {} - - @Override - public Locale getLocale() { - return Locale.US; - } - - private static class ResponseBuilder { - private long requestId; - private int statusCode; - private String message = ""; - - public long getRequestId() { - return requestId; - } - - public void setRequestId(long requestId) { - this.requestId = requestId; - } - - public int getStatusCode() { - return statusCode; - } - - public void setStatusCode(int statusCode) { - this.statusCode = statusCode; - } - - public String getMessage() { - return message; - } - - public void setMessage(String message) { - this.message = message; - } - } -} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/session/ContextPrincipal.java b/websocket-resources/src/main/java/org/whispersystems/websocket/session/ContextPrincipal.java new file mode 100644 index 000000000..e85eec60e --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/session/ContextPrincipal.java @@ -0,0 +1,37 @@ +package org.whispersystems.websocket.session; + +import java.security.Principal; + +public class ContextPrincipal implements Principal { + + private final WebSocketSessionContext context; + + public ContextPrincipal(WebSocketSessionContext context) { + this.context = context; + } + + @Override + public boolean equals(Object another) { + return another instanceof ContextPrincipal && + context.equals(((ContextPrincipal) another).context); + } + + @Override + public String toString() { + return super.toString(); + } + + @Override + public int hashCode() { + return context.hashCode(); + } + + @Override + public String getName() { + return "WebSocketSessionContext"; + } + + public WebSocketSessionContext getContext() { + return context; + } +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/session/WebSocketSessionContainerRequestValueFactory.java b/websocket-resources/src/main/java/org/whispersystems/websocket/session/WebSocketSessionContainerRequestValueFactory.java new file mode 100644 index 000000000..f830c6ee4 --- /dev/null +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/session/WebSocketSessionContainerRequestValueFactory.java @@ -0,0 +1,31 @@ +package org.whispersystems.websocket.session; + +import org.glassfish.jersey.server.ContainerRequest; +import org.whispersystems.websocket.WebSocketSecurityContext; + +import javax.ws.rs.core.SecurityContext; + +public class WebSocketSessionContainerRequestValueFactory { + private final ContainerRequest request; + + public WebSocketSessionContainerRequestValueFactory(ContainerRequest request) { + this.request = request; + } + + public WebSocketSessionContext provide() { + SecurityContext securityContext = request.getSecurityContext(); + + if (!(securityContext instanceof WebSocketSecurityContext)) { + throw new IllegalStateException("Security context isn't for websocket!"); + } + + WebSocketSessionContext sessionContext = ((WebSocketSecurityContext)securityContext).getSessionContext(); + + if (sessionContext == null) { + throw new IllegalStateException("No session context found for websocket!"); + } + + return sessionContext; + } + +} diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/session/WebSocketSessionContextValueFactoryProvider.java b/websocket-resources/src/main/java/org/whispersystems/websocket/session/WebSocketSessionContextValueFactoryProvider.java index e451d20e7..3773de171 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/session/WebSocketSessionContextValueFactoryProvider.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/session/WebSocketSessionContextValueFactoryProvider.java @@ -1,73 +1,45 @@ package org.whispersystems.websocket.session; -import org.glassfish.hk2.api.InjectionResolver; -import org.glassfish.hk2.api.ServiceLocator; -import org.glassfish.hk2.api.TypeLiteral; -import org.glassfish.hk2.utilities.binding.AbstractBinder; -import org.glassfish.jersey.server.internal.inject.AbstractContainerRequestValueFactory; -import org.glassfish.jersey.server.internal.inject.AbstractValueFactoryProvider; +import org.glassfish.jersey.internal.inject.AbstractBinder; +import org.glassfish.jersey.server.ContainerRequest; +import org.glassfish.jersey.server.internal.inject.AbstractValueParamProvider; import org.glassfish.jersey.server.internal.inject.MultivaluedParameterExtractorProvider; -import org.glassfish.jersey.server.internal.inject.ParamInjectionResolver; import org.glassfish.jersey.server.model.Parameter; -import org.glassfish.jersey.server.spi.internal.ValueFactoryProvider; -import org.whispersystems.websocket.servlet.WebSocketServletRequest; +import org.glassfish.jersey.server.spi.internal.ValueParamProvider; +import javax.annotation.Nullable; import javax.inject.Inject; import javax.inject.Singleton; -import java.security.Principal; +import java.util.function.Function; @Singleton -public class WebSocketSessionContextValueFactoryProvider extends AbstractValueFactoryProvider { +public class WebSocketSessionContextValueFactoryProvider extends AbstractValueParamProvider { @Inject - public WebSocketSessionContextValueFactoryProvider(MultivaluedParameterExtractorProvider mpep, - ServiceLocator injector) - { - super(mpep, injector, Parameter.Source.UNKNOWN); + public WebSocketSessionContextValueFactoryProvider(MultivaluedParameterExtractorProvider mpep) { + super(() -> mpep, Parameter.Source.UNKNOWN); } + @Nullable @Override - public AbstractContainerRequestValueFactory createValueFactory(Parameter parameter) { - if (parameter.getAnnotation(WebSocketSession.class) == null) { + protected Function createValueProvider(Parameter parameter) { + if (!parameter.isAnnotationPresent(WebSocketSession.class)) { return null; - } - - return new AbstractContainerRequestValueFactory() { - - public WebSocketSessionContext provide() { - Principal principal = getContainerRequest().getSecurityContext().getUserPrincipal(); - - if (principal == null) { - throw new IllegalStateException("Cannot inject a custom principal into unauthenticated request"); - } - - if (!(principal instanceof WebSocketServletRequest.ContextPrincipal)) { - throw new IllegalArgumentException("Cannot inject a non-WebSocket AuthPrincipal into request"); - } - - return ((WebSocketServletRequest.ContextPrincipal)principal).getContext(); - } - }; - } - - @Singleton - private static class WebSocketSessionInjectionResolver extends ParamInjectionResolver { - public WebSocketSessionInjectionResolver() { - super(WebSocketSessionContextValueFactoryProvider.class); + } else if (WebSocketSessionContext.class.equals(parameter.getRawType())) { + return request -> new WebSocketSessionContainerRequestValueFactory(request).provide(); + } else { + throw new IllegalArgumentException("Can't inject custom type"); } } public static class Binder extends AbstractBinder { - public Binder() { - } + public Binder() { } @Override protected void configure() { - bind(WebSocketSessionContextValueFactoryProvider.class).to(ValueFactoryProvider.class).in(Singleton.class); - bind(WebSocketSessionInjectionResolver.class).to(new TypeLiteral>() { - }).in(Singleton.class); + bind(WebSocketSessionContextValueFactoryProvider.class).to(ValueParamProvider.class).in(Singleton.class); } } } \ No newline at end of file diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/setup/WebSocketEnvironment.java b/websocket-resources/src/main/java/org/whispersystems/websocket/setup/WebSocketEnvironment.java index 7ec1bacd5..6133dbef5 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/setup/WebSocketEnvironment.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/setup/WebSocketEnvironment.java @@ -17,33 +17,30 @@ package org.whispersystems.websocket.setup; import com.fasterxml.jackson.databind.ObjectMapper; -import org.eclipse.jetty.server.RequestLog; -import org.glassfish.jersey.servlet.ServletContainer; +import org.glassfish.jersey.server.ResourceConfig; import org.whispersystems.websocket.auth.WebSocketAuthenticator; import org.whispersystems.websocket.configuration.WebSocketConfiguration; +import org.whispersystems.websocket.logging.WebsocketRequestLog; import org.whispersystems.websocket.messages.WebSocketMessageFactory; import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory; -import javax.servlet.http.HttpServlet; import javax.validation.Validator; +import java.security.Principal; import io.dropwizard.jersey.DropwizardResourceConfig; -import io.dropwizard.jersey.setup.JerseyContainerHolder; -import io.dropwizard.jersey.setup.JerseyEnvironment; import io.dropwizard.setup.Environment; -public class WebSocketEnvironment { +public class WebSocketEnvironment { - private final JerseyContainerHolder jerseyServletContainer; - private final JerseyEnvironment jerseyEnvironment; + private final ResourceConfig jerseyConfig; private final ObjectMapper objectMapper; private final Validator validator; - private final RequestLog requestLog; + private final WebsocketRequestLog requestLog; private final long idleTimeoutMillis; - private WebSocketAuthenticator authenticator; - private WebSocketMessageFactory messageFactory; - private WebSocketConnectListener connectListener; + private WebSocketAuthenticator authenticator; + private WebSocketMessageFactory messageFactory; + private WebSocketConnectListener connectListener; public WebSocketEnvironment(Environment environment, WebSocketConfiguration configuration) { this(environment, configuration, 60000); @@ -53,27 +50,24 @@ public class WebSocketEnvironment { this(environment, configuration.getRequestLog().build("websocket"), idleTimeoutMillis); } - public WebSocketEnvironment(Environment environment, RequestLog requestLog, long idleTimeoutMillis) { - DropwizardResourceConfig jerseyConfig = new DropwizardResourceConfig(environment.metrics()); - - this.objectMapper = environment.getObjectMapper(); - this.validator = environment.getValidator(); - this.requestLog = requestLog; - this.jerseyServletContainer = new JerseyContainerHolder(new ServletContainer(jerseyConfig) ); - this.jerseyEnvironment = new JerseyEnvironment(jerseyServletContainer, jerseyConfig); - this.messageFactory = new ProtobufWebSocketMessageFactory(); - this.idleTimeoutMillis = idleTimeoutMillis; + public WebSocketEnvironment(Environment environment, WebsocketRequestLog requestLog, long idleTimeoutMillis) { + this.jerseyConfig = new DropwizardResourceConfig(environment.metrics()); + this.objectMapper = environment.getObjectMapper(); + this.validator = environment.getValidator(); + this.requestLog = requestLog; + this.messageFactory = new ProtobufWebSocketMessageFactory(); + this.idleTimeoutMillis = idleTimeoutMillis; } - public JerseyEnvironment jersey() { - return jerseyEnvironment; + public ResourceConfig jersey() { + return jerseyConfig; } - public WebSocketAuthenticator getAuthenticator() { + public WebSocketAuthenticator getAuthenticator() { return authenticator; } - public void setAuthenticator(WebSocketAuthenticator authenticator) { + public void setAuthenticator(WebSocketAuthenticator authenticator) { this.authenticator = authenticator; } @@ -85,7 +79,7 @@ public class WebSocketEnvironment { return objectMapper; } - public RequestLog getRequestLog() { + public WebsocketRequestLog getRequestLog() { return requestLog; } @@ -93,10 +87,6 @@ public class WebSocketEnvironment { return validator; } - public HttpServlet getJerseyServletContainer() { - return (HttpServlet)jerseyServletContainer.getContainer(); - } - public WebSocketMessageFactory getMessageFactory() { return messageFactory; } diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/LoggableRequestResponseTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/LoggableRequestResponseTest.java deleted file mode 100644 index 4faa1874b..000000000 --- a/websocket-resources/src/test/java/org/whispersystems/websocket/LoggableRequestResponseTest.java +++ /dev/null @@ -1,64 +0,0 @@ -package org.whispersystems.websocket; - -import org.eclipse.jetty.server.AbstractNCSARequestLog; -import org.eclipse.jetty.server.NCSARequestLog; -import org.eclipse.jetty.server.RequestLog; -import org.eclipse.jetty.util.component.AbstractLifeCycle; -import org.eclipse.jetty.websocket.api.RemoteEndpoint; -import org.junit.Test; -import org.whispersystems.websocket.messages.WebSocketMessageFactory; -import org.whispersystems.websocket.messages.WebSocketRequestMessage; -import org.whispersystems.websocket.servlet.LoggableRequest; -import org.whispersystems.websocket.servlet.LoggableResponse; -import org.whispersystems.websocket.servlet.WebSocketServletRequest; -import org.whispersystems.websocket.servlet.WebSocketServletResponse; -import org.whispersystems.websocket.session.WebSocketSessionContext; - -import javax.servlet.ServletContext; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - -import java.util.HashMap; -import java.util.Optional; - -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -public class LoggableRequestResponseTest { - - @Test - public void testLogging() { - NCSARequestLog requestLog = new EnabledNCSARequestLog(); - - WebSocketClient webSocketClient = mock(WebSocketClient.class ); - WebSocketRequestMessage requestMessage = mock(WebSocketRequestMessage.class); - ServletContext servletContext = mock(ServletContext.class ); - RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class ); - WebSocketMessageFactory messageFactory = mock(WebSocketMessageFactory.class); - - when(requestMessage.getVerb()).thenReturn("GET"); - when(requestMessage.getBody()).thenReturn(Optional.empty()); - when(requestMessage.getHeaders()).thenReturn(new HashMap<>()); - when(requestMessage.getPath()).thenReturn("/api/v1/test"); - when(requestMessage.getRequestId()).thenReturn(1L); - when(requestMessage.hasRequestId()).thenReturn(true); - - WebSocketSessionContext sessionContext = new WebSocketSessionContext (webSocketClient ); - HttpServletRequest servletRequest = new WebSocketServletRequest (sessionContext, requestMessage, servletContext); - HttpServletResponse servletResponse = new WebSocketServletResponse(remoteEndpoint, 1, messageFactory ); - - LoggableRequest loggableRequest = new LoggableRequest (servletRequest ); - LoggableResponse loggableResponse = new LoggableResponse(servletResponse); - - requestLog.log(loggableRequest, loggableResponse); - } - - - private class EnabledNCSARequestLog extends NCSARequestLog { - @Override - public boolean isEnabled() { - return true; - } - } - -} diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java index 68a995c26..820c64c5a 100644 --- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java +++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderFactoryTest.java @@ -4,16 +4,21 @@ package org.whispersystems.websocket; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest; import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse; +import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory; +import org.glassfish.jersey.server.ResourceConfig; import org.junit.Test; import org.whispersystems.websocket.auth.AuthenticationException; import org.whispersystems.websocket.auth.WebSocketAuthenticator; +import org.whispersystems.websocket.setup.WebSocketConnectListener; import org.whispersystems.websocket.setup.WebSocketEnvironment; +import javax.security.auth.Subject; import javax.servlet.ServletException; import java.io.IOException; +import java.security.Principal; import java.util.Optional; -import io.dropwizard.jersey.setup.JerseyEnvironment; +import io.dropwizard.jersey.DropwizardResourceConfig; import static org.junit.Assert.*; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.*; @@ -21,8 +26,8 @@ import static org.mockito.Mockito.*; public class WebSocketResourceProviderFactoryTest { @Test - public void testUnauthorized() throws ServletException, AuthenticationException, IOException { - JerseyEnvironment jerseyEnvironment = mock(JerseyEnvironment.class ); + public void testUnauthorized() throws AuthenticationException, IOException { + ResourceConfig jerseyEnvironment = new DropwizardResourceConfig(); WebSocketEnvironment environment = mock(WebSocketEnvironment.class ); WebSocketAuthenticator authenticator = mock(WebSocketAuthenticator.class); ServletUpgradeRequest request = mock(ServletUpgradeRequest.class ); @@ -32,7 +37,7 @@ public class WebSocketResourceProviderFactoryTest { when(authenticator.authenticate(eq(request))).thenReturn(new WebSocketAuthenticator.AuthenticationResult<>(Optional.empty(), true)); when(environment.jersey()).thenReturn(jerseyEnvironment); - WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment); + WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class); Object connection = factory.createWebSocket(request, response); assertNull(connection); @@ -42,19 +47,19 @@ public class WebSocketResourceProviderFactoryTest { @Test public void testValidAuthorization() throws AuthenticationException, ServletException { - JerseyEnvironment jerseyEnvironment = mock(JerseyEnvironment.class ); - WebSocketEnvironment environment = mock(WebSocketEnvironment.class ); - WebSocketAuthenticator authenticator = mock(WebSocketAuthenticator.class); - ServletUpgradeRequest request = mock(ServletUpgradeRequest.class ); - ServletUpgradeResponse response = mock(ServletUpgradeResponse.class); - Session session = mock(Session.class ); + ResourceConfig jerseyEnvironment = new DropwizardResourceConfig(); + WebSocketEnvironment environment = mock(WebSocketEnvironment.class ); + WebSocketAuthenticator authenticator = mock(WebSocketAuthenticator.class ); + ServletUpgradeRequest request = mock(ServletUpgradeRequest.class ); + ServletUpgradeResponse response = mock(ServletUpgradeResponse.class ); + Session session = mock(Session.class ); Account account = new Account(); when(environment.getAuthenticator()).thenReturn(authenticator); when(authenticator.authenticate(eq(request))).thenReturn(new WebSocketAuthenticator.AuthenticationResult<>(Optional.of(account), true)); when(environment.jersey()).thenReturn(jerseyEnvironment); - WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment); + WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class); Object connection = factory.createWebSocket(request, response); assertNotNull(connection); @@ -67,7 +72,51 @@ public class WebSocketResourceProviderFactoryTest { assertEquals(((WebSocketResourceProvider)connection).getContext().getAuthenticated(), account); } - private static class Account {} + @Test + public void testErrorAuthorization() throws AuthenticationException, ServletException, IOException { + ResourceConfig jerseyEnvironment = new DropwizardResourceConfig(); + WebSocketEnvironment environment = mock(WebSocketEnvironment.class ); + WebSocketAuthenticator authenticator = mock(WebSocketAuthenticator.class ); + ServletUpgradeRequest request = mock(ServletUpgradeRequest.class ); + ServletUpgradeResponse response = mock(ServletUpgradeResponse.class ); + + when(environment.getAuthenticator()).thenReturn(authenticator); + when(authenticator.authenticate(eq(request))).thenThrow(new AuthenticationException("database failure")); + when(environment.jersey()).thenReturn(jerseyEnvironment); + + WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class); + Object connection = factory.createWebSocket(request, response); + + assertNull(connection); + verify(response).sendError(eq(500), eq("Failure")); + verify(authenticator).authenticate(eq(request)); + } + + @Test + public void testConfigure() { + ResourceConfig jerseyEnvironment = new DropwizardResourceConfig(); + WebSocketEnvironment environment = mock(WebSocketEnvironment.class ); + WebSocketServletFactory servletFactory = mock(WebSocketServletFactory.class ); + when(environment.jersey()).thenReturn(jerseyEnvironment); + + WebSocketResourceProviderFactory factory = new WebSocketResourceProviderFactory(environment, Account.class); + factory.configure(servletFactory); + + verify(servletFactory).setCreator(eq(factory)); + } + + + private static class Account implements Principal { + @Override + public String getName() { + return null; + } + + @Override + public boolean implies(Subject subject) { + return false; + } + } } diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java index f5745fc67..a5d874733 100644 --- a/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java +++ b/websocket-resources/src/test/java/org/whispersystems/websocket/WebSocketResourceProviderTest.java @@ -1,60 +1,95 @@ package org.whispersystems.websocket; -import org.eclipse.jetty.server.RequestLog; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; import org.eclipse.jetty.websocket.api.CloseStatus; import org.eclipse.jetty.websocket.api.RemoteEndpoint; import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.UpgradeRequest; +import org.eclipse.jetty.websocket.api.WriteCallback; +import org.glassfish.jersey.server.ApplicationHandler; +import org.glassfish.jersey.server.ContainerRequest; +import org.glassfish.jersey.server.ContainerResponse; +import org.glassfish.jersey.server.ResourceConfig; import org.junit.Test; import org.mockito.ArgumentCaptor; -import org.whispersystems.websocket.WebSocketResourceProvider; -import org.whispersystems.websocket.auth.AuthenticationException; -import org.whispersystems.websocket.auth.WebSocketAuthenticator; +import org.mockito.stubbing.Answer; +import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; +import org.whispersystems.websocket.logging.WebsocketRequestLog; import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory; +import org.whispersystems.websocket.messages.protobuf.SubProtocol; +import org.whispersystems.websocket.session.WebSocketSession; +import org.whispersystems.websocket.session.WebSocketSessionContext; +import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider; import org.whispersystems.websocket.setup.WebSocketConnectListener; -import javax.servlet.http.HttpServlet; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.io.IOException; +import javax.validation.Valid; +import javax.validation.constraints.Min; +import javax.validation.constraints.NotEmpty; +import javax.ws.rs.Consumes; +import javax.ws.rs.GET; +import javax.ws.rs.PUT; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import javax.ws.rs.ext.ExceptionMapper; +import javax.ws.rs.ext.Provider; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.security.Principal; import java.util.LinkedList; +import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import io.dropwizard.auth.Auth; +import io.dropwizard.auth.AuthValueFactoryProvider; +import io.dropwizard.jersey.DropwizardResourceConfig; +import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.*; public class WebSocketResourceProviderTest { @Test - public void testOnConnect() throws AuthenticationException, IOException { - HttpServlet contextHandler = mock(HttpServlet.class); - WebSocketAuthenticator authenticator = mock(WebSocketAuthenticator.class); - RequestLog requestLog = mock(RequestLog.class); - WebSocketResourceProvider provider = new WebSocketResourceProvider(contextHandler, requestLog, - null, - new ProtobufWebSocketMessageFactory(), - Optional.empty(), - 30000); + public void testOnConnect() { + ApplicationHandler applicationHandler = mock(ApplicationHandler.class ); + WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class ); + WebSocketConnectListener connectListener = mock(WebSocketConnectListener.class); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", + applicationHandler, requestLog, + new TestPrincipal("fooz"), + new ProtobufWebSocketMessageFactory(), + Optional.of(connectListener), + 30000 ); Session session = mock(Session.class ); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); - when(authenticator.authenticate(request)).thenReturn(new WebSocketAuthenticator.AuthenticationResult<>(Optional.of("fooz"), true)); provider.onWebSocketConnect(session); verify(session, never()).close(anyInt(), anyString()); verify(session, never()).close(); verify(session, never()).close(any(CloseStatus.class)); + + ArgumentCaptor contextArgumentCaptor = ArgumentCaptor.forClass(WebSocketSessionContext.class); + verify(connectListener).onWebSocketConnect(contextArgumentCaptor.capture()); + + assertThat(contextArgumentCaptor.getValue().getAuthenticated(TestPrincipal.class).getName()).isEqualTo("fooz"); } @Test - public void testRouteMessage() throws Exception { - HttpServlet servlet = mock(HttpServlet.class ); - WebSocketAuthenticator authenticator = mock(WebSocketAuthenticator.class); - RequestLog requestLog = mock(RequestLog.class ); - WebSocketResourceProvider provider = new WebSocketResourceProvider(servlet, requestLog, Optional.of((WebSocketAuthenticator)authenticator), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000); + public void testMockedRouteMessageSuccess() throws Exception { + ApplicationHandler applicationHandler = mock(ApplicationHandler.class ); + WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000); Session session = mock(Session.class ); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); @@ -62,7 +97,33 @@ public class WebSocketResourceProviderTest { when(session.getUpgradeRequest()).thenReturn(request); when(session.getRemote()).thenReturn(remoteEndpoint); - when(authenticator.authenticate(request)).thenReturn(new WebSocketAuthenticator.AuthenticationResult<>(Optional.of("foo"), true)); + + ContainerResponse response = mock(ContainerResponse.class); + when(response.getStatus()).thenReturn(200); + when(response.getStatusInfo()).thenReturn(new Response.StatusType() { + @Override + public int getStatusCode() { + return 200; + } + + @Override + public Response.Status.Family getFamily() { + return Response.Status.Family.SUCCESSFUL; + } + + @Override + public String getReasonPhrase() { + return "OK"; + } + }); + + ArgumentCaptor responseOutputStream = ArgumentCaptor.forClass(OutputStream.class); + + when(applicationHandler.apply(any(ContainerRequest.class), responseOutputStream.capture())) + .thenAnswer((Answer>) invocation -> { + responseOutputStream.getValue().write("hello world!".getBytes()); + return CompletableFuture.completedFuture(response); + }); provider.onWebSocketConnect(session); @@ -70,21 +131,567 @@ public class WebSocketResourceProviderTest { verify(session, never()).close(); verify(session, never()).close(any(CloseStatus.class)); - byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/bar", new LinkedList(), Optional.of("hello world!".getBytes())).toByteArray(); + byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/bar", new LinkedList<>(), Optional.of("hello world!".getBytes())).toByteArray(); provider.onWebSocketBinary(message, 0, message.length); - ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpServletRequest.class); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ContainerRequest.class); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(ByteBuffer.class ); - verify(servlet).service(requestCaptor.capture(), any(HttpServletResponse.class)); + verify(applicationHandler).apply(requestCaptor.capture(), any(OutputStream.class)); - HttpServletRequest bundledRequest = requestCaptor.getValue(); + ContainerRequest bundledRequest = requestCaptor.getValue(); - byte[] expected = new byte[bundledRequest.getInputStream().available()]; - int read = bundledRequest.getInputStream().read(expected); + assertThat(bundledRequest.getRequest().getMethod()).isEqualTo("GET"); + assertThat(bundledRequest.getBaseUri().toString()).isEqualTo("/"); + assertThat(bundledRequest.getPath(false)).isEqualTo("bar"); - assertThat(read).isEqualTo(expected.length); - assertThat(new String(expected)).isEqualTo("hello world!"); + verify(requestLog).log(eq("127.0.0.1"), eq(bundledRequest), eq(response)); + verify(remoteEndpoint).sendBytesByFuture(responseCaptor.capture()); + + SubProtocol.WebSocketMessage responseMessageContainer = SubProtocol.WebSocketMessage.parseFrom(responseCaptor.getValue().array()); + assertThat(responseMessageContainer.getResponse().getId()).isEqualTo(111L); + assertThat(responseMessageContainer.getResponse().getStatus()).isEqualTo(200); + assertThat(responseMessageContainer.getResponse().getMessage()).isEqualTo("OK"); + assertThat(responseMessageContainer.getResponse().getBody()).isEqualTo(ByteString.copyFrom("hello world!".getBytes())); + } + + @Test + public void testMockedRouteMessageFailure() throws Exception { + ApplicationHandler applicationHandler = mock(ApplicationHandler.class ); + WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000); + + Session session = mock(Session.class ); + RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); + UpgradeRequest request = mock(UpgradeRequest.class); + + when(session.getUpgradeRequest()).thenReturn(request); + when(session.getRemote()).thenReturn(remoteEndpoint); + + when(applicationHandler.apply(any(ContainerRequest.class), any(OutputStream.class))).thenReturn(CompletableFuture.failedFuture(new IllegalStateException("foo"))); + + provider.onWebSocketConnect(session); + + verify(session, never()).close(anyInt(), anyString()); + verify(session, never()).close(); + verify(session, never()).close(any(CloseStatus.class)); + + byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/bar", new LinkedList<>(), Optional.of("hello world!".getBytes())).toByteArray(); + + provider.onWebSocketBinary(message, 0, message.length); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ContainerRequest.class); + + verify(applicationHandler).apply(requestCaptor.capture(), any(OutputStream.class)); + + ContainerRequest bundledRequest = requestCaptor.getValue(); + + assertThat(bundledRequest.getRequest().getMethod()).isEqualTo("GET"); + assertThat(bundledRequest.getBaseUri().toString()).isEqualTo("/"); + assertThat(bundledRequest.getPath(false)).isEqualTo("bar"); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(ByteBuffer.class); + + verify(remoteEndpoint).sendBytesByFuture(responseCaptor.capture()); + + SubProtocol.WebSocketMessage responseMessageContainer = SubProtocol.WebSocketMessage.parseFrom(responseCaptor.getValue().array()); + assertThat(responseMessageContainer.getResponse().getStatus()).isEqualTo(500); + assertThat(responseMessageContainer.getResponse().getMessage()).isEqualTo("Error response"); + assertThat(responseMessageContainer.getResponse().hasBody()).isFalse(); + } + + @Test + public void testActualRouteMessageSuccess() throws InvalidProtocolBufferException { + ResourceConfig resourceConfig = new DropwizardResourceConfig(); + resourceConfig.register(new TestResource()); + resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder()); + resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class)); + resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper())); + + ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); + WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000); + + Session session = mock(Session.class ); + RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); + UpgradeRequest request = mock(UpgradeRequest.class); + + when(session.getUpgradeRequest()).thenReturn(request); + when(session.getRemote()).thenReturn(remoteEndpoint); + + provider.onWebSocketConnect(session); + + byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello", new LinkedList<>(), Optional.empty()).toByteArray(); + + provider.onWebSocketBinary(message, 0, message.length); + + ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); + + verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture()); + + SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); + + assertThat(response.getId()).isEqualTo(111L); + assertThat(response.getStatus()).isEqualTo(200); + assertThat(response.getMessage()).isEqualTo("OK"); + assertThat(response.getBody()).isEqualTo(ByteString.copyFrom("Hello!".getBytes())); + } + + @Test + public void testActualRouteMessageNotFound() throws InvalidProtocolBufferException { + ResourceConfig resourceConfig = new DropwizardResourceConfig(); + resourceConfig.register(new TestResource()); + resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder()); + resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class)); + resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper())); + + + ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); + WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000); + + Session session = mock(Session.class ); + RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); + UpgradeRequest request = mock(UpgradeRequest.class); + + when(session.getUpgradeRequest()).thenReturn(request); + when(session.getRemote()).thenReturn(remoteEndpoint); + + provider.onWebSocketConnect(session); + + byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/doesntexist", new LinkedList<>(), Optional.empty()).toByteArray(); + + provider.onWebSocketBinary(message, 0, message.length); + + ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); + + verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture()); + + SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); + + assertThat(response.getId()).isEqualTo(111L); + assertThat(response.getStatus()).isEqualTo(404); + assertThat(response.getMessage()).isEqualTo("Not Found"); + assertThat(response.hasBody()).isFalse(); + } + + @Test + public void testActualRouteMessageAuthorized() throws InvalidProtocolBufferException { + ResourceConfig resourceConfig = new DropwizardResourceConfig(); + resourceConfig.register(new TestResource()); + resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder()); + resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class)); + resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper())); + + ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); + WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("authorizedUserName"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000); + + Session session = mock(Session.class ); + RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); + UpgradeRequest request = mock(UpgradeRequest.class); + + when(session.getUpgradeRequest()).thenReturn(request); + when(session.getRemote()).thenReturn(remoteEndpoint); + + provider.onWebSocketConnect(session); + + byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/world", new LinkedList<>(), Optional.empty()).toByteArray(); + + provider.onWebSocketBinary(message, 0, message.length); + + ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); + + verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture()); + + SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); + + assertThat(response.getId()).isEqualTo(111L); + assertThat(response.getStatus()).isEqualTo(200); + assertThat(response.getMessage()).isEqualTo("OK"); + assertThat(response.getBody().toStringUtf8()).isEqualTo("World: authorizedUserName"); + } + + @Test + public void testActualRouteMessageUnauthorized() throws InvalidProtocolBufferException { + ResourceConfig resourceConfig = new DropwizardResourceConfig(); + resourceConfig.register(new TestResource()); + resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder()); + resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class)); + resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper())); + + ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); + WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000); + + Session session = mock(Session.class ); + RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); + UpgradeRequest request = mock(UpgradeRequest.class); + + when(session.getUpgradeRequest()).thenReturn(request); + when(session.getRemote()).thenReturn(remoteEndpoint); + + provider.onWebSocketConnect(session); + + byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/world", new LinkedList<>(), Optional.empty()).toByteArray(); + + provider.onWebSocketBinary(message, 0, message.length); + + ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); + + verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture()); + + SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); + + assertThat(response.getId()).isEqualTo(111L); + assertThat(response.getStatus()).isEqualTo(401); + assertThat(response.hasBody()).isFalse(); + } + + @Test + public void testActualRouteMessageOptionalAuthorizedPresent() throws InvalidProtocolBufferException { + ResourceConfig resourceConfig = new DropwizardResourceConfig(); + resourceConfig.register(new TestResource()); + resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder()); + resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class)); + resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper())); + + ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); + WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("something"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000); + + Session session = mock(Session.class ); + RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); + UpgradeRequest request = mock(UpgradeRequest.class); + + when(session.getUpgradeRequest()).thenReturn(request); + when(session.getRemote()).thenReturn(remoteEndpoint); + + provider.onWebSocketConnect(session); + + byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/optional", new LinkedList<>(), Optional.empty()).toByteArray(); + + provider.onWebSocketBinary(message, 0, message.length); + + ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); + + verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture()); + + SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); + + assertThat(response.getId()).isEqualTo(111L); + assertThat(response.getStatus()).isEqualTo(200); + assertThat(response.getMessage()).isEqualTo("OK"); + assertThat(response.getBody().toStringUtf8()).isEqualTo("World: something"); + } + + @Test + public void testActualRouteMessageOptionalAuthorizedEmpty() throws InvalidProtocolBufferException { + ResourceConfig resourceConfig = new DropwizardResourceConfig(); + resourceConfig.register(new TestResource()); + resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder()); + resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class)); + resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper())); + + ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); + WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, null, new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000); + + Session session = mock(Session.class ); + RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); + UpgradeRequest request = mock(UpgradeRequest.class); + + when(session.getUpgradeRequest()).thenReturn(request); + when(session.getRemote()).thenReturn(remoteEndpoint); + + provider.onWebSocketConnect(session); + + byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/optional", new LinkedList<>(), Optional.empty()).toByteArray(); + + provider.onWebSocketBinary(message, 0, message.length); + + ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); + + verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture()); + + SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); + + assertThat(response.getId()).isEqualTo(111L); + assertThat(response.getStatus()).isEqualTo(200); + assertThat(response.getMessage()).isEqualTo("OK"); + assertThat(response.getBody().toStringUtf8()).isEqualTo("Empty world"); + } + + @Test + public void testActualRouteMessagePutAuthenticatedEntity() throws InvalidProtocolBufferException, JsonProcessingException { + ResourceConfig resourceConfig = new DropwizardResourceConfig(); + resourceConfig.register(new TestResource()); + resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder()); + resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class)); + resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper())); + + ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); + WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000); + + Session session = mock(Session.class ); + RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); + UpgradeRequest request = mock(UpgradeRequest.class); + + when(session.getUpgradeRequest()).thenReturn(request); + when(session.getRemote()).thenReturn(remoteEndpoint); + + provider.onWebSocketConnect(session); + + byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT", "/v1/test/some/testparam", List.of("Content-Type: application/json"), Optional.of(new ObjectMapper().writeValueAsBytes(new TestResource.TestEntity("mykey", 1001)))).toByteArray(); + + provider.onWebSocketBinary(message, 0, message.length); + + ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); + + verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture()); + + SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); + + assertThat(response.getId()).isEqualTo(111L); + assertThat(response.getStatus()).isEqualTo(200); + assertThat(response.getMessage()).isEqualTo("OK"); + assertThat(response.getBody().toStringUtf8()).isEqualTo("gooduser:testparam:mykey:1001"); + } + + @Test + public void testActualRouteMessagePutAuthenticatedBadEntity() throws InvalidProtocolBufferException, JsonProcessingException { + ResourceConfig resourceConfig = new DropwizardResourceConfig(); + resourceConfig.register(new TestResource()); + resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder()); + resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class)); + resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper())); + + ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); + WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000); + + Session session = mock(Session.class ); + RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); + UpgradeRequest request = mock(UpgradeRequest.class); + + when(session.getUpgradeRequest()).thenReturn(request); + when(session.getRemote()).thenReturn(remoteEndpoint); + + provider.onWebSocketConnect(session); + + byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT", "/v1/test/some/testparam", List.of("Content-Type: application/json"), Optional.of(new ObjectMapper().writeValueAsBytes(new TestResource.TestEntity("mykey", 5)))).toByteArray(); + + provider.onWebSocketBinary(message, 0, message.length); + + ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); + + verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture()); + + SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); + + assertThat(response.getId()).isEqualTo(111L); + assertThat(response.getStatus()).isEqualTo(400); + assertThat(response.getMessage()).isEqualTo("Bad Request"); + assertThat(response.hasBody()).isFalse(); + } + + @Test + public void testActualRouteMessageExceptionMapping() throws InvalidProtocolBufferException { + ResourceConfig resourceConfig = new DropwizardResourceConfig(); + resourceConfig.register(new TestResource()); + resourceConfig.register(new TestExceptionMapper()); + resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder()); + resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class)); + resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper())); + + ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); + WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000); + + Session session = mock(Session.class ); + RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); + UpgradeRequest request = mock(UpgradeRequest.class); + + when(session.getUpgradeRequest()).thenReturn(request); + when(session.getRemote()).thenReturn(remoteEndpoint); + + provider.onWebSocketConnect(session); + + byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/exception/map", List.of("Content-Type: application/json"), Optional.empty()).toByteArray(); + + provider.onWebSocketBinary(message, 0, message.length); + + ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); + + verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture()); + + SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); + + assertThat(response.getId()).isEqualTo(111L); + assertThat(response.getStatus()).isEqualTo(1337); + assertThat(response.hasBody()).isFalse(); + } + + @Test + public void testActualRouteSessionContextInjection() throws InvalidProtocolBufferException { + ResourceConfig resourceConfig = new DropwizardResourceConfig(); + resourceConfig.register(new TestResource()); + resourceConfig.register(new TestExceptionMapper()); + resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder()); + resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class)); + resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper())); + + ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); + WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, requestLog, new TestPrincipal("gooduser"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000); + + Session session = mock(Session.class ); + RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); + UpgradeRequest request = mock(UpgradeRequest.class); + + when(session.getUpgradeRequest()).thenReturn(request); + when(session.getRemote()).thenReturn(remoteEndpoint); + + provider.onWebSocketConnect(session); + + byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/keepalive", new LinkedList<>(), Optional.empty()).toByteArray(); + + provider.onWebSocketBinary(message, 0, message.length); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(ByteBuffer.class); + + verify(remoteEndpoint).sendBytes(requestCaptor.capture(), any(WriteCallback.class)); + + SubProtocol.WebSocketRequestMessage requestMessage = getRequest(requestCaptor); + assertThat(requestMessage.getVerb()).isEqualTo("GET"); + assertThat(requestMessage.getPath()).isEqualTo("/v1/miccheck"); + assertThat(requestMessage.getBody().toStringUtf8()).isEqualTo("smert ze smert"); + + byte[] clientResponse = new ProtobufWebSocketMessageFactory().createResponse(requestMessage.getId(), 200, "OK", new LinkedList<>(), Optional.of("my response".getBytes())).toByteArray(); + + provider.onWebSocketBinary(clientResponse, 0, clientResponse.length); + + ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class); + + verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture()); + + SubProtocol.WebSocketResponseMessage response = getResponse(responseBytesCaptor); + + assertThat(response.getId()).isEqualTo(111L); + assertThat(response.getStatus()).isEqualTo(200); + assertThat(response.getMessage()).isEqualTo("OK"); + assertThat(response.getBody().toStringUtf8()).isEqualTo("my response"); + } + + private SubProtocol.WebSocketResponseMessage getResponse(ArgumentCaptor responseCaptor) throws InvalidProtocolBufferException { + return SubProtocol.WebSocketMessage.parseFrom(responseCaptor.getValue().array()).getResponse(); + } + + private SubProtocol.WebSocketRequestMessage getRequest(ArgumentCaptor requestCaptor) throws InvalidProtocolBufferException { + return SubProtocol.WebSocketMessage.parseFrom(requestCaptor.getValue().array()).getRequest(); + } + + + public static class TestPrincipal implements Principal { + + private final String name; + + private TestPrincipal(String name) { + this.name = name; + } + + @Override + public String getName() { + return name; + } + } + + public static class TestException extends Exception { + public TestException(String message) { + super(message); + } + } + + @Provider + public static class TestExceptionMapper implements ExceptionMapper { + + @Override + public Response toResponse(TestException exception) { + return Response.status(1337).build(); + } + } + + @Path("/v1/test") + public static class TestResource { + + @GET + @Path("/hello") + public String testGetHello() { + return "Hello!"; + } + + @GET + @Path("/world") + public String testAuthorizedHello(@Auth TestPrincipal user) { + if (user == null) throw new AssertionError(); + + return "World: " + user.getName(); + } + + @GET + @Path("/optional") + public String testAuthorizedHello(@Auth Optional user) { + if (user.isPresent()) return "World: " + user.get().getName(); + else return "Empty world"; + } + + @PUT + @Path("/some/{param}") + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + public Response testSet(@Auth TestPrincipal user, @PathParam ("param") String param, @Valid TestEntity entity) { + return Response.ok(user.name + ":" + param + ":" + entity.key + ":" + entity.value).build(); + } + + @GET + @Path("/exception/map") + public Response testExceptionMapping() throws TestException { + throw new TestException("I'd like to map this"); + } + + @GET + @Path("/keepalive") + public CompletableFuture testContextInjection(@WebSocketSession WebSocketSessionContext context) { + if (context == null) { + throw new AssertionError(); + } + + return context.getClient() + .sendRequest("GET", "/v1/miccheck", new LinkedList<>(), Optional.of("smert ze smert".getBytes())) + .thenApply(response -> Response.ok().entity(new String(response.getBody().get())).build()); + } + + public static class TestEntity { + + public TestEntity(String key, long value) { + this.key = key; + this.value = value; + } + + public TestEntity() { + } + + @JsonProperty + @NotEmpty + private String key; + + @JsonProperty + @Min(100) + private long value; + + } } } diff --git a/websocket-resources/src/test/java/org/whispersystems/websocket/logging/WebSocketRequestLogTest.java b/websocket-resources/src/test/java/org/whispersystems/websocket/logging/WebSocketRequestLogTest.java new file mode 100644 index 000000000..9efeca676 --- /dev/null +++ b/websocket-resources/src/test/java/org/whispersystems/websocket/logging/WebSocketRequestLogTest.java @@ -0,0 +1,120 @@ +package org.whispersystems.websocket.logging; + +import org.glassfish.jersey.internal.MapPropertiesDelegate; +import org.glassfish.jersey.server.ContainerRequest; +import org.glassfish.jersey.server.ContainerResponse; +import org.junit.Test; +import org.whispersystems.websocket.WebSocketSecurityContext; +import org.whispersystems.websocket.session.ContextPrincipal; +import org.whispersystems.websocket.session.WebSocketSessionContext; + +import javax.ws.rs.core.Response; +import java.io.ByteArrayOutputStream; +import java.net.URI; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; + +import ch.qos.logback.classic.LoggerContext; +import ch.qos.logback.core.OutputStreamAppender; +import ch.qos.logback.core.spi.DeferredProcessingAware; +import io.dropwizard.logging.AbstractOutputStreamAppenderFactory; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +public class WebSocketRequestLogTest { + + @Test + public void testLogLineWithoutHeaders() throws InterruptedException { + WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); + + ListAppender listAppender = new ListAppender<>(); + WebsocketRequestLoggerFactory requestLoggerFactory = new WebsocketRequestLoggerFactory(); + requestLoggerFactory.appenders = List.of(new ListAppenderFactory<>(listAppender)); + + WebsocketRequestLog requestLog = requestLoggerFactory.build("test-logger"); + ContainerRequest request = new ContainerRequest (null, URI.create("/v1/test"), "GET", new WebSocketSecurityContext(new ContextPrincipal(sessionContext)), new MapPropertiesDelegate(new HashMap<>()), null); + ContainerResponse response = new ContainerResponse(request, Response.ok("My response body").build()); + + requestLog.log("123.456.789.123", request, response); + + listAppender.waitForListSize(1); + assertThat(listAppender.list.size()).isEqualTo(1); + + String loggedLine = new String(listAppender.outputStream.toByteArray()); + assertThat(loggedLine.matches("123\\.456\\.789\\.123 \\- \\- \\[[0-9]{2}\\/[a-zA-Z]{3}\\/[0-9]{4}:[0-9]{2}:[0-9]{2}:[0-9]{2} \\-[0-9]{4}\\] \"GET \\/v1\\/test WS\" 200 \\- \"\\-\" \"\\-\"\n")).isTrue(); + } + + @Test + public void testLogLineWithHeaders() throws InterruptedException { + WebSocketSessionContext sessionContext = mock(WebSocketSessionContext.class); + + ListAppender listAppender = new ListAppender<>(); + WebsocketRequestLoggerFactory requestLoggerFactory = new WebsocketRequestLoggerFactory(); + requestLoggerFactory.appenders = List.of(new ListAppenderFactory<>(listAppender)); + + WebsocketRequestLog requestLog = requestLoggerFactory.build("test-logger"); + ContainerRequest request = new ContainerRequest (null, URI.create("/v1/test"), "GET", new WebSocketSecurityContext(new ContextPrincipal(sessionContext)), new MapPropertiesDelegate(new HashMap<>()), null); + request.header("User-Agent", "SmertZeSmert"); + request.header("Referer", "https://moxie.org"); + ContainerResponse response = new ContainerResponse(request, Response.ok("My response body").build()); + + requestLog.log("123.456.789.123", request, response); + + listAppender.waitForListSize(1); + assertThat(listAppender.list.size()).isEqualTo(1); + + String loggedLine = new String(listAppender.outputStream.toByteArray()); + assertThat(loggedLine.matches("123\\.456\\.789\\.123 \\- \\- \\[[0-9]{2}\\/[a-zA-Z]{3}\\/[0-9]{4}:[0-9]{2}:[0-9]{2}:[0-9]{2} \\-[0-9]{4}\\] \"GET \\/v1\\/test WS\" 200 \\- \"https://moxie.org\" \"SmertZeSmert\"\n")).isTrue(); + + System.out.println(listAppender.list.get(0)); + System.out.println(new String(listAppender.outputStream.toByteArray())); + } + + + private static class ListAppenderFactory extends AbstractOutputStreamAppenderFactory { + private final ListAppender listAppender; + + public ListAppenderFactory(ListAppender listAppender) { + this.listAppender = listAppender; + } + + @Override + protected OutputStreamAppender appender(LoggerContext context) { + listAppender.setContext(context); + return listAppender; + } + } + + private static class ListAppender extends OutputStreamAppender { + + public final List list = new ArrayList(); + public final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + + protected void append(E e) { + super.append(e); + + synchronized (list) { + list.add(e); + list.notifyAll(); + } + } + + @Override + public void start() { + setOutputStream(outputStream); + super.start(); + } + + public void waitForListSize(int size) throws InterruptedException { + synchronized (list) { + while (list.size() < size) { + list.wait(5000); + } + } + } + + } + + +}