Migrate WebSocket duration instrumentation to `OpenWebSocketCounter`

This commit is contained in:
Jon Chambers 2024-10-01 15:31:52 -04:00 committed by ravi-signal
parent 68814813c3
commit 100955a7db
4 changed files with 65 additions and 77 deletions

View File

@ -3,6 +3,7 @@ package org.whispersystems.textsecuregcm.metrics;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag; import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags; import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
@ -11,6 +12,7 @@ import java.util.Arrays;
import java.util.EnumMap; import java.util.EnumMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
public class OpenWebSocketCounter { public class OpenWebSocketCounter {
@ -18,14 +20,17 @@ public class OpenWebSocketCounter {
private final Map<ClientPlatform, AtomicInteger> openWebsocketsByClientPlatform; private final Map<ClientPlatform, AtomicInteger> openWebsocketsByClientPlatform;
private final AtomicInteger openWebsocketsFromUnknownPlatforms; private final AtomicInteger openWebsocketsFromUnknownPlatforms;
public OpenWebSocketCounter(final String openWebSocketGaugeName) { private final Map<ClientPlatform, Timer> durationTimersByClientPlatform;
this(openWebSocketGaugeName, Tags.empty()); private final Timer durationTimerForUnknownPlatforms;
public OpenWebSocketCounter(final String openWebSocketGaugeName, final String durationTimerName) {
this(openWebSocketGaugeName, durationTimerName, Tags.empty());
} }
public OpenWebSocketCounter(final String openWebSocketGaugeName, final Tags tags) { public OpenWebSocketCounter(final String openWebSocketGaugeName, final String durationTimerName, final Tags tags) {
openWebsocketsByClientPlatform = Arrays.stream(ClientPlatform.values()) openWebsocketsByClientPlatform = Arrays.stream(ClientPlatform.values())
.collect(Collectors.toMap( .collect(Collectors.toMap(
clientPlatform -> clientPlatform, Function.identity(),
clientPlatform -> buildGauge(openWebSocketGaugeName, clientPlatform.name().toLowerCase(), tags), clientPlatform -> buildGauge(openWebSocketGaugeName, clientPlatform.name().toLowerCase(), tags),
(a, b) -> { (a, b) -> {
throw new AssertionError("Duplicate client platform enumeration key"); throw new AssertionError("Duplicate client platform enumeration key");
@ -34,6 +39,18 @@ public class OpenWebSocketCounter {
)); ));
openWebsocketsFromUnknownPlatforms = buildGauge(openWebSocketGaugeName, "unknown", tags); openWebsocketsFromUnknownPlatforms = buildGauge(openWebSocketGaugeName, "unknown", tags);
durationTimersByClientPlatform = Arrays.stream(ClientPlatform.values())
.collect(Collectors.toMap(
clientPlatform -> clientPlatform,
clientPlatform -> buildTimer(durationTimerName, clientPlatform.name().toLowerCase(), tags),
(a, b) -> {
throw new AssertionError("Duplicate client platform enumeration key");
},
() -> new EnumMap<>(ClientPlatform.class)
));
durationTimerForUnknownPlatforms = buildTimer(durationTimerName, "unknown", tags);
} }
private static AtomicInteger buildGauge(final String gaugeName, final String clientPlatformName, final Tags tags) { private static AtomicInteger buildGauge(final String gaugeName, final String clientPlatformName, final Tags tags) {
@ -42,18 +59,45 @@ public class OpenWebSocketCounter {
new AtomicInteger(0)); new AtomicInteger(0));
} }
private static Timer buildTimer(final String timerName, final String clientPlatformName, final Tags tags) {
return Timer.builder(timerName)
.publishPercentileHistogram(true)
.tags(tags.and(Tag.of(UserAgentTagUtil.PLATFORM_TAG, clientPlatformName)))
.register(Metrics.globalRegistry);
}
public void countOpenWebSocket(final WebSocketSessionContext context) { public void countOpenWebSocket(final WebSocketSessionContext context) {
final AtomicInteger openWebSocketCounter = getOpenWebsocketCounter(context.getClient().getUserAgent()); final Timer.Sample sample = Timer.start();
// We have to jump through some hoops here to have something "effectively final" for the close listener, but
// assignable from a `catch` block.
final AtomicInteger openWebSocketCounter;
final Timer durationTimer;
{
AtomicInteger calculatedOpenWebSocketCounter;
Timer calculatedDurationTimer;
try {
final ClientPlatform clientPlatform =
UserAgentUtil.parseUserAgentString(context.getClient().getUserAgent()).getPlatform();
calculatedOpenWebSocketCounter = openWebsocketsByClientPlatform.get(clientPlatform);
calculatedDurationTimer = durationTimersByClientPlatform.get(clientPlatform);
} catch (final UnrecognizedUserAgentException e) {
calculatedOpenWebSocketCounter = openWebsocketsFromUnknownPlatforms;
calculatedDurationTimer = durationTimerForUnknownPlatforms;
}
openWebSocketCounter = calculatedOpenWebSocketCounter;
durationTimer = calculatedDurationTimer;
}
openWebSocketCounter.incrementAndGet(); openWebSocketCounter.incrementAndGet();
context.addWebsocketClosedListener((context1, statusCode, reason) -> openWebSocketCounter.decrementAndGet());
}
private AtomicInteger getOpenWebsocketCounter(final String userAgentString) { context.addWebsocketClosedListener((context1, statusCode, reason) -> {
try { sample.stop(durationTimer);
return openWebsocketsByClientPlatform.get(UserAgentUtil.parseUserAgentString(userAgentString).getPlatform()); openWebSocketCounter.decrementAndGet();
} catch (final UnrecognizedUserAgentException e) { });
return openWebsocketsFromUnknownPlatforms;
}
} }
} }

View File

@ -7,11 +7,7 @@ package org.whispersystems.textsecuregcm.websocket;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags; import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer;
import java.util.EnumMap;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -22,7 +18,6 @@ import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter; import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler; import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
@ -30,9 +25,6 @@ import org.whispersystems.textsecuregcm.push.ReceiptSender;
import org.whispersystems.textsecuregcm.redis.RedisOperation; import org.whispersystems.textsecuregcm.redis.RedisOperation;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.MessagesManager;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.websocket.session.WebSocketSessionContext; import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener; import org.whispersystems.websocket.setup.WebSocketConnectListener;
import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Scheduler;
@ -63,11 +55,6 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private final OpenWebSocketCounter openAuthenticatedWebSocketCounter; private final OpenWebSocketCounter openAuthenticatedWebSocketCounter;
private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter; private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter;
private final Map<ClientPlatform, Timer> durationTimersByClientPlatform;
private final Map<ClientPlatform, Timer> unauthenticatedDurationTimersByClientPlatform;
private final Timer durationTimerForUnknownPlatforms;
private final Timer unauthenticatedDurationTimerForUnknownPlatforms;
public AuthenticatedConnectListener(ReceiptSender receiptSender, public AuthenticatedConnectListener(ReceiptSender receiptSender,
MessagesManager messagesManager, MessagesManager messagesManager,
MessageMetrics messageMetrics, MessageMetrics messageMetrics,
@ -89,41 +76,17 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
this.clientReleaseManager = clientReleaseManager; this.clientReleaseManager = clientReleaseManager;
this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor; this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor;
durationTimersByClientPlatform = new EnumMap<>(ClientPlatform.class); openAuthenticatedWebSocketCounter =
unauthenticatedDurationTimersByClientPlatform = new EnumMap<>(ClientPlatform.class); new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, CONNECTED_DURATION_TIMER_NAME, Tags.of(AUTHENTICATED_TAG_NAME, "true"));
final Tags authenticatedTag = Tags.of(AUTHENTICATED_TAG_NAME, "true"); openUnauthenticatedWebSocketCounter =
final Tags unauthenticatedTag = Tags.of(AUTHENTICATED_TAG_NAME, "false"); new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, CONNECTED_DURATION_TIMER_NAME, Tags.of(AUTHENTICATED_TAG_NAME, "false"));
openAuthenticatedWebSocketCounter = new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, authenticatedTag);
openUnauthenticatedWebSocketCounter = new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, unauthenticatedTag);
for (final ClientPlatform clientPlatform : ClientPlatform.values()) {
final Tags clientPlatformTag = Tags.of(UserAgentTagUtil.PLATFORM_TAG, clientPlatform.name().toLowerCase());
durationTimersByClientPlatform.put(clientPlatform,
Metrics.timer(CONNECTED_DURATION_TIMER_NAME, clientPlatformTag.and(authenticatedTag)));
unauthenticatedDurationTimersByClientPlatform.put(clientPlatform,
Metrics.timer(CONNECTED_DURATION_TIMER_NAME, clientPlatformTag.and(unauthenticatedTag)));
}
final Tags unrecognizedPlatform = Tags.of(UserAgentTagUtil.PLATFORM_TAG, "unrecognized");
durationTimerForUnknownPlatforms = Metrics.timer(CONNECTED_DURATION_TIMER_NAME,
unrecognizedPlatform.and(authenticatedTag));
unauthenticatedDurationTimerForUnknownPlatforms = Metrics.timer(CONNECTED_DURATION_TIMER_NAME,
unrecognizedPlatform.and(unauthenticatedTag));
} }
@Override @Override
public void onWebSocketConnect(WebSocketSessionContext context) { public void onWebSocketConnect(WebSocketSessionContext context) {
final boolean authenticated = (context.getAuthenticated() != null); final boolean authenticated = (context.getAuthenticated() != null);
final String userAgent = context.getClient().getUserAgent();
final Timer connectionTimer = getConnectionTimer(userAgent, authenticated);
final OpenWebSocketCounter openWebSocketCounter = final OpenWebSocketCounter openWebSocketCounter =
authenticated ? openAuthenticatedWebSocketCounter : openUnauthenticatedWebSocketCounter; authenticated ? openAuthenticatedWebSocketCounter : openUnauthenticatedWebSocketCounter;
@ -131,7 +94,6 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
if (authenticated) { if (authenticated) {
final AuthenticatedDevice auth = context.getAuthenticated(AuthenticatedDevice.class); final AuthenticatedDevice auth = context.getAuthenticated(AuthenticatedDevice.class);
final Timer.Sample sample = Timer.start();
final WebSocketConnection connection = new WebSocketConnection(receiptSender, final WebSocketConnection connection = new WebSocketConnection(receiptSender,
messagesManager, messagesManager,
messageMetrics, messageMetrics,
@ -147,8 +109,6 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
final AtomicReference<ScheduledFuture<?>> renewPresenceFutureReference = new AtomicReference<>(); final AtomicReference<ScheduledFuture<?>> renewPresenceFutureReference = new AtomicReference<>();
context.addWebsocketClosedListener((closingContext, statusCode, reason) -> { context.addWebsocketClosedListener((closingContext, statusCode, reason) -> {
sample.stop(connectionTimer);
final ScheduledFuture<?> renewPresenceFuture = renewPresenceFutureReference.get(); final ScheduledFuture<?> renewPresenceFuture = renewPresenceFutureReference.get();
if (renewPresenceFuture != null) { if (renewPresenceFuture != null) {
@ -197,23 +157,6 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
log.warn("Failed to initialize websocket", e); log.warn("Failed to initialize websocket", e);
context.getClient().close(1011, "Unexpected error initializing connection"); context.getClient().close(1011, "Unexpected error initializing connection");
} }
} else {
final Timer.Sample sample = Timer.start();
context.addWebsocketClosedListener((context1, statusCode, reason) -> sample.stop(connectionTimer));
}
}
private Timer getConnectionTimer(final String userAgentString,
final boolean authenticated) {
try {
final ClientPlatform platform = UserAgentUtil.parseUserAgentString(userAgentString).getPlatform();
return authenticated
? durationTimersByClientPlatform.get(platform)
: unauthenticatedDurationTimersByClientPlatform.get(platform);
} catch (final UnrecognizedUserAgentException e) {
return authenticated
? durationTimerForUnknownPlatforms
: unauthenticatedDurationTimerForUnknownPlatforms;
} }
} }
} }

View File

@ -42,7 +42,8 @@ public class ProvisioningConnectListener implements WebSocketConnectListener {
public ProvisioningConnectListener(final ProvisioningManager provisioningManager) { public ProvisioningConnectListener(final ProvisioningManager provisioningManager) {
this.provisioningManager = provisioningManager; this.provisioningManager = provisioningManager;
this.openWebSocketCounter = new OpenWebSocketCounter(MetricsUtil.name(getClass(), "openWebsockets")); this.openWebSocketCounter = new OpenWebSocketCounter(MetricsUtil.name(getClass(), "openWebsockets"),
MetricsUtil.name(getClass(), "sessionDuration"));
} }
@Override @Override

View File

@ -142,7 +142,7 @@ class WebSocketConnectionTest {
// authenticated - valid user // authenticated - valid user
connectListener.onWebSocketConnect(sessionContext); connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext, times(2)).addWebsocketClosedListener( verify(sessionContext, times(1)).addWebsocketClosedListener(
any(WebSocketSessionContext.WebSocketEventListener.class)); any(WebSocketSessionContext.WebSocketEventListener.class));
// unauthenticated // unauthenticated
@ -152,7 +152,7 @@ class WebSocketConnectionTest {
assertFalse(account.invalidCredentialsProvided()); assertFalse(account.invalidCredentialsProvided());
connectListener.onWebSocketConnect(sessionContext); connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext, times(4)).addWebsocketClosedListener( verify(sessionContext, times(2)).addWebsocketClosedListener(
any(WebSocketSessionContext.WebSocketEventListener.class)); any(WebSocketSessionContext.WebSocketEventListener.class));
verifyNoMoreInteractions(messagesManager); verifyNoMoreInteractions(messagesManager);