diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index f6aa1a83a..cac78049c 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -213,6 +213,7 @@ import org.whispersystems.textsecuregcm.util.ManagedAwsCrt; import org.whispersystems.textsecuregcm.util.SystemMapper; import org.whispersystems.textsecuregcm.util.UsernameHashZkProofVerifier; import org.whispersystems.textsecuregcm.util.VirtualThreadPinEventMonitor; +import org.whispersystems.textsecuregcm.util.VirtualExecutorServiceProvider; import org.whispersystems.textsecuregcm.util.logging.LoggingUnhandledExceptionMapper; import org.whispersystems.textsecuregcm.util.logging.UncaughtExceptionHandler; import org.whispersystems.textsecuregcm.websocket.AuthenticatedConnectListener; @@ -740,6 +741,7 @@ public class WhisperServerService extends Application webSocketEnvironment = new WebSocketEnvironment<>(environment, config.getWebSocketConfiguration(), Duration.ofMillis(90000)); + webSocketEnvironment.jersey().register(new VirtualExecutorServiceProvider("managed-async-websocket-virtual-thread-")); webSocketEnvironment.setAuthenticator(new WebSocketAccountAuthenticator(accountAuthenticator)); webSocketEnvironment.setConnectListener( new AuthenticatedConnectListener(receiptSender, messagesManager, pushNotificationManager, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/VirtualExecutorServiceProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/VirtualExecutorServiceProvider.java new file mode 100644 index 000000000..db040c1ab --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/VirtualExecutorServiceProvider.java @@ -0,0 +1,56 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.util; + +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import org.glassfish.jersey.server.ManagedAsyncExecutor; +import org.glassfish.jersey.spi.ExecutorServiceProvider; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@ManagedAsyncExecutor +public class VirtualExecutorServiceProvider implements ExecutorServiceProvider { + private static final Logger logger = LoggerFactory.getLogger(VirtualExecutorServiceProvider.class); + + + /** + * Default thread pool executor termination timeout in milliseconds. + */ + public static final int TERMINATION_TIMEOUT = 5000; + private final String virtualThreadNamePrefix; + + public VirtualExecutorServiceProvider(final String virtualThreadNamePrefix) { + this.virtualThreadNamePrefix = virtualThreadNamePrefix; + } + + + @Override + public ExecutorService getExecutorService() { + logger.info("Creating executor service with virtual thread per task"); + return Executors.newThreadPerTaskExecutor(Thread.ofVirtual().name(virtualThreadNamePrefix, 0).factory()); + } + + @Override + public void dispose(final ExecutorService executorService) { + logger.info("Shutting down virtual thread pool executor"); + + executorService.shutdown(); + boolean terminated = false; + try { + terminated = executorService.awaitTermination(TERMINATION_TIMEOUT, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + if (!terminated) { + // virtual thread per task executor has no queue, so shouldn't have any un-run tasks + final List unrunTasks = executorService.shutdownNow(); + logger.info("Force terminated executor with {} un-run tasks", unrunTasks.size()); + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/VirtualExecutorServiceProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/VirtualExecutorServiceProviderTest.java new file mode 100644 index 000000000..34cf890dd --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/VirtualExecutorServiceProviderTest.java @@ -0,0 +1,81 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.util; + +import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; +import io.dropwizard.testing.junit5.ResourceExtension; +import java.security.Principal; +import javax.ws.rs.GET; +import javax.ws.rs.Path; +import javax.ws.rs.core.Response; +import org.glassfish.jersey.server.ManagedAsync; +import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +@ExtendWith(DropwizardExtensionsSupport.class) +class VirtualExecutorServiceProviderTest { + + private static final ResourceExtension resources = ResourceExtension.builder() + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addProvider(new VirtualExecutorServiceProvider("virtual-thread-")) + .addResource(new TestController()) + .build(); + + @Test + public void testManagedAsyncThread() { + final Response response = resources.getJerseyTest() + .target("/v1/test/managed-async") + .request() + .get(); + String threadName = response.readEntity(String.class); + assertThat(threadName).startsWith("virtual-thread-"); + } + + @Test + public void testUnmanagedThread() { + final Response response = resources.getJerseyTest() + .target("/v1/test/unmanaged") + .request() + .get(); + String threadName = response.readEntity(String.class); + assertThat(threadName).doesNotContain("virtual-thread-"); + } + + @Path("/v1/test") + public static class TestController { + + @GET + @Path("/managed-async") + @ManagedAsync + public Response managedAsync() { + return Response.ok().entity(Thread.currentThread().getName()).build(); + } + + @GET + @Path("/unmanaged") + public Response unmanaged() { + return Response.ok().entity(Thread.currentThread().getName()).build(); + } + + } + + public static class TestPrincipal implements Principal { + + private final String name; + + private TestPrincipal(String name) { + this.name = name; + } + + @Override + public String getName() { + return name; + } + } +}