diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java
index 2ad60c025..f44c2d2c0 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProvider.java
@@ -7,11 +7,9 @@ package org.whispersystems.textsecuregcm.auth;
import com.google.common.annotations.VisibleForTesting;
import java.util.Collections;
-import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
-import java.util.Objects;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
@@ -26,11 +24,9 @@ import org.whispersystems.textsecuregcm.util.Pair;
* This {@link WebsocketRefreshRequirementProvider} observes intra-request changes in {@link Account#isEnabled()} and
* {@link Device#isEnabled()}.
*
- * If a change in {@link Account#isEnabled()} is observed, then any active WebSocket connections for the account must be
- * closed, in order for clients to get a refreshed {@link io.dropwizard.auth.Auth} object.
- *
- * If a change in {@link Device#isEnabled()} is observed, including deletion of the {@link Device}, then any active
- * WebSocket connections for the device must be closed and re-authenticated.
+ * If a change in {@link Account#isEnabled()} or any associated {@link Device#isEnabled()} is observed, then any active
+ * WebSocket connections for the account must be closed in order for clients to get a refreshed
+ * {@link io.dropwizard.auth.Auth} object with a current device list.
*
* @see AuthenticatedAccount
* @see DisabledPermittedAuthenticatedAccount
@@ -39,7 +35,6 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
private static final Logger logger = LoggerFactory.getLogger(AuthEnablementRefreshRequirementProvider.class);
- private static final String ACCOUNT_ENABLED = AuthEnablementRefreshRequirementProvider.class.getName() + ".accountEnabled";
private static final String DEVICES_ENABLED = AuthEnablementRefreshRequirementProvider.class.getName() + ".devicesEnabled";
@VisibleForTesting
@@ -50,57 +45,34 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
@Override
public void handleRequestFiltered(final ContainerRequest request) {
// The authenticated principal, if any, will be available after filters have run.
- // Now that the account is known, capture a snapshot of `isEnabled` for the account and its devices,
- // before carrying out the request’s business logic.
+ // Now that the account is known, capture a snapshot of `isEnabled` for the account's devices before carrying out
+ // the request’s business logic.
ContainerRequestUtil.getAuthenticatedAccount(request)
- .ifPresent(
- account -> {
- request.setProperty(ACCOUNT_ENABLED, account.isEnabled());
- request.setProperty(DEVICES_ENABLED, buildDevicesEnabledMap(account));
- });
+ .ifPresent(account -> request.setProperty(DEVICES_ENABLED, buildDevicesEnabledMap(account)));
}
@Override
public List> handleRequestFinished(final ContainerRequest request) {
- // Now that the request is finished, check whether `isEnabled` changed for any of the devices, or the account
- // as a whole. If the value did change, the affected device(s) must disconnect and reauthenticate.
- // If a device was removed, it must also disconnect.
- if (request.getProperty(ACCOUNT_ENABLED) != null &&
- request.getProperty(DEVICES_ENABLED) != null) {
+ // Now that the request is finished, check whether `isEnabled` changed for any of the devices. If the value did
+ // change or if a devices was added or removed, all devices must disconnect and reauthenticate.
+ if (request.getProperty(DEVICES_ENABLED) != null) {
- final boolean accountInitiallyEnabled = (boolean) request.getProperty(ACCOUNT_ENABLED);
@SuppressWarnings("unchecked") final Map initialDevicesEnabled =
(Map) request.getProperty(DEVICES_ENABLED);
return ContainerRequestUtil.getAuthenticatedAccount(request).map(account -> {
final Set deviceIdsToDisplace;
+ final Map currentDevicesEnabled = buildDevicesEnabledMap(account);
- if (account.isEnabled() != accountInitiallyEnabled) {
- // the @Auth for all active connections must change when account.isEnabled() changes
- deviceIdsToDisplace = account.getDevices().stream()
- .map(Device::getId).collect(Collectors.toSet());
-
- deviceIdsToDisplace.addAll(initialDevicesEnabled.keySet());
-
- } else if (!initialDevicesEnabled.isEmpty()) {
-
- deviceIdsToDisplace = new HashSet<>();
- final Map currentDevicesEnabled = buildDevicesEnabledMap(account);
-
- initialDevicesEnabled.forEach((deviceId, enabled) -> {
- // `null` indicates the device was removed from the account. Any active presence should be removed.
- final boolean enabledMatches = Objects.equals(enabled,
- currentDevicesEnabled.getOrDefault(deviceId, null));
-
- if (!enabledMatches) {
- deviceIdsToDisplace.add(deviceId);
- }
- });
+ if (!initialDevicesEnabled.equals(currentDevicesEnabled)) {
+ deviceIdsToDisplace = new HashSet<>(initialDevicesEnabled.keySet());
+ deviceIdsToDisplace.addAll(currentDevicesEnabled.keySet());
} else {
deviceIdsToDisplace = Collections.emptySet();
}
- return deviceIdsToDisplace.stream().map(deviceId -> new Pair<>(account.getUuid(), deviceId))
+ return deviceIdsToDisplace.stream()
+ .map(deviceId -> new Pair<>(account.getUuid(), deviceId))
.collect(Collectors.toList());
}).orElseGet(() -> {
logger.error("Request had account, but it is no longer present");
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java
index 73411f824..5ce87a0ae 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRefreshRequirementProviderTest.java
@@ -10,12 +10,9 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyLong;
-import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
@@ -90,7 +87,7 @@ class AuthEnablementRefreshRequirementProviderTest {
private final ApplicationEventListener applicationEventListener = mock(ApplicationEventListener.class);
private final Account account = new Account();
- private Device authenticatedDevice = DevicesHelper.createDevice(1L);
+ private final Device authenticatedDevice = DevicesHelper.createDevice(1L);
private final Supplier> principalSupplier = () -> Optional.of(
new TestPrincipal("test", account, authenticatedDevice));
@@ -123,9 +120,7 @@ class AuthEnablementRefreshRequirementProviderTest {
final UUID uuid = UUID.randomUUID();
account.setUuid(uuid);
account.addDevice(authenticatedDevice);
- LongStream.range(2, 4).forEach(deviceId -> {
- account.addDevice(DevicesHelper.createDevice(deviceId));
- });
+ LongStream.range(2, 4).forEach(deviceId -> account.addDevice(DevicesHelper.createDevice(deviceId)));
account.getDevices()
.forEach(device -> when(clientPresenceManager.isPresent(uuid, device.getId())).thenReturn(true));
@@ -163,45 +158,6 @@ class AuthEnablementRefreshRequirementProviderTest {
}));
}
- @ParameterizedTest
- @MethodSource
- void testAccountEnabledChanged(final long authenticatedDeviceId, final boolean initialEnabled,
- final boolean finalEnabled) {
-
- DevicesHelper.setEnabled(account.getMasterDevice().orElseThrow(), initialEnabled);
-
- authenticatedDevice = account.getDevice(authenticatedDeviceId).orElseThrow();
-
- final Response response = resources.getJerseyTest()
- .target("/v1/test/account/enabled/" + finalEnabled)
- .request()
- .header("Authorization",
- "Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
- .put(Entity.entity("", MediaType.TEXT_PLAIN));
-
- assertEquals(200, response.getStatus());
-
- if (initialEnabled != finalEnabled) {
- verify(clientPresenceManager, times(account.getDevices().size())).displacePresence(eq(account.getUuid()),
- anyLong());
- } else {
- verifyNoInteractions(clientPresenceManager);
- }
- }
-
- static Stream testAccountEnabledChanged() {
- return Stream.of(
- Arguments.of(1L, true, false),
- Arguments.of(1L, false, true),
- Arguments.of(1L, true, true),
- Arguments.of(1L, false, false),
- Arguments.of(2L, true, false),
- Arguments.of(2L, false, true),
- Arguments.of(2L, true, true),
- Arguments.of(2L, false, false)
- );
- }
-
@ParameterizedTest
@MethodSource
void testDeviceEnabledChanged(final Map initialEnabled, final Map finalEnabled) {
@@ -221,21 +177,22 @@ class AuthEnablementRefreshRequirementProviderTest {
assertEquals(200, response.getStatus());
- assertAll(
- finalEnabled.entrySet().stream()
- .map(deviceIdEnabled -> () -> {
- final boolean expectDisplacedPresence =
- initialEnabled.get(deviceIdEnabled.getKey()) != deviceIdEnabled.getValue();
+ final boolean expectDisplacedPresence = !initialEnabled.equals(finalEnabled);
- verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0)).displacePresence(account.getUuid(),
- deviceIdEnabled.getKey());
- })
- );
+ assertAll(
+ initialEnabled.keySet().stream()
+ .map(deviceId -> () -> verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0))
+ .displacePresence(account.getUuid(), deviceId)));
+
+ assertAll(
+ finalEnabled.keySet().stream()
+ .map(deviceId -> () -> verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0))
+ .displacePresence(account.getUuid(), deviceId)));
}
static Stream testDeviceEnabledChanged() {
return Stream.of(
- // Not testing device ID 1L because that will trigger "account enabled changed"
+ Arguments.of(Map.of(1L, false, 2L, false), Map.of(1L, true, 2L, false)),
Arguments.of(Map.of(2L, false, 3L, false), Map.of(2L, true, 3L, true)),
Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, false, 3L, false)),
Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, true, 3L, true)),
@@ -262,7 +219,9 @@ class AuthEnablementRefreshRequirementProviderTest {
assertEquals(initialDeviceCount + addedDeviceNames.size(), account.getDevices().size());
- verifyNoInteractions(clientPresenceManager);
+ verify(clientPresenceManager).displacePresence(account.getUuid(), 1);
+ verify(clientPresenceManager).displacePresence(account.getUuid(), 2);
+ verify(clientPresenceManager).displacePresence(account.getUuid(), 3);
}
@ParameterizedTest
@@ -270,6 +229,8 @@ class AuthEnablementRefreshRequirementProviderTest {
void testDeviceRemoved(final int removedDeviceCount) {
assert account.getMasterDevice().orElseThrow().isEnabled();
+ final List initialDeviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toList());
+
final List deletedDeviceIds = account.getDevices().stream()
.map(Device::getId)
.filter(deviceId -> deviceId != 1L)
@@ -290,8 +251,8 @@ class AuthEnablementRefreshRequirementProviderTest {
assertEquals(200, response.getStatus());
- deletedDeviceIds.forEach(deletedDeviceId ->
- verify(clientPresenceManager).displacePresence(account.getUuid(), deletedDeviceId));
+ initialDeviceIds.forEach(deviceId ->
+ verify(clientPresenceManager).displacePresence(account.getUuid(), deviceId));
verifyNoMoreInteractions(clientPresenceManager);
}
@@ -374,32 +335,6 @@ class AuthEnablementRefreshRequirementProviderTest {
provider.onWebSocketConnect(session);
}
- @ParameterizedTest
- @MethodSource("org.whispersystems.textsecuregcm.auth.AuthEnablementRefreshRequirementProviderTest#testAccountEnabledChanged")
- void testAccountEnabledChangedWebSocket(final long authenticatedDeviceId, final boolean initialEnabled,
- final boolean finalEnabled) throws Exception {
-
- DevicesHelper.setEnabled(account.getMasterDevice().orElseThrow(), initialEnabled);
-
- authenticatedDevice = account.getDevice(authenticatedDeviceId).orElseThrow();
-
- byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT",
- "/v1/test/account/enabled/" + finalEnabled,
- new LinkedList<>(), Optional.empty()).toByteArray();
-
- provider.onWebSocketBinary(message, 0, message.length);
-
- final SubProtocol.WebSocketResponseMessage response = verifyAndGetResponse(remoteEndpoint);
-
- assertEquals(200, response.getStatus());
- if (initialEnabled != finalEnabled) {
- verify(clientPresenceManager, times(account.getDevices().size())).displacePresence(eq(account.getUuid()),
- anyLong());
- } else {
- verifyNoInteractions(clientPresenceManager);
- }
- }
-
@Test
void testOnEvent() throws Exception {