Introduce `OpenWebSocketCounter`

This commit is contained in:
Jon Chambers 2024-09-30 20:03:03 -04:00 committed by Jon Chambers
parent 581e61a85b
commit 3ed142d0a9
4 changed files with 84 additions and 94 deletions

View File

@ -0,0 +1,59 @@
package org.whispersystems.textsecuregcm.metrics;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tag;
import io.micrometer.core.instrument.Tags;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.websocket.session.WebSocketSessionContext;
import java.util.Arrays;
import java.util.EnumMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
public class OpenWebSocketCounter {
private final Map<ClientPlatform, AtomicInteger> openWebsocketsByClientPlatform;
private final AtomicInteger openWebsocketsFromUnknownPlatforms;
public OpenWebSocketCounter(final String openWebSocketGaugeName) {
this(openWebSocketGaugeName, Tags.empty());
}
public OpenWebSocketCounter(final String openWebSocketGaugeName, final Tags tags) {
openWebsocketsByClientPlatform = Arrays.stream(ClientPlatform.values())
.collect(Collectors.toMap(
clientPlatform -> clientPlatform,
clientPlatform -> buildGauge(openWebSocketGaugeName, clientPlatform.name().toLowerCase(), tags),
(a, b) -> {
throw new AssertionError("Duplicate client platform enumeration key");
},
() -> new EnumMap<>(ClientPlatform.class)
));
openWebsocketsFromUnknownPlatforms = buildGauge(openWebSocketGaugeName, "unknown", tags);
}
private static AtomicInteger buildGauge(final String gaugeName, final String clientPlatformName, final Tags tags) {
return Metrics.gauge(gaugeName,
tags.and(Tag.of(UserAgentTagUtil.PLATFORM_TAG, clientPlatformName)),
new AtomicInteger(0));
}
public void countOpenWebSocket(final WebSocketSessionContext context) {
final AtomicInteger openWebSocketCounter = getOpenWebsocketCounter(context.getClient().getUserAgent());
openWebSocketCounter.incrementAndGet();
context.addWebsocketClosedListener((context1, statusCode, reason) -> openWebSocketCounter.decrementAndGet());
}
private AtomicInteger getOpenWebsocketCounter(final String userAgentString) {
try {
return openWebsocketsByClientPlatform.get(UserAgentUtil.parseUserAgentString(userAgentString).getPlatform());
} catch (final UnrecognizedUserAgentException e) {
return openWebsocketsFromUnknownPlatforms;
}
}
}

View File

@ -15,13 +15,13 @@ import java.util.Map;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor;
import org.whispersystems.textsecuregcm.metrics.MessageMetrics; import org.whispersystems.textsecuregcm.metrics.MessageMetrics;
import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager;
@ -60,13 +60,11 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
private final ClientReleaseManager clientReleaseManager; private final ClientReleaseManager clientReleaseManager;
private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor; private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor;
private final Map<ClientPlatform, AtomicInteger> openAuthenticatedWebsocketsByClientPlatform; private final OpenWebSocketCounter openAuthenticatedWebSocketCounter;
private final Map<ClientPlatform, AtomicInteger> openUnauthenticatedWebsocketsByClientPlatform; private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter;
private final Map<ClientPlatform, Timer> durationTimersByClientPlatform; private final Map<ClientPlatform, Timer> durationTimersByClientPlatform;
private final Map<ClientPlatform, Timer> unauthenticatedDurationTimersByClientPlatform; private final Map<ClientPlatform, Timer> unauthenticatedDurationTimersByClientPlatform;
private final AtomicInteger openAuthenticatedWebsocketsFromUnknownPlatforms;
private final AtomicInteger openUnauthenticatedWebsocketsFromUnknownPlatforms;
private final Timer durationTimerForUnknownPlatforms; private final Timer durationTimerForUnknownPlatforms;
private final Timer unauthenticatedDurationTimerForUnknownPlatforms; private final Timer unauthenticatedDurationTimerForUnknownPlatforms;
@ -91,24 +89,17 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
this.clientReleaseManager = clientReleaseManager; this.clientReleaseManager = clientReleaseManager;
this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor; this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor;
openAuthenticatedWebsocketsByClientPlatform = new EnumMap<>(ClientPlatform.class);
openUnauthenticatedWebsocketsByClientPlatform = new EnumMap<>(ClientPlatform.class);
durationTimersByClientPlatform = new EnumMap<>(ClientPlatform.class); durationTimersByClientPlatform = new EnumMap<>(ClientPlatform.class);
unauthenticatedDurationTimersByClientPlatform = new EnumMap<>(ClientPlatform.class); unauthenticatedDurationTimersByClientPlatform = new EnumMap<>(ClientPlatform.class);
final Tags authenticatedTag = Tags.of(AUTHENTICATED_TAG_NAME, "true"); final Tags authenticatedTag = Tags.of(AUTHENTICATED_TAG_NAME, "true");
final Tags unauthenticatedTag = Tags.of(AUTHENTICATED_TAG_NAME, "false"); final Tags unauthenticatedTag = Tags.of(AUTHENTICATED_TAG_NAME, "false");
openAuthenticatedWebSocketCounter = new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, authenticatedTag);
openUnauthenticatedWebSocketCounter = new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, unauthenticatedTag);
for (final ClientPlatform clientPlatform : ClientPlatform.values()) { for (final ClientPlatform clientPlatform : ClientPlatform.values()) {
openAuthenticatedWebsocketsByClientPlatform.put(clientPlatform, new AtomicInteger(0));
openUnauthenticatedWebsocketsByClientPlatform.put(clientPlatform, new AtomicInteger(0));
final Tags clientPlatformTag = Tags.of(UserAgentTagUtil.PLATFORM_TAG, clientPlatform.name().toLowerCase()); final Tags clientPlatformTag = Tags.of(UserAgentTagUtil.PLATFORM_TAG, clientPlatform.name().toLowerCase());
Metrics.gauge(OPEN_WEBSOCKET_GAUGE_NAME, clientPlatformTag.and(authenticatedTag),
openAuthenticatedWebsocketsByClientPlatform.get(clientPlatform));
Metrics.gauge(OPEN_WEBSOCKET_GAUGE_NAME, clientPlatformTag.and(unauthenticatedTag),
openUnauthenticatedWebsocketsByClientPlatform.get(clientPlatform));
durationTimersByClientPlatform.put(clientPlatform, durationTimersByClientPlatform.put(clientPlatform,
Metrics.timer(CONNECTED_DURATION_TIMER_NAME, clientPlatformTag.and(authenticatedTag))); Metrics.timer(CONNECTED_DURATION_TIMER_NAME, clientPlatformTag.and(authenticatedTag)));
@ -117,15 +108,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
Metrics.timer(CONNECTED_DURATION_TIMER_NAME, clientPlatformTag.and(unauthenticatedTag))); Metrics.timer(CONNECTED_DURATION_TIMER_NAME, clientPlatformTag.and(unauthenticatedTag)));
} }
openAuthenticatedWebsocketsFromUnknownPlatforms = new AtomicInteger(0);
openUnauthenticatedWebsocketsFromUnknownPlatforms = new AtomicInteger(0);
final Tags unrecognizedPlatform = Tags.of(UserAgentTagUtil.PLATFORM_TAG, "unrecognized"); final Tags unrecognizedPlatform = Tags.of(UserAgentTagUtil.PLATFORM_TAG, "unrecognized");
Metrics.gauge(OPEN_WEBSOCKET_GAUGE_NAME, unrecognizedPlatform.and(authenticatedTag),
openAuthenticatedWebsocketsFromUnknownPlatforms);
Metrics.gauge(OPEN_WEBSOCKET_GAUGE_NAME, unrecognizedPlatform.and(unauthenticatedTag),
openUnauthenticatedWebsocketsFromUnknownPlatforms);
durationTimerForUnknownPlatforms = Metrics.timer(CONNECTED_DURATION_TIMER_NAME, durationTimerForUnknownPlatforms = Metrics.timer(CONNECTED_DURATION_TIMER_NAME,
unrecognizedPlatform.and(authenticatedTag)); unrecognizedPlatform.and(authenticatedTag));
@ -139,9 +122,13 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
final boolean authenticated = (context.getAuthenticated() != null); final boolean authenticated = (context.getAuthenticated() != null);
final String userAgent = context.getClient().getUserAgent(); final String userAgent = context.getClient().getUserAgent();
final AtomicInteger openWebsocketAtomicInteger = getOpenWebsocketCounter(userAgent, authenticated);
final Timer connectionTimer = getConnectionTimer(userAgent, authenticated); final Timer connectionTimer = getConnectionTimer(userAgent, authenticated);
final OpenWebSocketCounter openWebSocketCounter =
authenticated ? openAuthenticatedWebSocketCounter : openUnauthenticatedWebSocketCounter;
openWebSocketCounter.countOpenWebSocket(context);
if (authenticated) { if (authenticated) {
final AuthenticatedDevice auth = context.getAuthenticated(AuthenticatedDevice.class); final AuthenticatedDevice auth = context.getAuthenticated(AuthenticatedDevice.class);
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
@ -157,12 +144,9 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
clientReleaseManager, clientReleaseManager,
messageDeliveryLoopMonitor); messageDeliveryLoopMonitor);
openWebsocketAtomicInteger.incrementAndGet();
final AtomicReference<ScheduledFuture<?>> renewPresenceFutureReference = new AtomicReference<>(); final AtomicReference<ScheduledFuture<?>> renewPresenceFutureReference = new AtomicReference<>();
context.addWebsocketClosedListener((closingContext, statusCode, reason) -> { context.addWebsocketClosedListener((closingContext, statusCode, reason) -> {
openWebsocketAtomicInteger.decrementAndGet();
sample.stop(connectionTimer); sample.stop(connectionTimer);
final ScheduledFuture<?> renewPresenceFuture = renewPresenceFutureReference.get(); final ScheduledFuture<?> renewPresenceFuture = renewPresenceFutureReference.get();
@ -214,25 +198,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener {
context.getClient().close(1011, "Unexpected error initializing connection"); context.getClient().close(1011, "Unexpected error initializing connection");
} }
} else { } else {
openWebsocketAtomicInteger.incrementAndGet();
final Timer.Sample sample = Timer.start(); final Timer.Sample sample = Timer.start();
context.addWebsocketClosedListener((context1, statusCode, reason) -> { context.addWebsocketClosedListener((context1, statusCode, reason) -> sample.stop(connectionTimer));
openWebsocketAtomicInteger.decrementAndGet();
sample.stop(connectionTimer);
});
}
}
private AtomicInteger getOpenWebsocketCounter(final String userAgentString, final boolean authenticated) {
try {
final ClientPlatform platform = UserAgentUtil.parseUserAgentString(userAgentString).getPlatform();
return authenticated
? openAuthenticatedWebsocketsByClientPlatform.get(platform)
: openUnauthenticatedWebsocketsByClientPlatform.get(platform);
} catch (final UnrecognizedUserAgentException e) {
return authenticated
? openAuthenticatedWebsocketsFromUnknownPlatforms
: openUnauthenticatedWebsocketsFromUnknownPlatforms;
} }
} }

View File

@ -6,30 +6,20 @@
package org.whispersystems.textsecuregcm.websocket; package org.whispersystems.textsecuregcm.websocket;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Metrics; import java.security.SecureRandom;
import io.micrometer.core.instrument.Tags; import java.util.Base64;
import java.util.List;
import java.util.Optional;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.ProvisioningMessage; import org.whispersystems.textsecuregcm.entities.ProvisioningMessage;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter;
import org.whispersystems.textsecuregcm.push.ProvisioningManager; import org.whispersystems.textsecuregcm.push.ProvisioningManager;
import org.whispersystems.textsecuregcm.storage.PubSubProtos; import org.whispersystems.textsecuregcm.storage.PubSubProtos;
import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.HeaderUtils;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.websocket.session.WebSocketSessionContext; import org.whispersystems.websocket.session.WebSocketSessionContext;
import org.whispersystems.websocket.setup.WebSocketConnectListener; import org.whispersystems.websocket.setup.WebSocketConnectListener;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Base64;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
/** /**
* A "provisioning WebSocket" provides a mechanism for sending a caller-defined provisioning message from the primary * A "provisioning WebSocket" provides a mechanism for sending a caller-defined provisioning message from the primary
@ -48,38 +38,20 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
public class ProvisioningConnectListener implements WebSocketConnectListener { public class ProvisioningConnectListener implements WebSocketConnectListener {
private final ProvisioningManager provisioningManager; private final ProvisioningManager provisioningManager;
private final OpenWebSocketCounter openWebSocketCounter;
private final Map<ClientPlatform, AtomicInteger> openWebsocketsByClientPlatform;
private final AtomicInteger openWebsocketsFromUnknownPlatforms;
private static final String OPEN_WEBSOCKET_GAUGE_NAME = name(ProvisioningConnectListener.class, "openWebsockets");
public ProvisioningConnectListener(final ProvisioningManager provisioningManager) { public ProvisioningConnectListener(final ProvisioningManager provisioningManager) {
this.provisioningManager = provisioningManager; this.provisioningManager = provisioningManager;
this.openWebSocketCounter = new OpenWebSocketCounter(MetricsUtil.name(getClass(), "openWebsockets"));
openWebsocketsByClientPlatform = new EnumMap<>(ClientPlatform.class);
Arrays.stream(ClientPlatform.values())
.forEach(clientPlatform -> openWebsocketsByClientPlatform.put(clientPlatform,
Metrics.gauge(OPEN_WEBSOCKET_GAUGE_NAME,
Tags.of(UserAgentTagUtil.PLATFORM_TAG, clientPlatform.name().toLowerCase()),
new AtomicInteger(0))));
openWebsocketsFromUnknownPlatforms = Metrics.gauge(OPEN_WEBSOCKET_GAUGE_NAME,
Tags.of(UserAgentTagUtil.PLATFORM_TAG, "unrecognized"),
new AtomicInteger(0));
} }
@Override @Override
public void onWebSocketConnect(WebSocketSessionContext context) { public void onWebSocketConnect(WebSocketSessionContext context) {
openWebSocketCounter.countOpenWebSocket(context);
final String provisioningAddress = generateProvisioningAddress(); final String provisioningAddress = generateProvisioningAddress();
context.addWebsocketClosedListener((context1, statusCode, reason) -> provisioningManager.removeListener(provisioningAddress)); context.addWebsocketClosedListener((context1, statusCode, reason) -> provisioningManager.removeListener(provisioningAddress));
getOpenWebsocketCounter(context.getClient().getUserAgent()).incrementAndGet();
context.addWebsocketClosedListener((context1, statusCode, reason) ->
getOpenWebsocketCounter(context.getClient().getUserAgent()).decrementAndGet());
provisioningManager.addListener(provisioningAddress, message -> { provisioningManager.addListener(provisioningAddress, message -> {
assert message.getType() == PubSubProtos.PubSubMessage.Type.DELIVER; assert message.getType() == PubSubProtos.PubSubMessage.Type.DELIVER;
@ -102,12 +74,4 @@ public class ProvisioningConnectListener implements WebSocketConnectListener {
return Base64.getUrlEncoder().encodeToString(provisioningAddress); return Base64.getUrlEncoder().encodeToString(provisioningAddress);
} }
private AtomicInteger getOpenWebsocketCounter(final String userAgentString) {
try {
return openWebsocketsByClientPlatform.get(UserAgentUtil.parseUserAgentString(userAgentString).getPlatform());
} catch (final UnrecognizedUserAgentException e) {
return openWebsocketsFromUnknownPlatforms;
}
}
} }

View File

@ -142,7 +142,7 @@ class WebSocketConnectionTest {
// authenticated - valid user // authenticated - valid user
connectListener.onWebSocketConnect(sessionContext); connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext, times(1)).addWebsocketClosedListener( verify(sessionContext, times(2)).addWebsocketClosedListener(
any(WebSocketSessionContext.WebSocketEventListener.class)); any(WebSocketSessionContext.WebSocketEventListener.class));
// unauthenticated // unauthenticated
@ -152,7 +152,7 @@ class WebSocketConnectionTest {
assertFalse(account.invalidCredentialsProvided()); assertFalse(account.invalidCredentialsProvided());
connectListener.onWebSocketConnect(sessionContext); connectListener.onWebSocketConnect(sessionContext);
verify(sessionContext, times(2)).addWebsocketClosedListener( verify(sessionContext, times(4)).addWebsocketClosedListener(
any(WebSocketSessionContext.WebSocketEventListener.class)); any(WebSocketSessionContext.WebSocketEventListener.class));
verifyNoMoreInteractions(messagesManager); verifyNoMoreInteractions(messagesManager);