diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java index 6e835a362..797ffc153 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/AuthenticatedConnectListener.java @@ -52,6 +52,10 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { private static final String OPEN_WEBSOCKET_COUNTER_NAME = MetricsUtil.name(WebSocketConnection.class, "openWebsockets"); + private static final String CONNECTED_DURATION_TIMER_NAME = MetricsUtil.name(AuthenticatedConnectListener.class, + "connectedDuration"); + + private static final String AUTHENTICATED_TAG_NAME = "authenticated"; private static final long RENEW_PRESENCE_INTERVAL_MINUTES = 5; @@ -64,8 +68,15 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { private final ScheduledExecutorService scheduledExecutorService; private final Scheduler messageDeliveryScheduler; - private final Map openWebsocketsByClientPlatform; - private final AtomicInteger openWebsocketsFromUnknownPlatforms; + private final Map openAuthenticatedWebsocketsByClientPlatform; + private final Map openUnauthenticatedWebsocketsByClientPlatform; + private final Map durationTimersByClientPlatform; + private final Map unauthenticatedDurationTimersByClientPlatform; + + private final AtomicInteger openAuthenticatedWebsocketsFromUnknownPlatforms; + private final AtomicInteger openUnauthenticatedWebsocketsFromUnknownPlatforms; + private final io.micrometer.core.instrument.Timer durationTimerForUnknownPlatforms; + private final io.micrometer.core.instrument.Timer unauthenticatedDurationTimerForUnknownPlatforms; public AuthenticatedConnectListener(ReceiptSender receiptSender, MessagesManager messagesManager, @@ -80,39 +91,72 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { this.scheduledExecutorService = scheduledExecutorService; this.messageDeliveryScheduler = messageDeliveryScheduler; - openWebsocketsByClientPlatform = new EnumMap<>(ClientPlatform.class); + openAuthenticatedWebsocketsByClientPlatform = new EnumMap<>(ClientPlatform.class); + openUnauthenticatedWebsocketsByClientPlatform = new EnumMap<>(ClientPlatform.class); + durationTimersByClientPlatform = new EnumMap<>(ClientPlatform.class); + unauthenticatedDurationTimersByClientPlatform = new EnumMap<>(ClientPlatform.class); + + final Tags authenticatedTag = Tags.of(AUTHENTICATED_TAG_NAME, "true"); + final Tags unauthenticatedTag = Tags.of(AUTHENTICATED_TAG_NAME, "false"); for (final ClientPlatform clientPlatform : ClientPlatform.values()) { - openWebsocketsByClientPlatform.put(clientPlatform, new AtomicInteger(0)); + openAuthenticatedWebsocketsByClientPlatform.put(clientPlatform, new AtomicInteger(0)); + openUnauthenticatedWebsocketsByClientPlatform.put(clientPlatform, new AtomicInteger(0)); - Metrics.gauge(OPEN_WEBSOCKET_COUNTER_NAME, Tags.of(UserAgentTagUtil.PLATFORM_TAG, clientPlatform.name().toLowerCase()), - openWebsocketsByClientPlatform.get(clientPlatform)); + final Tags clientPlatformTag = Tags.of(UserAgentTagUtil.PLATFORM_TAG, clientPlatform.name().toLowerCase()); + Metrics.gauge(OPEN_WEBSOCKET_COUNTER_NAME, clientPlatformTag.and(authenticatedTag), + openAuthenticatedWebsocketsByClientPlatform.get(clientPlatform)); + + Metrics.gauge(OPEN_WEBSOCKET_COUNTER_NAME, clientPlatformTag.and(unauthenticatedTag), + openUnauthenticatedWebsocketsByClientPlatform.get(clientPlatform)); + + durationTimersByClientPlatform.put(clientPlatform, + Metrics.timer(CONNECTED_DURATION_TIMER_NAME, clientPlatformTag.and(authenticatedTag))); + + unauthenticatedDurationTimersByClientPlatform.put(clientPlatform, + Metrics.timer(CONNECTED_DURATION_TIMER_NAME, clientPlatformTag.and(unauthenticatedTag))); } - openWebsocketsFromUnknownPlatforms = new AtomicInteger(0); + openAuthenticatedWebsocketsFromUnknownPlatforms = new AtomicInteger(0); + openUnauthenticatedWebsocketsFromUnknownPlatforms = new AtomicInteger(0); - Metrics.gauge(OPEN_WEBSOCKET_COUNTER_NAME, Tags.of(UserAgentTagUtil.PLATFORM_TAG, "unrecognized"), - openWebsocketsFromUnknownPlatforms); + final Tags unrecognizedPlatform = Tags.of(UserAgentTagUtil.PLATFORM_TAG, "unrecognized"); + Metrics.gauge(OPEN_WEBSOCKET_COUNTER_NAME, unrecognizedPlatform.and(authenticatedTag), + openAuthenticatedWebsocketsFromUnknownPlatforms); + + Metrics.gauge(OPEN_WEBSOCKET_COUNTER_NAME, unrecognizedPlatform.and(unauthenticatedTag), + openUnauthenticatedWebsocketsFromUnknownPlatforms); + + durationTimerForUnknownPlatforms = Metrics.timer(CONNECTED_DURATION_TIMER_NAME, + unrecognizedPlatform.and(authenticatedTag)); + + unauthenticatedDurationTimerForUnknownPlatforms = Metrics.timer(CONNECTED_DURATION_TIMER_NAME, + unrecognizedPlatform.and(unauthenticatedTag)); } @Override public void onWebSocketConnect(WebSocketSessionContext context) { - if (context.getAuthenticated() != null) { + + final boolean authenticated = (context.getAuthenticated() != null); + final String userAgent = context.getClient().getUserAgent(); + final AtomicInteger openWebsocketAtomicInteger = getOpenWebsocketCounter(userAgent, authenticated); + final io.micrometer.core.instrument.Timer connectionTimer = getConnectionTimer(userAgent, authenticated); + + if (authenticated) { final AuthenticatedAccount auth = context.getAuthenticated(AuthenticatedAccount.class); final Device device = auth.getAuthenticatedDevice(); final Timer.Context timer = durationTimer.time(); + final io.micrometer.core.instrument.Timer.Sample sample = io.micrometer.core.instrument.Timer.start(); final WebSocketConnection connection = new WebSocketConnection(receiptSender, messagesManager, auth, device, context.getClient(), scheduledExecutorService, messageDeliveryScheduler); - final AtomicInteger openWebsocketAtomicInteger = getOpenWebsocketCounter(context.getClient().getUserAgent()); - openWebsocketAtomicInteger.incrementAndGet(); openWebsocketCounter.inc(); - pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), device, context.getClient().getUserAgent()); + pushNotificationManager.handleMessagesRetrieved(auth.getAccount(), device, userAgent); final AtomicReference> renewPresenceFutureReference = new AtomicReference<>(); @@ -121,6 +165,7 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { openWebsocketCounter.dec(); timer.stop(); + sample.stop(connectionTimer); final ScheduledFuture renewPresenceFuture = renewPresenceFutureReference.get(); @@ -159,16 +204,45 @@ public class AuthenticatedConnectListener implements WebSocketConnectListener { context.getClient().close(1011, "Unexpected error initializing connection"); } } else { + + openWebsocketAtomicInteger.incrementAndGet(); + openWebsocketCounter.inc(); + final Timer.Context timer = unauthenticatedDurationTimer.time(); - context.addWebsocketClosedListener((context1, statusCode, reason) -> timer.stop()); + final io.micrometer.core.instrument.Timer.Sample sample = io.micrometer.core.instrument.Timer.start(); + context.addWebsocketClosedListener((context1, statusCode, reason) -> { + openWebsocketAtomicInteger.decrementAndGet(); + openWebsocketCounter.dec(); + timer.stop(); + sample.stop(connectionTimer); + }); } } - private AtomicInteger getOpenWebsocketCounter(final String userAgentString) { + private AtomicInteger getOpenWebsocketCounter(final String userAgentString, final boolean authenticated) { try { - return openWebsocketsByClientPlatform.get(UserAgentUtil.parseUserAgentString(userAgentString).getPlatform()); + final ClientPlatform platform = UserAgentUtil.parseUserAgentString(userAgentString).getPlatform(); + return authenticated + ? openAuthenticatedWebsocketsByClientPlatform.get(platform) + : openUnauthenticatedWebsocketsByClientPlatform.get(platform); } catch (final UnrecognizedUserAgentException e) { - return openWebsocketsFromUnknownPlatforms; + return authenticated + ? openAuthenticatedWebsocketsFromUnknownPlatforms + : openUnauthenticatedWebsocketsFromUnknownPlatforms; + } + } + + private io.micrometer.core.instrument.Timer getConnectionTimer(final String userAgentString, + final boolean authenticated) { + try { + final ClientPlatform platform = UserAgentUtil.parseUserAgentString(userAgentString).getPlatform(); + return authenticated + ? durationTimersByClientPlatform.get(platform) + : unauthenticatedDurationTimersByClientPlatform.get(platform); + } catch (final UnrecognizedUserAgentException e) { + return authenticated + ? durationTimerForUnknownPlatforms + : unauthenticatedDurationTimerForUnknownPlatforms; } } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java new file mode 100644 index 000000000..8148e41b3 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketAccountAuthenticatorTest.java @@ -0,0 +1,87 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.websocket; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.i18n.phonenumbers.PhoneNumberUtil; +import io.dropwizard.auth.basic.BasicCredentials; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Stream; +import org.eclipse.jetty.websocket.api.UpgradeRequest; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.whispersystems.textsecuregcm.auth.AccountAuthenticator; +import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount; +import org.whispersystems.textsecuregcm.storage.Account; +import org.whispersystems.textsecuregcm.storage.Device; +import org.whispersystems.textsecuregcm.util.Pair; +import org.whispersystems.websocket.auth.WebSocketAuthenticator; + +class WebSocketAccountAuthenticatorTest { + + private static final String VALID_USER = PhoneNumberUtil.getInstance().format( + PhoneNumberUtil.getInstance().getExampleNumber("NZ"), PhoneNumberUtil.PhoneNumberFormat.E164); + private static final String VALID_PASSWORD = "valid"; + private static final String INVALID_USER = PhoneNumberUtil.getInstance().format( + PhoneNumberUtil.getInstance().getExampleNumber("AU"), PhoneNumberUtil.PhoneNumberFormat.E164); + private static final String INVALID_PASSWORD = "invalid"; + + private AccountAuthenticator accountAuthenticator; + + private UpgradeRequest upgradeRequest; + + @BeforeEach + void setUp() { + accountAuthenticator = mock(AccountAuthenticator.class); + + when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) + .thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(mock(Account.class), mock(Device.class))))); + + when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD)))) + .thenReturn(Optional.empty()); + + upgradeRequest = mock(UpgradeRequest.class); + } + + @ParameterizedTest + @MethodSource + void testAuthenticate(final Map> upgradeRequestParameters, final boolean expectAccount, + final boolean expectRequired) throws Exception { + + when(upgradeRequest.getParameterMap()).thenReturn(upgradeRequestParameters); + + final WebSocketAccountAuthenticator webSocketAuthenticator = new WebSocketAccountAuthenticator( + accountAuthenticator); + + final WebSocketAuthenticator.AuthenticationResult result = webSocketAuthenticator.authenticate( + upgradeRequest); + + if (expectAccount) { + assertTrue(result.getUser().isPresent()); + } else { + assertTrue(result.getUser().isEmpty()); + } + + assertEquals(expectRequired, result.isRequired()); + } + + private static Stream testAuthenticate() { + return Stream.of( + Arguments.of(Map.of("login", List.of(VALID_USER), "password", List.of(VALID_PASSWORD)), true, true), + Arguments.of(Map.of("login", List.of(INVALID_USER), "password", List.of(INVALID_PASSWORD)), false, true), + Arguments.of(Map.of(), false, false) + ); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java index be377cecc..efbed8f8b 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnectionTest.java @@ -21,6 +21,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; @@ -75,12 +76,10 @@ import reactor.test.StepVerifier; class WebSocketConnectionTest { private static final String VALID_USER = "+14152222222"; - private static final String INVALID_USER = "+14151111111"; private static final int SOURCE_DEVICE_ID = 1; private static final String VALID_PASSWORD = "secure"; - private static final String INVALID_PASSWORD = "insecure"; private AccountAuthenticator accountAuthenticator; private AccountsManager accountsManager; @@ -124,28 +123,31 @@ class WebSocketConnectionTest { when(accountAuthenticator.authenticate(eq(new BasicCredentials(VALID_USER, VALID_PASSWORD)))) .thenReturn(Optional.of(new AuthenticatedAccount(() -> new Pair<>(account, device)))); - when(accountAuthenticator.authenticate(eq(new BasicCredentials(INVALID_USER, INVALID_PASSWORD)))) - .thenReturn(Optional.empty()); - - when(upgradeRequest.getParameterMap()).thenReturn(Map.of( - "login", List.of(VALID_USER), - "password", List.of(VALID_PASSWORD))); - AuthenticationResult account = webSocketAuthenticator.authenticate(upgradeRequest); + when(sessionContext.getAuthenticated()).thenReturn(account.getUser().orElse(null)); when(sessionContext.getAuthenticated(AuthenticatedAccount.class)).thenReturn(account.getUser().orElse(null)); + final WebSocketClient webSocketClient = mock(WebSocketClient.class); + when(webSocketClient.getUserAgent()).thenReturn("Signal-Android/6.22.8"); + when(sessionContext.getClient()).thenReturn(webSocketClient); + + // authenticated - valid user connectListener.onWebSocketConnect(sessionContext); - verify(sessionContext).addWebsocketClosedListener(any(WebSocketSessionContext.WebSocketEventListener.class)); - - when(upgradeRequest.getParameterMap()).thenReturn(Map.of( - "login", List.of(INVALID_USER), - "password", List.of(INVALID_PASSWORD) - )); + verify(sessionContext, times(1)).addWebsocketClosedListener( + any(WebSocketSessionContext.WebSocketEventListener.class)); + // unauthenticated + when(upgradeRequest.getParameterMap()).thenReturn(Map.of()); account = webSocketAuthenticator.authenticate(upgradeRequest); assertFalse(account.getUser().isPresent()); - assertTrue(account.isRequired()); + assertFalse(account.isRequired()); + + connectListener.onWebSocketConnect(sessionContext); + verify(sessionContext, times(2)).addWebsocketClosedListener( + any(WebSocketSessionContext.WebSocketEventListener.class)); + + verifyNoMoreInteractions(messagesManager); } @Test diff --git a/websocket-resources/src/main/java/org/whispersystems/websocket/session/WebSocketSessionContext.java b/websocket-resources/src/main/java/org/whispersystems/websocket/session/WebSocketSessionContext.java index b5298ef3a..a83b41075 100644 --- a/websocket-resources/src/main/java/org/whispersystems/websocket/session/WebSocketSessionContext.java +++ b/websocket-resources/src/main/java/org/whispersystems/websocket/session/WebSocketSessionContext.java @@ -4,10 +4,10 @@ */ package org.whispersystems.websocket.session; -import org.whispersystems.websocket.WebSocketClient; - import java.util.LinkedList; import java.util.List; +import javax.annotation.Nullable; +import org.whispersystems.websocket.WebSocketClient; public class WebSocketSessionContext { @@ -34,6 +34,7 @@ public class WebSocketSessionContext { throw new IllegalArgumentException("No authenticated type for: " + clazz + ", we have: " + authenticated); } + @Nullable public Object getAuthenticated() { return authenticated; }