From 8359ef73f44c4b61c931f99c97e686da64b353b2 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Wed, 29 Sep 2021 11:10:05 -0400 Subject: [PATCH] Cycle all connected websockets on any device or account enabled state change --- ...hEnablementRefreshRequirementProvider.java | 58 +++------- ...blementRefreshRequirementProviderTest.java | 105 ++++-------------- 2 files changed, 35 insertions(+), 128 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java index 2ad60c025..f44c2d2c0 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java @@ -7,11 +7,9 @@ package org.whispersystems.textsecuregcm.auth; import com.google.common.annotations.VisibleForTesting; import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Set; import java.util.UUID; import java.util.stream.Collectors; @@ -26,11 +24,9 @@ import org.whispersystems.textsecuregcm.util.Pair; * This {@link WebsocketRefreshRequirementProvider} observes intra-request changes in {@link Account#isEnabled()} and * {@link Device#isEnabled()}. *

- * If a change in {@link Account#isEnabled()} is observed, then any active WebSocket connections for the account must be - * closed, in order for clients to get a refreshed {@link io.dropwizard.auth.Auth} object. - *

- * If a change in {@link Device#isEnabled()} is observed, including deletion of the {@link Device}, then any active - * WebSocket connections for the device must be closed and re-authenticated. + * If a change in {@link Account#isEnabled()} or any associated {@link Device#isEnabled()} is observed, then any active + * WebSocket connections for the account must be closed in order for clients to get a refreshed + * {@link io.dropwizard.auth.Auth} object with a current device list. * * @see AuthenticatedAccount * @see DisabledPermittedAuthenticatedAccount @@ -39,7 +35,6 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres private static final Logger logger = LoggerFactory.getLogger(AuthEnablementRefreshRequirementProvider.class); - private static final String ACCOUNT_ENABLED = AuthEnablementRefreshRequirementProvider.class.getName() + ".accountEnabled"; private static final String DEVICES_ENABLED = AuthEnablementRefreshRequirementProvider.class.getName() + ".devicesEnabled"; @VisibleForTesting @@ -50,57 +45,34 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres @Override public void handleRequestFiltered(final ContainerRequest request) { // The authenticated principal, if any, will be available after filters have run. - // Now that the account is known, capture a snapshot of `isEnabled` for the account and its devices, - // before carrying out the request’s business logic. + // Now that the account is known, capture a snapshot of `isEnabled` for the account's devices before carrying out + // the request’s business logic. ContainerRequestUtil.getAuthenticatedAccount(request) - .ifPresent( - account -> { - request.setProperty(ACCOUNT_ENABLED, account.isEnabled()); - request.setProperty(DEVICES_ENABLED, buildDevicesEnabledMap(account)); - }); + .ifPresent(account -> request.setProperty(DEVICES_ENABLED, buildDevicesEnabledMap(account))); } @Override public List> handleRequestFinished(final ContainerRequest request) { - // Now that the request is finished, check whether `isEnabled` changed for any of the devices, or the account - // as a whole. If the value did change, the affected device(s) must disconnect and reauthenticate. - // If a device was removed, it must also disconnect. - if (request.getProperty(ACCOUNT_ENABLED) != null && - request.getProperty(DEVICES_ENABLED) != null) { + // Now that the request is finished, check whether `isEnabled` changed for any of the devices. If the value did + // change or if a devices was added or removed, all devices must disconnect and reauthenticate. + if (request.getProperty(DEVICES_ENABLED) != null) { - final boolean accountInitiallyEnabled = (boolean) request.getProperty(ACCOUNT_ENABLED); @SuppressWarnings("unchecked") final Map initialDevicesEnabled = (Map) request.getProperty(DEVICES_ENABLED); return ContainerRequestUtil.getAuthenticatedAccount(request).map(account -> { final Set deviceIdsToDisplace; + final Map currentDevicesEnabled = buildDevicesEnabledMap(account); - if (account.isEnabled() != accountInitiallyEnabled) { - // the @Auth for all active connections must change when account.isEnabled() changes - deviceIdsToDisplace = account.getDevices().stream() - .map(Device::getId).collect(Collectors.toSet()); - - deviceIdsToDisplace.addAll(initialDevicesEnabled.keySet()); - - } else if (!initialDevicesEnabled.isEmpty()) { - - deviceIdsToDisplace = new HashSet<>(); - final Map currentDevicesEnabled = buildDevicesEnabledMap(account); - - initialDevicesEnabled.forEach((deviceId, enabled) -> { - // `null` indicates the device was removed from the account. Any active presence should be removed. - final boolean enabledMatches = Objects.equals(enabled, - currentDevicesEnabled.getOrDefault(deviceId, null)); - - if (!enabledMatches) { - deviceIdsToDisplace.add(deviceId); - } - }); + if (!initialDevicesEnabled.equals(currentDevicesEnabled)) { + deviceIdsToDisplace = new HashSet<>(initialDevicesEnabled.keySet()); + deviceIdsToDisplace.addAll(currentDevicesEnabled.keySet()); } else { deviceIdsToDisplace = Collections.emptySet(); } - return deviceIdsToDisplace.stream().map(deviceId -> new Pair<>(account.getUuid(), deviceId)) + return deviceIdsToDisplace.stream() + .map(deviceId -> new Pair<>(account.getUuid(), deviceId)) .collect(Collectors.toList()); }).orElseGet(() -> { logger.error("Request had account, but it is no longer present"); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java index 73411f824..5ce87a0ae 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java @@ -10,12 +10,9 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; @@ -90,7 +87,7 @@ class AuthEnablementRefreshRequirementProviderTest { private final ApplicationEventListener applicationEventListener = mock(ApplicationEventListener.class); private final Account account = new Account(); - private Device authenticatedDevice = DevicesHelper.createDevice(1L); + private final Device authenticatedDevice = DevicesHelper.createDevice(1L); private final Supplier> principalSupplier = () -> Optional.of( new TestPrincipal("test", account, authenticatedDevice)); @@ -123,9 +120,7 @@ class AuthEnablementRefreshRequirementProviderTest { final UUID uuid = UUID.randomUUID(); account.setUuid(uuid); account.addDevice(authenticatedDevice); - LongStream.range(2, 4).forEach(deviceId -> { - account.addDevice(DevicesHelper.createDevice(deviceId)); - }); + LongStream.range(2, 4).forEach(deviceId -> account.addDevice(DevicesHelper.createDevice(deviceId))); account.getDevices() .forEach(device -> when(clientPresenceManager.isPresent(uuid, device.getId())).thenReturn(true)); @@ -163,45 +158,6 @@ class AuthEnablementRefreshRequirementProviderTest { })); } - @ParameterizedTest - @MethodSource - void testAccountEnabledChanged(final long authenticatedDeviceId, final boolean initialEnabled, - final boolean finalEnabled) { - - DevicesHelper.setEnabled(account.getMasterDevice().orElseThrow(), initialEnabled); - - authenticatedDevice = account.getDevice(authenticatedDeviceId).orElseThrow(); - - final Response response = resources.getJerseyTest() - .target("/v1/test/account/enabled/" + finalEnabled) - .request() - .header("Authorization", - "Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8))) - .put(Entity.entity("", MediaType.TEXT_PLAIN)); - - assertEquals(200, response.getStatus()); - - if (initialEnabled != finalEnabled) { - verify(clientPresenceManager, times(account.getDevices().size())).displacePresence(eq(account.getUuid()), - anyLong()); - } else { - verifyNoInteractions(clientPresenceManager); - } - } - - static Stream testAccountEnabledChanged() { - return Stream.of( - Arguments.of(1L, true, false), - Arguments.of(1L, false, true), - Arguments.of(1L, true, true), - Arguments.of(1L, false, false), - Arguments.of(2L, true, false), - Arguments.of(2L, false, true), - Arguments.of(2L, true, true), - Arguments.of(2L, false, false) - ); - } - @ParameterizedTest @MethodSource void testDeviceEnabledChanged(final Map initialEnabled, final Map finalEnabled) { @@ -221,21 +177,22 @@ class AuthEnablementRefreshRequirementProviderTest { assertEquals(200, response.getStatus()); - assertAll( - finalEnabled.entrySet().stream() - .map(deviceIdEnabled -> () -> { - final boolean expectDisplacedPresence = - initialEnabled.get(deviceIdEnabled.getKey()) != deviceIdEnabled.getValue(); + final boolean expectDisplacedPresence = !initialEnabled.equals(finalEnabled); - verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0)).displacePresence(account.getUuid(), - deviceIdEnabled.getKey()); - }) - ); + assertAll( + initialEnabled.keySet().stream() + .map(deviceId -> () -> verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0)) + .displacePresence(account.getUuid(), deviceId))); + + assertAll( + finalEnabled.keySet().stream() + .map(deviceId -> () -> verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0)) + .displacePresence(account.getUuid(), deviceId))); } static Stream testDeviceEnabledChanged() { return Stream.of( - // Not testing device ID 1L because that will trigger "account enabled changed" + Arguments.of(Map.of(1L, false, 2L, false), Map.of(1L, true, 2L, false)), Arguments.of(Map.of(2L, false, 3L, false), Map.of(2L, true, 3L, true)), Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, false, 3L, false)), Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, true, 3L, true)), @@ -262,7 +219,9 @@ class AuthEnablementRefreshRequirementProviderTest { assertEquals(initialDeviceCount + addedDeviceNames.size(), account.getDevices().size()); - verifyNoInteractions(clientPresenceManager); + verify(clientPresenceManager).displacePresence(account.getUuid(), 1); + verify(clientPresenceManager).displacePresence(account.getUuid(), 2); + verify(clientPresenceManager).displacePresence(account.getUuid(), 3); } @ParameterizedTest @@ -270,6 +229,8 @@ class AuthEnablementRefreshRequirementProviderTest { void testDeviceRemoved(final int removedDeviceCount) { assert account.getMasterDevice().orElseThrow().isEnabled(); + final List initialDeviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toList()); + final List deletedDeviceIds = account.getDevices().stream() .map(Device::getId) .filter(deviceId -> deviceId != 1L) @@ -290,8 +251,8 @@ class AuthEnablementRefreshRequirementProviderTest { assertEquals(200, response.getStatus()); - deletedDeviceIds.forEach(deletedDeviceId -> - verify(clientPresenceManager).displacePresence(account.getUuid(), deletedDeviceId)); + initialDeviceIds.forEach(deviceId -> + verify(clientPresenceManager).displacePresence(account.getUuid(), deviceId)); verifyNoMoreInteractions(clientPresenceManager); } @@ -374,32 +335,6 @@ class AuthEnablementRefreshRequirementProviderTest { provider.onWebSocketConnect(session); } - @ParameterizedTest - @MethodSource("org.whispersystems.textsecuregcm.auth.AuthEnablementRefreshRequirementProviderTest#testAccountEnabledChanged") - void testAccountEnabledChangedWebSocket(final long authenticatedDeviceId, final boolean initialEnabled, - final boolean finalEnabled) throws Exception { - - DevicesHelper.setEnabled(account.getMasterDevice().orElseThrow(), initialEnabled); - - authenticatedDevice = account.getDevice(authenticatedDeviceId).orElseThrow(); - - byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT", - "/v1/test/account/enabled/" + finalEnabled, - new LinkedList<>(), Optional.empty()).toByteArray(); - - provider.onWebSocketBinary(message, 0, message.length); - - final SubProtocol.WebSocketResponseMessage response = verifyAndGetResponse(remoteEndpoint); - - assertEquals(200, response.getStatus()); - if (initialEnabled != finalEnabled) { - verify(clientPresenceManager, times(account.getDevices().size())).displacePresence(eq(account.getUuid()), - anyLong()); - } else { - verifyNoInteractions(clientPresenceManager); - } - } - @Test void testOnEvent() throws Exception {