diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapper.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapper.java index e4accb904..9f7ba458a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapper.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapper.java @@ -7,9 +7,9 @@ package org.whispersystems.textsecuregcm.util.logging; import com.google.common.annotations.VisibleForTesting; import io.dropwizard.jersey.errors.LoggingExceptionMapper; -import javax.servlet.http.HttpServletRequest; +import javax.inject.Provider; import javax.ws.rs.core.Context; -import org.glassfish.jersey.server.ExtendedUriInfo; +import org.glassfish.jersey.server.ContainerRequest; import org.slf4j.Logger; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UserAgent; @@ -18,10 +18,7 @@ import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; public class LoggingUnhandledExceptionMapper extends LoggingExceptionMapper { @Context - private HttpServletRequest request; - - @Context - private ExtendedUriInfo uriInfo; + private Provider request; public LoggingUnhandledExceptionMapper() { super(); @@ -38,10 +35,10 @@ public class LoggingUnhandledExceptionMapper extends LoggingExceptionMapper testExceptionMapper() { + return Stream.of( + Arguments.of(false, "/v1/test/no-exception", "/v1/test/no-exception", "Signal-Android/5.1.2 Android/30", null), + Arguments.of(true, "/v1/test/unhandled-runtime-exception", "/v1/test/unhandled-runtime-exception", "Signal-Android/5.1.2 Android/30", "ANDROID 5.1.2"), + Arguments.of(true, "/v1/test/unhandled-runtime-exception/1/and/two", "/v1/test/unhandled-runtime-exception/\\{parameter1\\}/and/\\{parameter2\\}", "Signal-iOS/5.10.2 iOS/14.1", "IOS 5.10.2"), + Arguments.of(true, "/v1/test/unhandled-runtime-exception", "/v1/test/unhandled-runtime-exception", "Some literal user-agent", "Some literal user-agent") + ); + } + + @AfterEach + void teardown() { + reset(exceptionMapper, logger); } @ParameterizedTest @MethodSource - void testExceptionMapper(final boolean expectException, final String targetPath, final String loggedPath, final String userAgentHeader, - final String userAgentLog) { + void testExceptionMapper(final boolean expectException, final String targetPath, final String loggedPath, + final String userAgentHeader, final String userAgentLog) { resources.getJerseyTest() .target(targetPath) @@ -59,21 +87,63 @@ class LoggingUnhandledExceptionMapperTest { .get(); if (expectException) { - verify(exceptionMapper, times(1)).toResponse(any(Exception.class)); - verify(logger, times(1)).error(matches(String.format(".* at GET %s \\(%s\\)", loggedPath, userAgentLog)), any(Exception.class)); + verify(logger, times(1)) + .error(matches(String.format(".* at GET %s \\(%s\\)", loggedPath, userAgentLog)), any(Exception.class)); } else { verifyNoInteractions(exceptionMapper); } } - static Stream testExceptionMapper() { - return Stream.of( - Arguments.of(false, "/v1/test/no-exception", "/v1/test/no-exception", null, null, null), - Arguments.of(true, "/v1/test/unhandled-runtime-exception", "/v1/test/unhandled-runtime-exception", "Signal-Android/5.1.2 Android/30", "ANDROID 5.1.2"), - Arguments.of(true, "/v1/test/unhandled-runtime-exception/1/and/two", "/v1/test/unhandled-runtime-exception/\\{parameter1\\}/and/\\{parameter2\\}", "Signal-iOS/5.10.2 iOS/14.1", "IOS 5.10.2") - ); + @ParameterizedTest + @MethodSource("testExceptionMapper") + void testWebsocketExceptionMapper(final boolean expectException, final String targetPath, final String loggedPath, + final String userAgentHeader, final String userAgentLog) { + + Session session = mock(Session.class); + WebSocketResourceProvider provider = createWebsocketProvider(userAgentHeader, session); + + provider.onWebSocketConnect(session); + + byte[] message = new ProtobufWebSocketMessageFactory() + .createRequest(Optional.of(111L), "GET", targetPath, new LinkedList<>(), Optional.empty()).toByteArray(); + + provider.onWebSocketBinary(message, 0, message.length); + + if (expectException) { + verify(exceptionMapper, times(1)).toResponse(any(Exception.class)); + verify(logger, times(1)) + .error(matches(String.format(".* at GET %s \\(%s\\)", loggedPath, userAgentLog)), any(Exception.class)); + + } else { + verifyNoInteractions(exceptionMapper); + } + + } + + private WebSocketResourceProvider createWebsocketProvider(final String userAgentHeader, final Session session) { + ResourceConfig resourceConfig = new DropwizardResourceConfig(); + resourceConfig.register(exceptionMapper); + resourceConfig.register(new TestController()); + resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder()); + resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class)); + resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper())); + + ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig); + WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class); + WebSocketResourceProvider provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler, + requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000); + + RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); + UpgradeRequest request = mock(UpgradeRequest.class); + + when(session.getUpgradeRequest()).thenReturn(request); + when(session.getRemote()).thenReturn(remoteEndpoint); + when(request.getHeader("User-Agent")).thenReturn(userAgentHeader); + when(request.getHeaders()).thenReturn(Map.of("User-Agent", List.of(userAgentHeader))); + + return provider; } @Path("/v1/test") @@ -93,8 +163,23 @@ class LoggingUnhandledExceptionMapperTest { @GET @Path("/unhandled-runtime-exception/{parameter1}/and/{parameter2}") - public Response testUnhandledExceptionWithPathParameter(@PathParam("parameter1") String parameter1, @PathParam("parameter2") String parameter2) { + public Response testUnhandledExceptionWithPathParameter(@PathParam("parameter1") String parameter1, + @PathParam("parameter2") String parameter2) { throw new RuntimeException(); } } + + public static class TestPrincipal implements Principal { + + private final String name; + + private TestPrincipal(String name) { + this.name = name; + } + + @Override + public String getName() { + return name; + } + } }