diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 48d3979d9..9d2c5cfae 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -157,6 +157,7 @@ import org.whispersystems.textsecuregcm.mappers.RegistrationServiceSenderExcepti import org.whispersystems.textsecuregcm.mappers.ServerRejectedExceptionMapper; import org.whispersystems.textsecuregcm.mappers.SubscriptionProcessorExceptionMapper; import org.whispersystems.textsecuregcm.metrics.MetricsApplicationEventListener; +import org.whispersystems.textsecuregcm.metrics.MetricsHttpChannelListener; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; import org.whispersystems.textsecuregcm.metrics.ReportedMessageMetricsListener; import org.whispersystems.textsecuregcm.metrics.TrafficSource; @@ -180,7 +181,6 @@ import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient; import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client; import org.whispersystems.textsecuregcm.spam.FilterSpam; import org.whispersystems.textsecuregcm.spam.PushChallengeConfigProvider; -import org.whispersystems.textsecuregcm.spam.RateLimitChallengeListener; import org.whispersystems.textsecuregcm.spam.ReportSpamTokenProvider; import org.whispersystems.textsecuregcm.spam.ScoreThresholdProvider; import org.whispersystems.textsecuregcm.spam.SenderOverrideProvider; @@ -792,6 +792,9 @@ public class WhisperServerService extends Application + * It implements {@link LifeCycle.Listener} without overriding methods, so that it can be an event listener that + * Dropwizard will attach to the container—the {@link Container.Listener} implementation is where it attaches + * itself to any {@link Connector}s. + * + * @see MetricsRequestEventListener + */ +public class MetricsHttpChannelListener implements HttpChannel.Listener, Container.Listener, LifeCycle.Listener, + ContainerRequestFilter { + + private static final Logger logger = LoggerFactory.getLogger(MetricsHttpChannelListener.class); + + private record RequestInfo(String path, String method, int statusCode, @Nullable String userAgent) { + + } + + private final ClientReleaseManager clientReleaseManager; + + public static final String REQUEST_COUNTER_NAME = name(MetricsHttpChannelListener.class, "request"); + public static final String REQUESTS_BY_VERSION_COUNTER_NAME = name(MetricsHttpChannelListener.class, + "requestByVersion"); + @VisibleForTesting + static final String URI_INFO_PROPERTY_NAME = MetricsHttpChannelListener.class.getName() + ".uriInfo"; + + @VisibleForTesting + static final String PATH_TAG = "path"; + + @VisibleForTesting + static final String METHOD_TAG = "method"; + + @VisibleForTesting + static final String STATUS_CODE_TAG = "status"; + + @VisibleForTesting + static final String TRAFFIC_SOURCE_TAG = "trafficSource"; + + private final MeterRegistry meterRegistry; + + + public MetricsHttpChannelListener(final ClientReleaseManager clientReleaseManager) { + this(Metrics.globalRegistry, clientReleaseManager); + } + + @VisibleForTesting + MetricsHttpChannelListener(final MeterRegistry meterRegistry, final ClientReleaseManager clientReleaseManager) { + this.meterRegistry = meterRegistry; + this.clientReleaseManager = clientReleaseManager; + } + + public void configure(final Environment environment) { + // register as ContainerRequestFilter + environment.jersey().register(this); + + // hook into lifecycle events, to react to the Connector being added + environment.lifecycle().addEventListener(this); + } + + @Override + public void onRequestFailure(final Request request, final Throwable failure) { + final RequestInfo requestInfo = getRequestInfo(request); + + logger.warn("Request failure: {} {} ({}) [{}] ", + requestInfo.method(), + requestInfo.path(), + requestInfo.userAgent(), + requestInfo.statusCode(), failure); + } + + @Override + public void onComplete(final Request request) { + + final RequestInfo requestInfo = getRequestInfo(request); + + final List tags = new ArrayList<>(5); + tags.add(Tag.of(PATH_TAG, requestInfo.path())); + tags.add(Tag.of(METHOD_TAG, requestInfo.method())); + tags.add(Tag.of(STATUS_CODE_TAG, String.valueOf(requestInfo.statusCode()))); + tags.add(Tag.of(TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase())); + + final Tag platformTag = UserAgentTagUtil.getPlatformTag(requestInfo.userAgent()); + tags.add(platformTag); + + meterRegistry.counter(REQUEST_COUNTER_NAME, tags).increment(); + + UserAgentTagUtil.getClientVersionTag(requestInfo.userAgent(), clientReleaseManager).ifPresent( + clientVersionTag -> meterRegistry.counter(REQUESTS_BY_VERSION_COUNTER_NAME, + Tags.of(clientVersionTag, platformTag)).increment()); + } + + @Override + public void beanAdded(final Container parent, final Object child) { + if (child instanceof Connector connector) { + connector.addBean(this); + } + } + + @Override + public void beanRemoved(final Container parent, final Object child) { + + } + + @Override + public void filter(final ContainerRequestContext requestContext) throws IOException { + requestContext.setProperty(URI_INFO_PROPERTY_NAME, requestContext.getUriInfo()); + } + + private RequestInfo getRequestInfo(Request request) { + final String path = Optional.ofNullable(request.getAttribute(URI_INFO_PROPERTY_NAME)) + .map(attr -> UriInfoUtil.getPathTemplate((ExtendedUriInfo) attr)) + .orElse("unknown"); + final String method = Optional.ofNullable(request.getMethod()).orElse("unknown"); + // Response cannot be null, but its status might not always reflect an actual response status, since it gets + // initialized to 200 + final int status = request.getResponse().getStatus(); + + @Nullable final String userAgent = request.getHeader(HttpHeaders.USER_AGENT); + + return new RequestInfo(path, method, status, userAgent); + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerIntegrationTest.java new file mode 100644 index 000000000..daf38d1e6 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerIntegrationTest.java @@ -0,0 +1,231 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.metrics; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.net.HttpHeaders; +import io.dropwizard.core.Application; +import io.dropwizard.core.Configuration; +import io.dropwizard.core.setup.Environment; +import io.dropwizard.testing.junit5.DropwizardAppExtension; +import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Tag; +import java.security.Principal; +import java.time.Duration; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; +import javax.security.auth.Subject; +import javax.ws.rs.GET; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.client.Client; +import javax.ws.rs.container.ContainerRequestContext; +import javax.ws.rs.core.Context; +import org.eclipse.jetty.server.Connector; +import org.eclipse.jetty.server.HttpChannel; +import org.eclipse.jetty.server.Request; +import org.eclipse.jetty.util.component.Container; +import org.eclipse.jetty.util.component.LifeCycle; +import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.ArgumentCaptor; +import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; +import org.whispersystems.websocket.WebSocketResourceProviderFactory; +import org.whispersystems.websocket.configuration.WebSocketConfiguration; +import org.whispersystems.websocket.setup.WebSocketEnvironment; + +@ExtendWith(DropwizardExtensionsSupport.class) +class MetricsHttpChannelListenerIntegrationTest { + + private static final TrafficSource TRAFFIC_SOURCE = TrafficSource.HTTP; + private static final MeterRegistry METER_REGISTRY = mock(MeterRegistry.class); + private static final Counter COUNTER = mock(Counter.class); + private static final AtomicReference> LISTENER_FUTURE_REFERENCE = new AtomicReference<>(); + + private static final DropwizardAppExtension EXTENSION = new DropwizardAppExtension<>( + MetricsHttpChannelListenerIntegrationTest.TestApplication.class); + + @AfterEach + void teardown() { + reset(METER_REGISTRY); + reset(COUNTER); + } + + @ParameterizedTest + @MethodSource + @SuppressWarnings("unchecked") + void testSimplePath(String requestPath, String expectedTagPath, String expectedResponse) throws Exception { + + final CompletableFuture listenerCompleteFuture = new CompletableFuture<>(); + LISTENER_FUTURE_REFERENCE.set(listenerCompleteFuture); + + final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); + when(METER_REGISTRY.counter(anyString(), any(Iterable.class))) + .thenAnswer(a -> MetricsHttpChannelListener.REQUEST_COUNTER_NAME.equals(a.getArgument(0, String.class)) + ? COUNTER + : mock(Counter.class)) + .thenReturn(COUNTER); + + Client client = EXTENSION.client(); + + final String response = client.target( + String.format("http://localhost:%d%s", EXTENSION.getLocalPort(), requestPath)) + .request() + .header(HttpHeaders.USER_AGENT, "Signal-Android/4.53.7 (Android 8.1)") + .get(String.class); + + assertEquals(expectedResponse, response); + + listenerCompleteFuture.get(1000, TimeUnit.MILLISECONDS); + + verify(METER_REGISTRY).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture()); + verify(COUNTER).increment(); + + final Iterable tagIterable = tagCaptor.getValue(); + final Set tags = new HashSet<>(); + + for (final Tag tag : tagIterable) { + tags.add(tag); + } + + assertEquals(5, tags.size()); + assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.PATH_TAG, expectedTagPath))); + assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.METHOD_TAG, "GET"))); + assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, "200"))); + assertTrue( + tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()))); + assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); + } + + static Stream testSimplePath() { + return Stream.of( + Arguments.of("/v1/test/hello", "/v1/test/hello", "Hello!"), + Arguments.of("/v1/test/greet/friend", "/v1/test/greet/{name}", + String.format(TestResource.GREET_FORMAT, "friend")) + ); + } + + public static class TestApplication extends Application { + + @Override + public void run(final Configuration configuration, + final Environment environment) throws Exception { + + final MetricsHttpChannelListener metricsHttpChannelListener = new MetricsHttpChannelListener( + METER_REGISTRY, + mock(ClientReleaseManager.class)); + + metricsHttpChannelListener.configure(environment); + environment.lifecycle().addEventListener(new TestListener(LISTENER_FUTURE_REFERENCE)); + + environment.jersey().register(new TestResource()); + + // WebSocket set up + final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration(); + + WebSocketEnvironment webSocketEnvironment = new WebSocketEnvironment<>(environment, + webSocketConfiguration, Duration.ofMillis(1000)); + + webSocketEnvironment.jersey().register(new TestResource()); + + JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); + + WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory<>( + webSocketEnvironment, TestPrincipal.class, webSocketConfiguration, "ignored"); + + environment.servlets().addServlet("WebSocket", webSocketServlet); + } + } + + /** + * A simple listener to signal that {@link HttpChannel.Listener} has completed its work, since its onComplete() is on + * a different thread from the one that sends the response, creating a race condition between the listener and the + * test assertions + */ + static class TestListener implements HttpChannel.Listener, Container.Listener, LifeCycle.Listener { + + private final AtomicReference> completableFutureAtomicReference; + + TestListener(AtomicReference> completableFutureAtomicReference) { + + this.completableFutureAtomicReference = completableFutureAtomicReference; + } + + @Override + public void onComplete(final Request request) { + completableFutureAtomicReference.get().complete(null); + } + + @Override + public void beanAdded(final Container parent, final Object child) { + if (child instanceof Connector connector) { + connector.addBean(this); + } + } + + @Override + public void beanRemoved(final Container parent, final Object child) { + + } + + } + + @Path("/v1/test") + public static class TestResource { + + static final String GREET_FORMAT = "Hello, %s!"; + + + @GET + @Path("/hello") + public String testGetHello() { + return "Hello!"; + } + + @GET + @Path("/greet/{name}") + public String testGreetByName(@PathParam("name") String name, @Context ContainerRequestContext context) { + + context.setProperty("uriInfo", context.getUriInfo()); + + return String.format(GREET_FORMAT, name); + } + } + + public static class TestPrincipal implements Principal { + + // Principal implementation + + @Override + public String getName() { + return null; + } + + @Override + public boolean implies(final Subject subject) { + return false; + } + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerTest.java new file mode 100644 index 000000000..614ca99af --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerTest.java @@ -0,0 +1,95 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.metrics; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.net.HttpHeaders; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Tag; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import org.eclipse.jetty.http.HttpURI; +import org.eclipse.jetty.server.Request; +import org.eclipse.jetty.server.Response; +import org.glassfish.jersey.server.ExtendedUriInfo; +import org.glassfish.jersey.uri.UriTemplate; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; + +class MetricsHttpChannelListenerTest { + + private MeterRegistry meterRegistry; + private Counter counter; + private MetricsHttpChannelListener listener; + + @BeforeEach + void setup() { + meterRegistry = mock(MeterRegistry.class); + counter = mock(Counter.class); + + final ClientReleaseManager clientReleaseManager = mock(ClientReleaseManager.class); + when(clientReleaseManager.isVersionActive(any(), any())).thenReturn(false); + + listener = new MetricsHttpChannelListener(meterRegistry, clientReleaseManager); + } + + @Test + @SuppressWarnings("unchecked") + void testOnEvent() { + final String path = "/test"; + final String method = "GET"; + final int statusCode = 200; + + final HttpURI httpUri = mock(HttpURI.class); + when(httpUri.getPath()).thenReturn(path); + + final Request request = mock(Request.class); + when(request.getMethod()).thenReturn(method); + when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/4.53.7 (Android 8.1)"); + when(request.getHttpURI()).thenReturn(httpUri); + + final Response response = mock(Response.class); + when(response.getStatus()).thenReturn(statusCode); + when(request.getResponse()).thenReturn(response); + final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class); + when(request.getAttribute(MetricsHttpChannelListener.URI_INFO_PROPERTY_NAME)).thenReturn(extendedUriInfo); + when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path))); + + final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); + when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), any(Iterable.class))) + .thenReturn(counter); + + listener.onComplete(request); + + verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture()); + + final Iterable tagIterable = tagCaptor.getValue(); + final Set tags = new HashSet<>(); + + for (final Tag tag : tagIterable) { + tags.add(tag); + } + + assertEquals(5, tags.size()); + assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.PATH_TAG, path))); + assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.METHOD_TAG, method))); + assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(statusCode)))); + assertTrue( + tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase()))); + assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); + } +}