From 3ed142d0a9e5c7266b1515baf2cb4a1fb2a3cd09 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Mon, 30 Sep 2024 20:03:03 -0400 Subject: [PATCH] Introduce `OpenWebSocketCounter` --- .../metrics/OpenWebSocketCounter.java | 59 +++++++++++++++++++ .../AuthenticatedConnectListener.java | 59 ++++--------------- .../ProvisioningConnectListener.java | 56 ++++-------------- .../websocket/WebSocketConnectionTest.java | 4 +- 4 files changed, 84 insertions(+), 94 deletions(-) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/metrics/OpenWebSocketCounter.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/OpenWebSocketCounter.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/OpenWebSocketCounter.java new file mode 100644 index 000000000..dcc25d6e0 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/OpenWebSocketCounter.java @@ -0,0 +1,59 @@ +package org.whispersystems.textsecuregcm.metrics; + +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.Tags; +import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; +import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; +import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; +import org.whispersystems.websocket.session.WebSocketSessionContext; +import java.util.Arrays; +import java.util.EnumMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +public class OpenWebSocketCounter { + + private final Map openWebsocketsByClientPlatform; + private final AtomicInteger openWebsocketsFromUnknownPlatforms; + + public OpenWebSocketCounter(final String openWebSocketGaugeName) { + this(openWebSocketGaugeName, Tags.empty()); + } + + public OpenWebSocketCounter(final String openWebSocketGaugeName, final Tags tags) { + openWebsocketsByClientPlatform = Arrays.stream(ClientPlatform.values()) + .collect(Collectors.toMap( + clientPlatform -> clientPlatform, + clientPlatform -> buildGauge(openWebSocketGaugeName, clientPlatform.name().toLowerCase(), tags), + (a, b) -> { + throw new AssertionError("Duplicate client platform enumeration key"); + }, + () -> new EnumMap<>(ClientPlatform.class) + )); + + openWebsocketsFromUnknownPlatforms = buildGauge(openWebSocketGaugeName, "unknown", tags); + } + + private static AtomicInteger buildGauge(final String gaugeName, final String clientPlatformName, final Tags tags) { + return Metrics.gauge(gaugeName, + tags.and(Tag.of(UserAgentTagUtil.PLATFORM_TAG, clientPlatformName)), + new AtomicInteger(0)); + } + + public void countOpenWebSocket(final WebSocketSessionContext context) { + final AtomicInteger openWebSocketCounter = getOpenWebsocketCounter(context.getClient().getUserAgent()); + + openWebSocketCounter.incrementAndGet(); + context.addWebsocketClosedListener((context1, statusCode, reason) -> openWebSocketCounter.decrementAndGet()); + } + + private AtomicInteger getOpenWebsocketCounter(final String userAgentString) { + try { + return openWebsocketsByClientPlatform.get(UserAgentUtil.parseUserAgentString(userAgentString).getPlatform()); + } catch (final UnrecognizedUserAgentException e) { + return openWebsocketsFromUnknownPlatforms; + } + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index 25520f43e..1e791f384 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -15,13 +15,13 @@ import java.util.Map; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.limits.MessageDeliveryLoopMonitor; import org.whispersystems.textsecuregcm.metrics.MessageMetrics; +import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter; import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.PushNotificationManager; @@ -60,13 +60,11 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { private final ClientReleaseManager clientReleaseManager; private final MessageDeliveryLoopMonitor messageDeliveryLoopMonitor; - private final Map openAuthenticatedWebsocketsByClientPlatform; - private final Map openUnauthenticatedWebsocketsByClientPlatform; + private final OpenWebSocketCounter openAuthenticatedWebSocketCounter; + private final OpenWebSocketCounter openUnauthenticatedWebSocketCounter; + private final Map durationTimersByClientPlatform; private final Map unauthenticatedDurationTimersByClientPlatform; - - private final AtomicInteger openAuthenticatedWebsocketsFromUnknownPlatforms; - private final AtomicInteger openUnauthenticatedWebsocketsFromUnknownPlatforms; private final Timer durationTimerForUnknownPlatforms; private final Timer unauthenticatedDurationTimerForUnknownPlatforms; @@ -91,24 +89,17 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { this.clientReleaseManager = clientReleaseManager; this.messageDeliveryLoopMonitor = messageDeliveryLoopMonitor; - openAuthenticatedWebsocketsByClientPlatform = new EnumMap<>(ClientPlatform.class); - openUnauthenticatedWebsocketsByClientPlatform = new EnumMap<>(ClientPlatform.class); durationTimersByClientPlatform = new EnumMap<>(ClientPlatform.class); unauthenticatedDurationTimersByClientPlatform = new EnumMap<>(ClientPlatform.class); final Tags authenticatedTag = Tags.of(AUTHENTICATED_TAG_NAME, "true"); final Tags unauthenticatedTag = Tags.of(AUTHENTICATED_TAG_NAME, "false"); + openAuthenticatedWebSocketCounter = new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, authenticatedTag); + openUnauthenticatedWebSocketCounter = new OpenWebSocketCounter(OPEN_WEBSOCKET_GAUGE_NAME, unauthenticatedTag); + for (final ClientPlatform clientPlatform : ClientPlatform.values()) { - openAuthenticatedWebsocketsByClientPlatform.put(clientPlatform, new AtomicInteger(0)); - openUnauthenticatedWebsocketsByClientPlatform.put(clientPlatform, new AtomicInteger(0)); - final Tags clientPlatformTag = Tags.of(UserAgentTagUtil.PLATFORM_TAG, clientPlatform.name().toLowerCase()); - Metrics.gauge(OPEN_WEBSOCKET_GAUGE_NAME, clientPlatformTag.and(authenticatedTag), - openAuthenticatedWebsocketsByClientPlatform.get(clientPlatform)); - - Metrics.gauge(OPEN_WEBSOCKET_GAUGE_NAME, clientPlatformTag.and(unauthenticatedTag), - openUnauthenticatedWebsocketsByClientPlatform.get(clientPlatform)); durationTimersByClientPlatform.put(clientPlatform, Metrics.timer(CONNECTED_DURATION_TIMER_NAME, clientPlatformTag.and(authenticatedTag))); @@ -117,15 +108,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { Metrics.timer(CONNECTED_DURATION_TIMER_NAME, clientPlatformTag.and(unauthenticatedTag))); } - openAuthenticatedWebsocketsFromUnknownPlatforms = new AtomicInteger(0); - openUnauthenticatedWebsocketsFromUnknownPlatforms = new AtomicInteger(0); - final Tags unrecognizedPlatform = Tags.of(UserAgentTagUtil.PLATFORM_TAG, "unrecognized"); - Metrics.gauge(OPEN_WEBSOCKET_GAUGE_NAME, unrecognizedPlatform.and(authenticatedTag), - openAuthenticatedWebsocketsFromUnknownPlatforms); - - Metrics.gauge(OPEN_WEBSOCKET_GAUGE_NAME, unrecognizedPlatform.and(unauthenticatedTag), - openUnauthenticatedWebsocketsFromUnknownPlatforms); durationTimerForUnknownPlatforms = Metrics.timer(CONNECTED_DURATION_TIMER_NAME, unrecognizedPlatform.and(authenticatedTag)); @@ -139,9 +122,13 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { final boolean authenticated = (context.getAuthenticated() != null); final String userAgent = context.getClient().getUserAgent(); - final AtomicInteger openWebsocketAtomicInteger = getOpenWebsocketCounter(userAgent, authenticated); final Timer connectionTimer = getConnectionTimer(userAgent, authenticated); + final OpenWebSocketCounter openWebSocketCounter = + authenticated ? openAuthenticatedWebSocketCounter : openUnauthenticatedWebSocketCounter; + + openWebSocketCounter.countOpenWebSocket(context); + if (authenticated) { final AuthenticatedDevice auth = context.getAuthenticated(AuthenticatedDevice.class); final Timer.Sample sample = Timer.start(); @@ -157,12 +144,9 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { clientReleaseManager, messageDeliveryLoopMonitor); - openWebsocketAtomicInteger.incrementAndGet(); - final AtomicReference> renewPresenceFutureReference = new AtomicReference<>(); context.addWebsocketClosedListener((closingContext, statusCode, reason) -> { - openWebsocketAtomicInteger.decrementAndGet(); sample.stop(connectionTimer); final ScheduledFuture renewPresenceFuture = renewPresenceFutureReference.get(); @@ -214,25 +198,8 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { context.getClient().close(1011, "Unexpected error initializing connection"); } } else { - openWebsocketAtomicInteger.incrementAndGet(); final Timer.Sample sample = Timer.start(); - context.addWebsocketClosedListener((context1, statusCode, reason) -> { - openWebsocketAtomicInteger.decrementAndGet(); - sample.stop(connectionTimer); - }); - } - } - - private AtomicInteger getOpenWebsocketCounter(final String userAgentString, final boolean authenticated) { - try { - final ClientPlatform platform = UserAgentUtil.parseUserAgentString(userAgentString).getPlatform(); - return authenticated - ? openAuthenticatedWebsocketsByClientPlatform.get(platform) - : openUnauthenticatedWebsocketsByClientPlatform.get(platform); - } catch (final UnrecognizedUserAgentException e) { - return authenticated - ? openAuthenticatedWebsocketsFromUnknownPlatforms - : openUnauthenticatedWebsocketsFromUnknownPlatforms; + context.addWebsocketClosedListener((context1, statusCode, reason) -> sample.stop(connectionTimer)); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java index 7b21a96bc..9ca3f43d3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/ProvisioningConnectListener.java @@ -6,30 +6,20 @@ package org.whispersystems.textsecuregcm.websocket; import com.google.common.annotations.VisibleForTesting; -import io.micrometer.core.instrument.Metrics; -import io.micrometer.core.instrument.Tags; +import java.security.SecureRandom; +import java.util.Base64; +import java.util.List; +import java.util.Optional; import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.ProvisioningMessage; -import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.metrics.OpenWebSocketCounter; import org.whispersystems.textsecuregcm.push.ProvisioningManager; import org.whispersystems.textsecuregcm.storage.PubSubProtos; import org.whispersystems.textsecuregcm.util.HeaderUtils; -import org.whispersystems.textsecuregcm.util.ua.ClientPlatform; -import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException; -import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil; import org.whispersystems.websocket.session.WebSocketSessionContext; import org.whispersystems.websocket.setup.WebSocketConnectListener; -import java.security.SecureRandom; -import java.util.Arrays; -import java.util.Base64; -import java.util.EnumMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.atomic.AtomicInteger; - -import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; /** * A "provisioning WebSocket" provides a mechanism for sending a caller-defined provisioning message from the primary @@ -48,38 +38,20 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; public class ProvisioningConnectListener implements WebSocketConnectListener { private final ProvisioningManager provisioningManager; - - private final Map openWebsocketsByClientPlatform; - private final AtomicInteger openWebsocketsFromUnknownPlatforms; - - private static final String OPEN_WEBSOCKET_GAUGE_NAME = name(ProvisioningConnectListener.class, "openWebsockets"); + private final OpenWebSocketCounter openWebSocketCounter; public ProvisioningConnectListener(final ProvisioningManager provisioningManager) { this.provisioningManager = provisioningManager; - - openWebsocketsByClientPlatform = new EnumMap<>(ClientPlatform.class); - - Arrays.stream(ClientPlatform.values()) - .forEach(clientPlatform -> openWebsocketsByClientPlatform.put(clientPlatform, - Metrics.gauge(OPEN_WEBSOCKET_GAUGE_NAME, - Tags.of(UserAgentTagUtil.PLATFORM_TAG, clientPlatform.name().toLowerCase()), - new AtomicInteger(0)))); - - openWebsocketsFromUnknownPlatforms = Metrics.gauge(OPEN_WEBSOCKET_GAUGE_NAME, - Tags.of(UserAgentTagUtil.PLATFORM_TAG, "unrecognized"), - new AtomicInteger(0)); + this.openWebSocketCounter = new OpenWebSocketCounter(MetricsUtil.name(getClass(), "openWebsockets")); } @Override public void onWebSocketConnect(WebSocketSessionContext context) { + openWebSocketCounter.countOpenWebSocket(context); + final String provisioningAddress = generateProvisioningAddress(); context.addWebsocketClosedListener((context1, statusCode, reason) -> provisioningManager.removeListener(provisioningAddress)); - getOpenWebsocketCounter(context.getClient().getUserAgent()).incrementAndGet(); - - context.addWebsocketClosedListener((context1, statusCode, reason) -> - getOpenWebsocketCounter(context.getClient().getUserAgent()).decrementAndGet()); - provisioningManager.addListener(provisioningAddress, message -> { assert message.getType() == PubSubProtos.PubSubMessage.Type.DELIVER; @@ -102,12 +74,4 @@ public class ProvisioningConnectListener implements WebSocketConnectListener { return Base64.getUrlEncoder().encodeToString(provisioningAddress); } - - private AtomicInteger getOpenWebsocketCounter(final String userAgentString) { - try { - return openWebsocketsByClientPlatform.get(UserAgentUtil.parseUserAgentString(userAgentString).getPlatform()); - } catch (final UnrecognizedUserAgentException e) { - return openWebsocketsFromUnknownPlatforms; - } - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index 08c83f60e..a25e397b4 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -142,7 +142,7 @@ class WebSocketConnectionTest { // authenticated - valid user connectListener.onWebSocketConnect(sessionContext); - verify(sessionContext, times(1)).addWebsocketClosedListener( + verify(sessionContext, times(2)).addWebsocketClosedListener( any(WebSocketSessionContext.WebSocketEventListener.class)); // unauthenticated @@ -152,7 +152,7 @@ class WebSocketConnectionTest { assertFalse(account.invalidCredentialsProvided()); connectListener.onWebSocketConnect(sessionContext); - verify(sessionContext, times(2)).addWebsocketClosedListener( + verify(sessionContext, times(4)).addWebsocketClosedListener( any(WebSocketSessionContext.WebSocketEventListener.class)); verifyNoMoreInteractions(messagesManager);