Add static servlet paths to MetricsHttpChannelListener

This commit is contained in:
Chris Eager 2024-02-14 15:27:13 -06:00 committed by Chris Eager
parent f90ccd3391
commit 9ce2b7555c
4 changed files with 169 additions and 22 deletions

View File

@ -37,6 +37,7 @@ import java.util.EnumSet;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.ServiceLoader; import java.util.ServiceLoader;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue; import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
@ -97,8 +98,8 @@ import org.whispersystems.textsecuregcm.controllers.ArtController;
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV2; import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV2;
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV3; import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV3;
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV4; import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV4;
import org.whispersystems.textsecuregcm.controllers.CallRoutingController;
import org.whispersystems.textsecuregcm.controllers.CallLinkController; import org.whispersystems.textsecuregcm.controllers.CallLinkController;
import org.whispersystems.textsecuregcm.controllers.CallRoutingController;
import org.whispersystems.textsecuregcm.controllers.CertificateController; import org.whispersystems.textsecuregcm.controllers.CertificateController;
import org.whispersystems.textsecuregcm.controllers.ChallengeController; import org.whispersystems.textsecuregcm.controllers.ChallengeController;
import org.whispersystems.textsecuregcm.controllers.DeviceController; import org.whispersystems.textsecuregcm.controllers.DeviceController;
@ -792,7 +793,11 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
.setAuthenticator(accountAuthenticator) .setAuthenticator(accountAuthenticator)
.buildAuthFilter(); .buildAuthFilter();
final MetricsHttpChannelListener metricsHttpChannelListener = new MetricsHttpChannelListener(clientReleaseManager); final String websocketServletPath = "/v1/websocket/";
final String provisioningWebsocketServletPath = "/v1/websocket/provisioning/";
final MetricsHttpChannelListener metricsHttpChannelListener = new MetricsHttpChannelListener(clientReleaseManager,
Set.of(websocketServletPath, provisioningWebsocketServletPath, "/health-check"));
metricsHttpChannelListener.configure(environment); metricsHttpChannelListener.configure(environment);
environment.jersey().register(new VirtualExecutorServiceProvider("managed-async-virtual-thread-")); environment.jersey().register(new VirtualExecutorServiceProvider("managed-async-virtual-thread-"));
@ -950,10 +955,10 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet); ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet);
ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet); ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet);
websocket.addMapping("/v1/websocket/"); websocket.addMapping(websocketServletPath);
websocket.setAsyncSupported(true); websocket.setAsyncSupported(true);
provisioning.addMapping("/v1/websocket/provisioning/"); provisioning.addMapping(provisioningWebsocketServletPath);
provisioning.setAsyncSupported(true); provisioning.setAsyncSupported(true);
environment.admin().addTask(new SetRequestLoggingEnabledTask()); environment.admin().addTask(new SetRequestLoggingEnabledTask());

View File

@ -18,6 +18,7 @@ import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
import java.util.Set;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.ws.rs.container.ContainerRequestContext; import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerResponseContext; import javax.ws.rs.container.ContainerResponseContext;
@ -54,6 +55,7 @@ public class MetricsHttpChannelListener implements HttpChannel.Listener, Contain
} }
private final ClientReleaseManager clientReleaseManager; private final ClientReleaseManager clientReleaseManager;
private final Set<String> servletPaths;
public static final String REQUEST_COUNTER_NAME = name(MetricsHttpChannelListener.class, "request"); public static final String REQUEST_COUNTER_NAME = name(MetricsHttpChannelListener.class, "request");
public static final String REQUESTS_BY_VERSION_COUNTER_NAME = name(MetricsHttpChannelListener.class, public static final String REQUESTS_BY_VERSION_COUNTER_NAME = name(MetricsHttpChannelListener.class,
@ -76,14 +78,16 @@ public class MetricsHttpChannelListener implements HttpChannel.Listener, Contain
private final MeterRegistry meterRegistry; private final MeterRegistry meterRegistry;
public MetricsHttpChannelListener(final ClientReleaseManager clientReleaseManager) { public MetricsHttpChannelListener(final ClientReleaseManager clientReleaseManager, final Set<String> servletPaths) {
this(Metrics.globalRegistry, clientReleaseManager); this(Metrics.globalRegistry, clientReleaseManager, servletPaths);
} }
@VisibleForTesting @VisibleForTesting
MetricsHttpChannelListener(final MeterRegistry meterRegistry, final ClientReleaseManager clientReleaseManager) { MetricsHttpChannelListener(final MeterRegistry meterRegistry, final ClientReleaseManager clientReleaseManager,
final Set<String> servletPaths) {
this.meterRegistry = meterRegistry; this.meterRegistry = meterRegistry;
this.clientReleaseManager = clientReleaseManager; this.clientReleaseManager = clientReleaseManager;
this.servletPaths = servletPaths;
} }
public void configure(final Environment environment) { public void configure(final Environment environment) {
@ -158,7 +162,12 @@ public class MetricsHttpChannelListener implements HttpChannel.Listener, Contain
private RequestInfo getRequestInfo(Request request) { private RequestInfo getRequestInfo(Request request) {
final String path = Optional.ofNullable(request.getAttribute(URI_INFO_PROPERTY_NAME)) final String path = Optional.ofNullable(request.getAttribute(URI_INFO_PROPERTY_NAME))
.map(attr -> UriInfoUtil.getPathTemplate((ExtendedUriInfo) attr)) .map(attr -> UriInfoUtil.getPathTemplate((ExtendedUriInfo) attr))
.orElse("unknown"); .orElseGet(() -> {
if (servletPaths.contains(request.getPathInfo())) {
return request.getPathInfo();
}
return "unknown";
});
final String method = Optional.ofNullable(request.getMethod()).orElse("unknown"); final String method = Optional.ofNullable(request.getMethod()).orElse("unknown");
// Response cannot be null, but its status might not always reflect an actual response status, since it gets // Response cannot be null, but its status might not always reflect an actual response status, since it gets
// initialized to 200 // initialized to 200

View File

@ -27,8 +27,10 @@ import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tag;
import java.io.IOException; import java.io.IOException;
import java.net.URI;
import java.security.Principal; import java.security.Principal;
import java.time.Duration; import java.time.Duration;
import java.util.EnumSet;
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
@ -38,6 +40,7 @@ import java.util.function.Supplier;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.annotation.Priority; import javax.annotation.Priority;
import javax.security.auth.Subject; import javax.security.auth.Subject;
import javax.servlet.DispatcherType;
import javax.ws.rs.GET; import javax.ws.rs.GET;
import javax.ws.rs.InternalServerErrorException; import javax.ws.rs.InternalServerErrorException;
import javax.ws.rs.NotAuthorizedException; import javax.ws.rs.NotAuthorizedException;
@ -55,13 +58,21 @@ import org.eclipse.jetty.server.HttpChannel;
import org.eclipse.jetty.server.Request; import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.util.component.Container; import org.eclipse.jetty.util.component.Container;
import org.eclipse.jetty.util.component.LifeCycle; import org.eclipse.jetty.util.component.LifeCycle;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.websocket.WebSocketResourceProviderFactory; import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.configuration.WebSocketConfiguration; import org.whispersystems.websocket.configuration.WebSocketConfiguration;
@ -148,6 +159,64 @@ class MetricsHttpChannelListenerIntegrationTest {
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
} }
@Nested
class WebSocket {
private WebSocketClient client;
@BeforeEach
void setUp() throws Exception {
client = new WebSocketClient();
client.start();
}
@AfterEach
void tearDown() throws Exception {
client.stop();
}
@Test
void testWebSocketUpgrade() throws Exception {
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
upgradeRequest.setHeader(HttpHeaders.USER_AGENT, "Signal-Android/4.53.7 (Android 8.1)");
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
when(METER_REGISTRY.counter(anyString(), any(Iterable.class)))
.thenAnswer(a -> MetricsHttpChannelListener.REQUEST_COUNTER_NAME.equals(a.getArgument(0, String.class))
? COUNTER
: mock(Counter.class))
.thenReturn(COUNTER);
client.connect(new WebSocketListener() {
@Override
public void onWebSocketConnect(final Session session) {
session.close(1000, "OK");
}
},
URI.create(String.format("ws://localhost:%d%s", EXTENSION.getLocalPort(), "/v1/websocket")), upgradeRequest)
.get(1, TimeUnit.SECONDS);
verify(METER_REGISTRY).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture());
verify(COUNTER).increment();
final Iterable<Tag> tagIterable = tagCaptor.getValue();
final Set<Tag> tags = new HashSet<>();
for (final Tag tag : tagIterable) {
tags.add(tag);
}
assertEquals(5, tags.size());
assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.PATH_TAG, "/v1/websocket")));
assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.METHOD_TAG, "GET")));
assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(101))));
assertTrue(
tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase())));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
}
}
static Stream<Arguments> testSimplePath() { static Stream<Arguments> testSimplePath() {
return Stream.of( return Stream.of(
Arguments.of("/v1/test/hello", "/v1/test/hello", "Hello!", 200), Arguments.of("/v1/test/hello", "/v1/test/hello", "Hello!", 200),
@ -166,11 +235,16 @@ class MetricsHttpChannelListenerIntegrationTest {
final MetricsHttpChannelListener metricsHttpChannelListener = new MetricsHttpChannelListener( final MetricsHttpChannelListener metricsHttpChannelListener = new MetricsHttpChannelListener(
METER_REGISTRY, METER_REGISTRY,
mock(ClientReleaseManager.class)); mock(ClientReleaseManager.class),
Set.of("/v1/websocket")
);
metricsHttpChannelListener.configure(environment); metricsHttpChannelListener.configure(environment);
environment.lifecycle().addEventListener(new TestListener(LISTENER_FUTURE_REFERENCE)); environment.lifecycle().addEventListener(new TestListener(LISTENER_FUTURE_REFERENCE));
environment.servlets().addFilter("RemoteAddressFilter", new RemoteAddressFilter(true))
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
environment.jersey().register(new TestResource()); environment.jersey().register(new TestResource());
environment.jersey().register(new TestAuthFilter()); environment.jersey().register(new TestAuthFilter());
@ -185,9 +259,11 @@ class MetricsHttpChannelListenerIntegrationTest {
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);
WebSocketResourceProviderFactory<TestPrincipal> webSocketServlet = new WebSocketResourceProviderFactory<>( WebSocketResourceProviderFactory<TestPrincipal> webSocketServlet = new WebSocketResourceProviderFactory<>(
webSocketEnvironment, TestPrincipal.class, webSocketConfiguration, "ignored"); webSocketEnvironment, TestPrincipal.class, webSocketConfiguration,
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);
environment.servlets().addServlet("WebSocket", webSocketServlet); environment.servlets().addServlet("WebSocket", webSocketServlet)
.addMapping("/v1/websocket");
} }
} }
@ -273,4 +349,5 @@ class MetricsHttpChannelListenerIntegrationTest {
return false; return false;
} }
} }
} }

View File

@ -11,12 +11,14 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import com.google.common.net.HttpHeaders; import com.google.common.net.HttpHeaders;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tag;
import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
@ -27,29 +29,39 @@ import org.glassfish.jersey.server.ExtendedUriInfo;
import org.glassfish.jersey.uri.UriTemplate; import org.glassfish.jersey.uri.UriTemplate;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
class MetricsHttpChannelListenerTest { class MetricsHttpChannelListenerTest {
private MeterRegistry meterRegistry; private MeterRegistry meterRegistry;
private Counter counter; private Counter requestCounter;
private Counter requestsByVersionCounter;
private ClientReleaseManager clientReleaseManager;
private MetricsHttpChannelListener listener; private MetricsHttpChannelListener listener;
@BeforeEach @BeforeEach
void setup() { void setup() {
meterRegistry = mock(MeterRegistry.class); meterRegistry = mock(MeterRegistry.class);
counter = mock(Counter.class); requestCounter = mock(Counter.class);
requestsByVersionCounter = mock(Counter.class);
final ClientReleaseManager clientReleaseManager = mock(ClientReleaseManager.class); when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), any(Iterable.class)))
when(clientReleaseManager.isVersionActive(any(), any())).thenReturn(false); .thenReturn(requestCounter);
listener = new MetricsHttpChannelListener(meterRegistry, clientReleaseManager); when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUESTS_BY_VERSION_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestsByVersionCounter);
clientReleaseManager = mock(ClientReleaseManager.class);
listener = new MetricsHttpChannelListener(meterRegistry, clientReleaseManager, Collections.emptySet());
} }
@Test @Test
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
void testOnEvent() { void testRequests() {
final String path = "/test"; final String path = "/test";
final String method = "GET"; final String method = "GET";
final int statusCode = 200; final int statusCode = 200;
@ -70,17 +82,15 @@ class MetricsHttpChannelListenerTest {
when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path))); when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path)));
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class); final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), any(Iterable.class)))
.thenReturn(counter);
listener.onComplete(request); listener.onComplete(request);
verify(requestCounter).increment();
verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture()); verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture());
final Iterable<Tag> tagIterable = tagCaptor.getValue();
final Set<Tag> tags = new HashSet<>(); final Set<Tag> tags = new HashSet<>();
for (final Tag tag : tagCaptor.getValue()) {
for (final Tag tag : tagIterable) {
tags.add(tag); tags.add(tag);
} }
@ -92,4 +102,50 @@ class MetricsHttpChannelListenerTest {
tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase()))); tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase())));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
} }
@ParameterizedTest
@ValueSource(booleans = {true, false})
@SuppressWarnings("unchecked")
void testRequestsByVersion(final boolean versionActive) {
when(clientReleaseManager.isVersionActive(any(), any())).thenReturn(versionActive);
final String path = "/test";
final String method = "GET";
final int statusCode = 200;
final HttpURI httpUri = mock(HttpURI.class);
when(httpUri.getPath()).thenReturn(path);
final Request request = mock(Request.class);
when(request.getMethod()).thenReturn(method);
when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/6.53.7 (Android 8.1)");
when(request.getHttpURI()).thenReturn(httpUri);
final Response response = mock(Response.class);
when(response.getStatus()).thenReturn(statusCode);
when(request.getResponse()).thenReturn(response);
final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class);
when(request.getAttribute(MetricsHttpChannelListener.URI_INFO_PROPERTY_NAME)).thenReturn(extendedUriInfo);
when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path)));
listener.onComplete(request);
if (versionActive) {
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUESTS_BY_VERSION_COUNTER_NAME),
tagCaptor.capture());
final Set<Tag> tags = new HashSet<>();
tags.clear();
for (final Tag tag : tagCaptor.getValue()) {
tags.add(tag);
}
assertEquals(2, tags.size());
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.VERSION_TAG, "6.53.7")));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
} else {
verifyNoInteractions(requestsByVersionCounter);
}
}
} }