Cycle all connected websockets on any device or account enabled state change

This commit is contained in:
Jon Chambers 2021-09-29 11:10:05 -04:00 committed by Jon Chambers
parent c6bb649adb
commit 8359ef73f4
2 changed files with 35 additions and 128 deletions

View File

@ -7,11 +7,9 @@ package org.whispersystems.textsecuregcm.auth;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -26,11 +24,9 @@ import org.whispersystems.textsecuregcm.util.Pair;
* This {@link WebsocketRefreshRequirementProvider} observes intra-request changes in {@link Account#isEnabled()} and * This {@link WebsocketRefreshRequirementProvider} observes intra-request changes in {@link Account#isEnabled()} and
* {@link Device#isEnabled()}. * {@link Device#isEnabled()}.
* <p> * <p>
* If a change in {@link Account#isEnabled()} is observed, then any active WebSocket connections for the account must be * If a change in {@link Account#isEnabled()} or any associated {@link Device#isEnabled()} is observed, then any active
* closed, in order for clients to get a refreshed {@link io.dropwizard.auth.Auth} object. * WebSocket connections for the account must be closed in order for clients to get a refreshed
* <p> * {@link io.dropwizard.auth.Auth} object with a current device list.
* If a change in {@link Device#isEnabled()} is observed, including deletion of the {@link Device}, then any active
* WebSocket connections for the device must be closed and re-authenticated.
* *
* @see AuthenticatedAccount * @see AuthenticatedAccount
* @see DisabledPermittedAuthenticatedAccount * @see DisabledPermittedAuthenticatedAccount
@ -39,7 +35,6 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
private static final Logger logger = LoggerFactory.getLogger(AuthEnablementRefreshRequirementProvider.class); private static final Logger logger = LoggerFactory.getLogger(AuthEnablementRefreshRequirementProvider.class);
private static final String ACCOUNT_ENABLED = AuthEnablementRefreshRequirementProvider.class.getName() + ".accountEnabled";
private static final String DEVICES_ENABLED = AuthEnablementRefreshRequirementProvider.class.getName() + ".devicesEnabled"; private static final String DEVICES_ENABLED = AuthEnablementRefreshRequirementProvider.class.getName() + ".devicesEnabled";
@VisibleForTesting @VisibleForTesting
@ -50,57 +45,34 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
@Override @Override
public void handleRequestFiltered(final ContainerRequest request) { public void handleRequestFiltered(final ContainerRequest request) {
// The authenticated principal, if any, will be available after filters have run. // The authenticated principal, if any, will be available after filters have run.
// Now that the account is known, capture a snapshot of `isEnabled` for the account and its devices, // Now that the account is known, capture a snapshot of `isEnabled` for the account's devices before carrying out
// before carrying out the requests business logic. // the requests business logic.
ContainerRequestUtil.getAuthenticatedAccount(request) ContainerRequestUtil.getAuthenticatedAccount(request)
.ifPresent( .ifPresent(account -> request.setProperty(DEVICES_ENABLED, buildDevicesEnabledMap(account)));
account -> {
request.setProperty(ACCOUNT_ENABLED, account.isEnabled());
request.setProperty(DEVICES_ENABLED, buildDevicesEnabledMap(account));
});
} }
@Override @Override
public List<Pair<UUID, Long>> handleRequestFinished(final ContainerRequest request) { public List<Pair<UUID, Long>> handleRequestFinished(final ContainerRequest request) {
// Now that the request is finished, check whether `isEnabled` changed for any of the devices, or the account // Now that the request is finished, check whether `isEnabled` changed for any of the devices. If the value did
// as a whole. If the value did change, the affected device(s) must disconnect and reauthenticate. // change or if a devices was added or removed, all devices must disconnect and reauthenticate.
// If a device was removed, it must also disconnect. if (request.getProperty(DEVICES_ENABLED) != null) {
if (request.getProperty(ACCOUNT_ENABLED) != null &&
request.getProperty(DEVICES_ENABLED) != null) {
final boolean accountInitiallyEnabled = (boolean) request.getProperty(ACCOUNT_ENABLED);
@SuppressWarnings("unchecked") final Map<Long, Boolean> initialDevicesEnabled = @SuppressWarnings("unchecked") final Map<Long, Boolean> initialDevicesEnabled =
(Map<Long, Boolean>) request.getProperty(DEVICES_ENABLED); (Map<Long, Boolean>) request.getProperty(DEVICES_ENABLED);
return ContainerRequestUtil.getAuthenticatedAccount(request).map(account -> { return ContainerRequestUtil.getAuthenticatedAccount(request).map(account -> {
final Set<Long> deviceIdsToDisplace; final Set<Long> deviceIdsToDisplace;
final Map<Long, Boolean> currentDevicesEnabled = buildDevicesEnabledMap(account);
if (account.isEnabled() != accountInitiallyEnabled) { if (!initialDevicesEnabled.equals(currentDevicesEnabled)) {
// the @Auth for all active connections must change when account.isEnabled() changes deviceIdsToDisplace = new HashSet<>(initialDevicesEnabled.keySet());
deviceIdsToDisplace = account.getDevices().stream() deviceIdsToDisplace.addAll(currentDevicesEnabled.keySet());
.map(Device::getId).collect(Collectors.toSet());
deviceIdsToDisplace.addAll(initialDevicesEnabled.keySet());
} else if (!initialDevicesEnabled.isEmpty()) {
deviceIdsToDisplace = new HashSet<>();
final Map<Long, Boolean> currentDevicesEnabled = buildDevicesEnabledMap(account);
initialDevicesEnabled.forEach((deviceId, enabled) -> {
// `null` indicates the device was removed from the account. Any active presence should be removed.
final boolean enabledMatches = Objects.equals(enabled,
currentDevicesEnabled.getOrDefault(deviceId, null));
if (!enabledMatches) {
deviceIdsToDisplace.add(deviceId);
}
});
} else { } else {
deviceIdsToDisplace = Collections.emptySet(); deviceIdsToDisplace = Collections.emptySet();
} }
return deviceIdsToDisplace.stream().map(deviceId -> new Pair<>(account.getUuid(), deviceId)) return deviceIdsToDisplace.stream()
.map(deviceId -> new Pair<>(account.getUuid(), deviceId))
.collect(Collectors.toList()); .collect(Collectors.toList());
}).orElseGet(() -> { }).orElseGet(() -> {
logger.error("Request had account, but it is no longer present"); logger.error("Request had account, but it is no longer present");

View File

@ -10,12 +10,9 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -90,7 +87,7 @@ class AuthEnablementRefreshRequirementProviderTest {
private final ApplicationEventListener applicationEventListener = mock(ApplicationEventListener.class); private final ApplicationEventListener applicationEventListener = mock(ApplicationEventListener.class);
private final Account account = new Account(); private final Account account = new Account();
private Device authenticatedDevice = DevicesHelper.createDevice(1L); private final Device authenticatedDevice = DevicesHelper.createDevice(1L);
private final Supplier<Optional<TestPrincipal>> principalSupplier = () -> Optional.of( private final Supplier<Optional<TestPrincipal>> principalSupplier = () -> Optional.of(
new TestPrincipal("test", account, authenticatedDevice)); new TestPrincipal("test", account, authenticatedDevice));
@ -123,9 +120,7 @@ class AuthEnablementRefreshRequirementProviderTest {
final UUID uuid = UUID.randomUUID(); final UUID uuid = UUID.randomUUID();
account.setUuid(uuid); account.setUuid(uuid);
account.addDevice(authenticatedDevice); account.addDevice(authenticatedDevice);
LongStream.range(2, 4).forEach(deviceId -> { LongStream.range(2, 4).forEach(deviceId -> account.addDevice(DevicesHelper.createDevice(deviceId)));
account.addDevice(DevicesHelper.createDevice(deviceId));
});
account.getDevices() account.getDevices()
.forEach(device -> when(clientPresenceManager.isPresent(uuid, device.getId())).thenReturn(true)); .forEach(device -> when(clientPresenceManager.isPresent(uuid, device.getId())).thenReturn(true));
@ -163,45 +158,6 @@ class AuthEnablementRefreshRequirementProviderTest {
})); }));
} }
@ParameterizedTest
@MethodSource
void testAccountEnabledChanged(final long authenticatedDeviceId, final boolean initialEnabled,
final boolean finalEnabled) {
DevicesHelper.setEnabled(account.getMasterDevice().orElseThrow(), initialEnabled);
authenticatedDevice = account.getDevice(authenticatedDeviceId).orElseThrow();
final Response response = resources.getJerseyTest()
.target("/v1/test/account/enabled/" + finalEnabled)
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.put(Entity.entity("", MediaType.TEXT_PLAIN));
assertEquals(200, response.getStatus());
if (initialEnabled != finalEnabled) {
verify(clientPresenceManager, times(account.getDevices().size())).displacePresence(eq(account.getUuid()),
anyLong());
} else {
verifyNoInteractions(clientPresenceManager);
}
}
static Stream<Arguments> testAccountEnabledChanged() {
return Stream.of(
Arguments.of(1L, true, false),
Arguments.of(1L, false, true),
Arguments.of(1L, true, true),
Arguments.of(1L, false, false),
Arguments.of(2L, true, false),
Arguments.of(2L, false, true),
Arguments.of(2L, true, true),
Arguments.of(2L, false, false)
);
}
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void testDeviceEnabledChanged(final Map<Long, Boolean> initialEnabled, final Map<Long, Boolean> finalEnabled) { void testDeviceEnabledChanged(final Map<Long, Boolean> initialEnabled, final Map<Long, Boolean> finalEnabled) {
@ -221,21 +177,22 @@ class AuthEnablementRefreshRequirementProviderTest {
assertEquals(200, response.getStatus()); assertEquals(200, response.getStatus());
assertAll( final boolean expectDisplacedPresence = !initialEnabled.equals(finalEnabled);
finalEnabled.entrySet().stream()
.map(deviceIdEnabled -> () -> {
final boolean expectDisplacedPresence =
initialEnabled.get(deviceIdEnabled.getKey()) != deviceIdEnabled.getValue();
verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0)).displacePresence(account.getUuid(), assertAll(
deviceIdEnabled.getKey()); initialEnabled.keySet().stream()
}) .map(deviceId -> () -> verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0))
); .displacePresence(account.getUuid(), deviceId)));
assertAll(
finalEnabled.keySet().stream()
.map(deviceId -> () -> verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0))
.displacePresence(account.getUuid(), deviceId)));
} }
static Stream<Arguments> testDeviceEnabledChanged() { static Stream<Arguments> testDeviceEnabledChanged() {
return Stream.of( return Stream.of(
// Not testing device ID 1L because that will trigger "account enabled changed" Arguments.of(Map.of(1L, false, 2L, false), Map.of(1L, true, 2L, false)),
Arguments.of(Map.of(2L, false, 3L, false), Map.of(2L, true, 3L, true)), Arguments.of(Map.of(2L, false, 3L, false), Map.of(2L, true, 3L, true)),
Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, false, 3L, false)), Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, false, 3L, false)),
Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, true, 3L, true)), Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, true, 3L, true)),
@ -262,7 +219,9 @@ class AuthEnablementRefreshRequirementProviderTest {
assertEquals(initialDeviceCount + addedDeviceNames.size(), account.getDevices().size()); assertEquals(initialDeviceCount + addedDeviceNames.size(), account.getDevices().size());
verifyNoInteractions(clientPresenceManager); verify(clientPresenceManager).displacePresence(account.getUuid(), 1);
verify(clientPresenceManager).displacePresence(account.getUuid(), 2);
verify(clientPresenceManager).displacePresence(account.getUuid(), 3);
} }
@ParameterizedTest @ParameterizedTest
@ -270,6 +229,8 @@ class AuthEnablementRefreshRequirementProviderTest {
void testDeviceRemoved(final int removedDeviceCount) { void testDeviceRemoved(final int removedDeviceCount) {
assert account.getMasterDevice().orElseThrow().isEnabled(); assert account.getMasterDevice().orElseThrow().isEnabled();
final List<Long> initialDeviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toList());
final List<Long> deletedDeviceIds = account.getDevices().stream() final List<Long> deletedDeviceIds = account.getDevices().stream()
.map(Device::getId) .map(Device::getId)
.filter(deviceId -> deviceId != 1L) .filter(deviceId -> deviceId != 1L)
@ -290,8 +251,8 @@ class AuthEnablementRefreshRequirementProviderTest {
assertEquals(200, response.getStatus()); assertEquals(200, response.getStatus());
deletedDeviceIds.forEach(deletedDeviceId -> initialDeviceIds.forEach(deviceId ->
verify(clientPresenceManager).displacePresence(account.getUuid(), deletedDeviceId)); verify(clientPresenceManager).displacePresence(account.getUuid(), deviceId));
verifyNoMoreInteractions(clientPresenceManager); verifyNoMoreInteractions(clientPresenceManager);
} }
@ -374,32 +335,6 @@ class AuthEnablementRefreshRequirementProviderTest {
provider.onWebSocketConnect(session); provider.onWebSocketConnect(session);
} }
@ParameterizedTest
@MethodSource("org.whispersystems.textsecuregcm.auth.AuthEnablementRefreshRequirementProviderTest#testAccountEnabledChanged")
void testAccountEnabledChangedWebSocket(final long authenticatedDeviceId, final boolean initialEnabled,
final boolean finalEnabled) throws Exception {
DevicesHelper.setEnabled(account.getMasterDevice().orElseThrow(), initialEnabled);
authenticatedDevice = account.getDevice(authenticatedDeviceId).orElseThrow();
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT",
"/v1/test/account/enabled/" + finalEnabled,
new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
final SubProtocol.WebSocketResponseMessage response = verifyAndGetResponse(remoteEndpoint);
assertEquals(200, response.getStatus());
if (initialEnabled != finalEnabled) {
verify(clientPresenceManager, times(account.getDevices().size())).displacePresence(eq(account.getUuid()),
anyLong());
} else {
verifyNoInteractions(clientPresenceManager);
}
}
@Test @Test
void testOnEvent() throws Exception { void testOnEvent() throws Exception {