Cycle all connected websockets on any device or account enabled state change

This commit is contained in:
Jon Chambers 2021-09-29 11:10:05 -04:00 committed by Jon Chambers
parent c6bb649adb
commit 8359ef73f4
2 changed files with 35 additions and 128 deletions

View File

@ -7,11 +7,9 @@ package org.whispersystems.textsecuregcm.auth;
import com.google.common.annotations.VisibleForTesting;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
@ -26,11 +24,9 @@ import org.whispersystems.textsecuregcm.util.Pair;
* This {@link WebsocketRefreshRequirementProvider} observes intra-request changes in {@link Account#isEnabled()} and
* {@link Device#isEnabled()}.
* <p>
* If a change in {@link Account#isEnabled()} is observed, then any active WebSocket connections for the account must be
* closed, in order for clients to get a refreshed {@link io.dropwizard.auth.Auth} object.
* <p>
* If a change in {@link Device#isEnabled()} is observed, including deletion of the {@link Device}, then any active
* WebSocket connections for the device must be closed and re-authenticated.
* If a change in {@link Account#isEnabled()} or any associated {@link Device#isEnabled()} is observed, then any active
* WebSocket connections for the account must be closed in order for clients to get a refreshed
* {@link io.dropwizard.auth.Auth} object with a current device list.
*
* @see AuthenticatedAccount
* @see DisabledPermittedAuthenticatedAccount
@ -39,7 +35,6 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
private static final Logger logger = LoggerFactory.getLogger(AuthEnablementRefreshRequirementProvider.class);
private static final String ACCOUNT_ENABLED = AuthEnablementRefreshRequirementProvider.class.getName() + ".accountEnabled";
private static final String DEVICES_ENABLED = AuthEnablementRefreshRequirementProvider.class.getName() + ".devicesEnabled";
@VisibleForTesting
@ -50,57 +45,34 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
@Override
public void handleRequestFiltered(final ContainerRequest request) {
// The authenticated principal, if any, will be available after filters have run.
// Now that the account is known, capture a snapshot of `isEnabled` for the account and its devices,
// before carrying out the requests business logic.
// Now that the account is known, capture a snapshot of `isEnabled` for the account's devices before carrying out
// the requests business logic.
ContainerRequestUtil.getAuthenticatedAccount(request)
.ifPresent(
account -> {
request.setProperty(ACCOUNT_ENABLED, account.isEnabled());
request.setProperty(DEVICES_ENABLED, buildDevicesEnabledMap(account));
});
.ifPresent(account -> request.setProperty(DEVICES_ENABLED, buildDevicesEnabledMap(account)));
}
@Override
public List<Pair<UUID, Long>> handleRequestFinished(final ContainerRequest request) {
// Now that the request is finished, check whether `isEnabled` changed for any of the devices, or the account
// as a whole. If the value did change, the affected device(s) must disconnect and reauthenticate.
// If a device was removed, it must also disconnect.
if (request.getProperty(ACCOUNT_ENABLED) != null &&
request.getProperty(DEVICES_ENABLED) != null) {
// Now that the request is finished, check whether `isEnabled` changed for any of the devices. If the value did
// change or if a devices was added or removed, all devices must disconnect and reauthenticate.
if (request.getProperty(DEVICES_ENABLED) != null) {
final boolean accountInitiallyEnabled = (boolean) request.getProperty(ACCOUNT_ENABLED);
@SuppressWarnings("unchecked") final Map<Long, Boolean> initialDevicesEnabled =
(Map<Long, Boolean>) request.getProperty(DEVICES_ENABLED);
return ContainerRequestUtil.getAuthenticatedAccount(request).map(account -> {
final Set<Long> deviceIdsToDisplace;
final Map<Long, Boolean> currentDevicesEnabled = buildDevicesEnabledMap(account);
if (account.isEnabled() != accountInitiallyEnabled) {
// the @Auth for all active connections must change when account.isEnabled() changes
deviceIdsToDisplace = account.getDevices().stream()
.map(Device::getId).collect(Collectors.toSet());
deviceIdsToDisplace.addAll(initialDevicesEnabled.keySet());
} else if (!initialDevicesEnabled.isEmpty()) {
deviceIdsToDisplace = new HashSet<>();
final Map<Long, Boolean> currentDevicesEnabled = buildDevicesEnabledMap(account);
initialDevicesEnabled.forEach((deviceId, enabled) -> {
// `null` indicates the device was removed from the account. Any active presence should be removed.
final boolean enabledMatches = Objects.equals(enabled,
currentDevicesEnabled.getOrDefault(deviceId, null));
if (!enabledMatches) {
deviceIdsToDisplace.add(deviceId);
}
});
if (!initialDevicesEnabled.equals(currentDevicesEnabled)) {
deviceIdsToDisplace = new HashSet<>(initialDevicesEnabled.keySet());
deviceIdsToDisplace.addAll(currentDevicesEnabled.keySet());
} else {
deviceIdsToDisplace = Collections.emptySet();
}
return deviceIdsToDisplace.stream().map(deviceId -> new Pair<>(account.getUuid(), deviceId))
return deviceIdsToDisplace.stream()
.map(deviceId -> new Pair<>(account.getUuid(), deviceId))
.collect(Collectors.toList());
}).orElseGet(() -> {
logger.error("Request had account, but it is no longer present");

View File

@ -10,12 +10,9 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
@ -90,7 +87,7 @@ class AuthEnablementRefreshRequirementProviderTest {
private final ApplicationEventListener applicationEventListener = mock(ApplicationEventListener.class);
private final Account account = new Account();
private Device authenticatedDevice = DevicesHelper.createDevice(1L);
private final Device authenticatedDevice = DevicesHelper.createDevice(1L);
private final Supplier<Optional<TestPrincipal>> principalSupplier = () -> Optional.of(
new TestPrincipal("test", account, authenticatedDevice));
@ -123,9 +120,7 @@ class AuthEnablementRefreshRequirementProviderTest {
final UUID uuid = UUID.randomUUID();
account.setUuid(uuid);
account.addDevice(authenticatedDevice);
LongStream.range(2, 4).forEach(deviceId -> {
account.addDevice(DevicesHelper.createDevice(deviceId));
});
LongStream.range(2, 4).forEach(deviceId -> account.addDevice(DevicesHelper.createDevice(deviceId)));
account.getDevices()
.forEach(device -> when(clientPresenceManager.isPresent(uuid, device.getId())).thenReturn(true));
@ -163,45 +158,6 @@ class AuthEnablementRefreshRequirementProviderTest {
}));
}
@ParameterizedTest
@MethodSource
void testAccountEnabledChanged(final long authenticatedDeviceId, final boolean initialEnabled,
final boolean finalEnabled) {
DevicesHelper.setEnabled(account.getMasterDevice().orElseThrow(), initialEnabled);
authenticatedDevice = account.getDevice(authenticatedDeviceId).orElseThrow();
final Response response = resources.getJerseyTest()
.target("/v1/test/account/enabled/" + finalEnabled)
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.put(Entity.entity("", MediaType.TEXT_PLAIN));
assertEquals(200, response.getStatus());
if (initialEnabled != finalEnabled) {
verify(clientPresenceManager, times(account.getDevices().size())).displacePresence(eq(account.getUuid()),
anyLong());
} else {
verifyNoInteractions(clientPresenceManager);
}
}
static Stream<Arguments> testAccountEnabledChanged() {
return Stream.of(
Arguments.of(1L, true, false),
Arguments.of(1L, false, true),
Arguments.of(1L, true, true),
Arguments.of(1L, false, false),
Arguments.of(2L, true, false),
Arguments.of(2L, false, true),
Arguments.of(2L, true, true),
Arguments.of(2L, false, false)
);
}
@ParameterizedTest
@MethodSource
void testDeviceEnabledChanged(final Map<Long, Boolean> initialEnabled, final Map<Long, Boolean> finalEnabled) {
@ -221,21 +177,22 @@ class AuthEnablementRefreshRequirementProviderTest {
assertEquals(200, response.getStatus());
assertAll(
finalEnabled.entrySet().stream()
.map(deviceIdEnabled -> () -> {
final boolean expectDisplacedPresence =
initialEnabled.get(deviceIdEnabled.getKey()) != deviceIdEnabled.getValue();
final boolean expectDisplacedPresence = !initialEnabled.equals(finalEnabled);
verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0)).displacePresence(account.getUuid(),
deviceIdEnabled.getKey());
})
);
assertAll(
initialEnabled.keySet().stream()
.map(deviceId -> () -> verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0))
.displacePresence(account.getUuid(), deviceId)));
assertAll(
finalEnabled.keySet().stream()
.map(deviceId -> () -> verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0))
.displacePresence(account.getUuid(), deviceId)));
}
static Stream<Arguments> testDeviceEnabledChanged() {
return Stream.of(
// Not testing device ID 1L because that will trigger "account enabled changed"
Arguments.of(Map.of(1L, false, 2L, false), Map.of(1L, true, 2L, false)),
Arguments.of(Map.of(2L, false, 3L, false), Map.of(2L, true, 3L, true)),
Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, false, 3L, false)),
Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, true, 3L, true)),
@ -262,7 +219,9 @@ class AuthEnablementRefreshRequirementProviderTest {
assertEquals(initialDeviceCount + addedDeviceNames.size(), account.getDevices().size());
verifyNoInteractions(clientPresenceManager);
verify(clientPresenceManager).displacePresence(account.getUuid(), 1);
verify(clientPresenceManager).displacePresence(account.getUuid(), 2);
verify(clientPresenceManager).displacePresence(account.getUuid(), 3);
}
@ParameterizedTest
@ -270,6 +229,8 @@ class AuthEnablementRefreshRequirementProviderTest {
void testDeviceRemoved(final int removedDeviceCount) {
assert account.getMasterDevice().orElseThrow().isEnabled();
final List<Long> initialDeviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toList());
final List<Long> deletedDeviceIds = account.getDevices().stream()
.map(Device::getId)
.filter(deviceId -> deviceId != 1L)
@ -290,8 +251,8 @@ class AuthEnablementRefreshRequirementProviderTest {
assertEquals(200, response.getStatus());
deletedDeviceIds.forEach(deletedDeviceId ->
verify(clientPresenceManager).displacePresence(account.getUuid(), deletedDeviceId));
initialDeviceIds.forEach(deviceId ->
verify(clientPresenceManager).displacePresence(account.getUuid(), deviceId));
verifyNoMoreInteractions(clientPresenceManager);
}
@ -374,32 +335,6 @@ class AuthEnablementRefreshRequirementProviderTest {
provider.onWebSocketConnect(session);
}
@ParameterizedTest
@MethodSource("org.whispersystems.textsecuregcm.auth.AuthEnablementRefreshRequirementProviderTest#testAccountEnabledChanged")
void testAccountEnabledChangedWebSocket(final long authenticatedDeviceId, final boolean initialEnabled,
final boolean finalEnabled) throws Exception {
DevicesHelper.setEnabled(account.getMasterDevice().orElseThrow(), initialEnabled);
authenticatedDevice = account.getDevice(authenticatedDeviceId).orElseThrow();
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT",
"/v1/test/account/enabled/" + finalEnabled,
new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
final SubProtocol.WebSocketResponseMessage response = verifyAndGetResponse(remoteEndpoint);
assertEquals(200, response.getStatus());
if (initialEnabled != finalEnabled) {
verify(clientPresenceManager, times(account.getDevices().size())).displacePresence(eq(account.getUuid()),
anyLong());
} else {
verifyNoInteractions(clientPresenceManager);
}
}
@Test
void testOnEvent() throws Exception {