diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/mappers/CompletionExceptionMapper.java b/service/src/main/java/org/whispersystems/textsecuregcm/mappers/CompletionExceptionMapper.java index 9cb6dbc44..a6eb2396b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/mappers/CompletionExceptionMapper.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/mappers/CompletionExceptionMapper.java @@ -12,13 +12,13 @@ import javax.ws.rs.core.Context; import javax.ws.rs.core.Response; import javax.ws.rs.ext.ExceptionMapper; import javax.ws.rs.ext.Provider; -import javax.ws.rs.ext.Providers; +import org.glassfish.jersey.spi.ExceptionMappers; @Provider public class CompletionExceptionMapper implements ExceptionMapper { @Context - private Providers providers; + private ExceptionMappers exceptionMappers; @Override public Response toResponse(final CompletionException exception) { @@ -26,8 +26,7 @@ public class CompletionExceptionMapper implements ExceptionMapper { @Context diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java index 4d92e2a30..9ac97b4a6 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/logging/LoggingUnhandledExceptionMapperTest.java @@ -20,11 +20,18 @@ import io.dropwizard.jersey.DropwizardResourceConfig; import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.ResourceExtension; +import java.nio.ByteBuffer; import java.security.Principal; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; import java.util.stream.Stream; import javax.ws.rs.GET; import javax.ws.rs.Path; @@ -35,13 +42,16 @@ import org.eclipse.jetty.websocket.api.Session; import org.eclipse.jetty.websocket.api.UpgradeRequest; import org.glassfish.jersey.server.ApplicationHandler; import org.glassfish.jersey.server.ResourceConfig; +import org.glassfish.jersey.server.ServerProperties; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.slf4j.Logger; +import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.websocket.WebSocketResourceProvider; import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider; @@ -54,25 +64,45 @@ class LoggingUnhandledExceptionMapperTest { private static final Logger logger = mock(Logger.class); - private static final LoggingUnhandledExceptionMapper exceptionMapper = spy(new LoggingUnhandledExceptionMapper(logger)); + private static final LoggingUnhandledExceptionMapper exceptionMapper = spy( + new LoggingUnhandledExceptionMapper(logger)); private static final ResourceExtension resources = ResourceExtension.builder() + .addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE) + .addProvider(new CompletionExceptionMapper()) .addProvider(exceptionMapper) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .addResource(new TestController()) .build(); + static ScheduledExecutorService scheduledExecutorService; + static Stream testExceptionMapper() { return Stream.of( Arguments.of(false, "/v1/test/no-exception", "/v1/test/no-exception", "Signal-Android/5.1.2 Android/30", null), - Arguments.of(true, "/v1/test/unhandled-runtime-exception", "/v1/test/unhandled-runtime-exception", "Signal-Android/5.1.2 Android/30", "ANDROID 5.1.2"), - Arguments.of(true, "/v1/test/unhandled-runtime-exception/1/and/two", "/v1/test/unhandled-runtime-exception/\\{parameter1\\}/and/\\{parameter2\\}", "Signal-iOS/5.10.2 iOS/14.1", "IOS 5.10.2"), - Arguments.of(true, "/v1/test/unhandled-runtime-exception", "/v1/test/unhandled-runtime-exception", "Some literal user-agent", "Some literal user-agent") + Arguments.of(true, "/v1/test/unhandled-runtime-exception", "/v1/test/unhandled-runtime-exception", + "Signal-Android/5.1.2 Android/30", "ANDROID 5.1.2"), + Arguments.of(true, "/v1/test/unhandled-runtime-exception/1/and/two", + "/v1/test/unhandled-runtime-exception/\\{parameter1\\}/and/\\{parameter2\\}", "Signal-iOS/5.10.2 iOS/14.1", + "IOS 5.10.2"), + Arguments.of(true, "/v1/test/unhandled-runtime-exception", "/v1/test/unhandled-runtime-exception", + "Some literal user-agent", "Some literal user-agent"), + Arguments.of(true, "/v1/test/unhandled-runtime-exception-async", "/v1/test/unhandled-runtime-exception-async", + "Some literal user-agent", "Some literal user-agent"), + Arguments.of(true, "/v1/test/unhandled-runtime-exception-async-completion", + "/v1/test/unhandled-runtime-exception-async-completion", + "Some literal user-agent", "Some literal user-agent") ); } + @BeforeEach + void setup() { + scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(); + } + @AfterEach void teardown() { + scheduledExecutorService.shutdown(); reset(exceptionMapper, logger); } @@ -100,10 +130,13 @@ class LoggingUnhandledExceptionMapperTest { @ParameterizedTest @MethodSource("testExceptionMapper") void testWebsocketExceptionMapper(final boolean expectException, final String targetPath, final String loggedPath, - final String userAgentHeader, final String userAgentLog) { + final String userAgentHeader, final String userAgentLog) throws Exception { + + final CompletableFuture responseFuture = new CompletableFuture<>(); Session session = mock(Session.class); - WebSocketResourceProvider provider = createWebsocketProvider(userAgentHeader, session); + WebSocketResourceProvider provider = createWebsocketProvider(userAgentHeader, session, + responseFuture::complete); provider.onWebSocketConnect(session); @@ -112,6 +145,8 @@ class LoggingUnhandledExceptionMapperTest { provider.onWebSocketBinary(message, 0, message.length); + responseFuture.get(1, TimeUnit.SECONDS); + if (expectException) { verify(exceptionMapper, times(1)).toResponse(any(Exception.class)); verify(logger, times(1)) @@ -123,7 +158,8 @@ class LoggingUnhandledExceptionMapperTest { } - private WebSocketResourceProvider createWebsocketProvider(final String userAgentHeader, final Session session) { + private WebSocketResourceProvider createWebsocketProvider(final String userAgentHeader, + final Session session, final Consumer responseHandler) { ResourceConfig resourceConfig = new DropwizardResourceConfig(); resourceConfig.register(exceptionMapper); resourceConfig.register(new TestController()); @@ -137,6 +173,11 @@ class LoggingUnhandledExceptionMapperTest { requestLog, new TestPrincipal("foo"), new ProtobufWebSocketMessageFactory(), Optional.empty(), 30000); RemoteEndpoint remoteEndpoint = mock(RemoteEndpoint.class); + when(remoteEndpoint.sendBytesByFuture(any())) + .thenAnswer(answer -> { + responseHandler.accept(answer.getArgument(0, ByteBuffer.class)); + return CompletableFuture.completedFuture(null); + }); UpgradeRequest request = mock(UpgradeRequest.class); when(session.getUpgradeRequest()).thenReturn(request); @@ -162,6 +203,29 @@ class LoggingUnhandledExceptionMapperTest { throw new RuntimeException(); } + @GET + @Path("/unhandled-runtime-exception-async") + public CompletableFuture testUnhandledExceptionAsync() { + final CompletableFuture responseFuture = new CompletableFuture<>(); + + scheduledExecutorService.schedule(() -> responseFuture.completeExceptionally(new RuntimeException("async")), + 50, TimeUnit.MILLISECONDS); + + return responseFuture; + } + + @GET + @Path("/unhandled-runtime-exception-async-completion") + public CompletableFuture testUnhandledCompletionExceptionAsync() { + final CompletableFuture responseFuture = new CompletableFuture<>(); + + scheduledExecutorService.schedule( + () -> responseFuture.completeExceptionally(new CompletionException(new RuntimeException("async"))), + 50, TimeUnit.MILLISECONDS); + + return responseFuture; + } + @GET @Path("/unhandled-runtime-exception/{parameter1}/and/{parameter2}") public Response testUnhandledExceptionWithPathParameter(@PathParam("parameter1") String parameter1,