Cycle all connected websockets on any device or account enabled state change
This commit is contained in:
parent
c6bb649adb
commit
8359ef73f4
|
@ -7,11 +7,9 @@ package org.whispersystems.textsecuregcm.auth;
|
|||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import java.util.stream.Collectors;
|
||||
|
@ -26,11 +24,9 @@ import org.whispersystems.textsecuregcm.util.Pair;
|
|||
* This {@link WebsocketRefreshRequirementProvider} observes intra-request changes in {@link Account#isEnabled()} and
|
||||
* {@link Device#isEnabled()}.
|
||||
* <p>
|
||||
* If a change in {@link Account#isEnabled()} is observed, then any active WebSocket connections for the account must be
|
||||
* closed, in order for clients to get a refreshed {@link io.dropwizard.auth.Auth} object.
|
||||
* <p>
|
||||
* If a change in {@link Device#isEnabled()} is observed, including deletion of the {@link Device}, then any active
|
||||
* WebSocket connections for the device must be closed and re-authenticated.
|
||||
* If a change in {@link Account#isEnabled()} or any associated {@link Device#isEnabled()} is observed, then any active
|
||||
* WebSocket connections for the account must be closed in order for clients to get a refreshed
|
||||
* {@link io.dropwizard.auth.Auth} object with a current device list.
|
||||
*
|
||||
* @see AuthenticatedAccount
|
||||
* @see DisabledPermittedAuthenticatedAccount
|
||||
|
@ -39,7 +35,6 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
|
|||
|
||||
private static final Logger logger = LoggerFactory.getLogger(AuthEnablementRefreshRequirementProvider.class);
|
||||
|
||||
private static final String ACCOUNT_ENABLED = AuthEnablementRefreshRequirementProvider.class.getName() + ".accountEnabled";
|
||||
private static final String DEVICES_ENABLED = AuthEnablementRefreshRequirementProvider.class.getName() + ".devicesEnabled";
|
||||
|
||||
@VisibleForTesting
|
||||
|
@ -50,57 +45,34 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
|
|||
@Override
|
||||
public void handleRequestFiltered(final ContainerRequest request) {
|
||||
// The authenticated principal, if any, will be available after filters have run.
|
||||
// Now that the account is known, capture a snapshot of `isEnabled` for the account and its devices,
|
||||
// before carrying out the request’s business logic.
|
||||
// Now that the account is known, capture a snapshot of `isEnabled` for the account's devices before carrying out
|
||||
// the request’s business logic.
|
||||
ContainerRequestUtil.getAuthenticatedAccount(request)
|
||||
.ifPresent(
|
||||
account -> {
|
||||
request.setProperty(ACCOUNT_ENABLED, account.isEnabled());
|
||||
request.setProperty(DEVICES_ENABLED, buildDevicesEnabledMap(account));
|
||||
});
|
||||
.ifPresent(account -> request.setProperty(DEVICES_ENABLED, buildDevicesEnabledMap(account)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Pair<UUID, Long>> handleRequestFinished(final ContainerRequest request) {
|
||||
// Now that the request is finished, check whether `isEnabled` changed for any of the devices, or the account
|
||||
// as a whole. If the value did change, the affected device(s) must disconnect and reauthenticate.
|
||||
// If a device was removed, it must also disconnect.
|
||||
if (request.getProperty(ACCOUNT_ENABLED) != null &&
|
||||
request.getProperty(DEVICES_ENABLED) != null) {
|
||||
// Now that the request is finished, check whether `isEnabled` changed for any of the devices. If the value did
|
||||
// change or if a devices was added or removed, all devices must disconnect and reauthenticate.
|
||||
if (request.getProperty(DEVICES_ENABLED) != null) {
|
||||
|
||||
final boolean accountInitiallyEnabled = (boolean) request.getProperty(ACCOUNT_ENABLED);
|
||||
@SuppressWarnings("unchecked") final Map<Long, Boolean> initialDevicesEnabled =
|
||||
(Map<Long, Boolean>) request.getProperty(DEVICES_ENABLED);
|
||||
|
||||
return ContainerRequestUtil.getAuthenticatedAccount(request).map(account -> {
|
||||
final Set<Long> deviceIdsToDisplace;
|
||||
final Map<Long, Boolean> currentDevicesEnabled = buildDevicesEnabledMap(account);
|
||||
|
||||
if (account.isEnabled() != accountInitiallyEnabled) {
|
||||
// the @Auth for all active connections must change when account.isEnabled() changes
|
||||
deviceIdsToDisplace = account.getDevices().stream()
|
||||
.map(Device::getId).collect(Collectors.toSet());
|
||||
|
||||
deviceIdsToDisplace.addAll(initialDevicesEnabled.keySet());
|
||||
|
||||
} else if (!initialDevicesEnabled.isEmpty()) {
|
||||
|
||||
deviceIdsToDisplace = new HashSet<>();
|
||||
final Map<Long, Boolean> currentDevicesEnabled = buildDevicesEnabledMap(account);
|
||||
|
||||
initialDevicesEnabled.forEach((deviceId, enabled) -> {
|
||||
// `null` indicates the device was removed from the account. Any active presence should be removed.
|
||||
final boolean enabledMatches = Objects.equals(enabled,
|
||||
currentDevicesEnabled.getOrDefault(deviceId, null));
|
||||
|
||||
if (!enabledMatches) {
|
||||
deviceIdsToDisplace.add(deviceId);
|
||||
}
|
||||
});
|
||||
if (!initialDevicesEnabled.equals(currentDevicesEnabled)) {
|
||||
deviceIdsToDisplace = new HashSet<>(initialDevicesEnabled.keySet());
|
||||
deviceIdsToDisplace.addAll(currentDevicesEnabled.keySet());
|
||||
} else {
|
||||
deviceIdsToDisplace = Collections.emptySet();
|
||||
}
|
||||
|
||||
return deviceIdsToDisplace.stream().map(deviceId -> new Pair<>(account.getUuid(), deviceId))
|
||||
return deviceIdsToDisplace.stream()
|
||||
.map(deviceId -> new Pair<>(account.getUuid(), deviceId))
|
||||
.collect(Collectors.toList());
|
||||
}).orElseGet(() -> {
|
||||
logger.error("Request had account, but it is no longer present");
|
||||
|
|
|
@ -10,12 +10,9 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoInteractions;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
|
@ -90,7 +87,7 @@ class AuthEnablementRefreshRequirementProviderTest {
|
|||
private final ApplicationEventListener applicationEventListener = mock(ApplicationEventListener.class);
|
||||
|
||||
private final Account account = new Account();
|
||||
private Device authenticatedDevice = DevicesHelper.createDevice(1L);
|
||||
private final Device authenticatedDevice = DevicesHelper.createDevice(1L);
|
||||
|
||||
private final Supplier<Optional<TestPrincipal>> principalSupplier = () -> Optional.of(
|
||||
new TestPrincipal("test", account, authenticatedDevice));
|
||||
|
@ -123,9 +120,7 @@ class AuthEnablementRefreshRequirementProviderTest {
|
|||
final UUID uuid = UUID.randomUUID();
|
||||
account.setUuid(uuid);
|
||||
account.addDevice(authenticatedDevice);
|
||||
LongStream.range(2, 4).forEach(deviceId -> {
|
||||
account.addDevice(DevicesHelper.createDevice(deviceId));
|
||||
});
|
||||
LongStream.range(2, 4).forEach(deviceId -> account.addDevice(DevicesHelper.createDevice(deviceId)));
|
||||
|
||||
account.getDevices()
|
||||
.forEach(device -> when(clientPresenceManager.isPresent(uuid, device.getId())).thenReturn(true));
|
||||
|
@ -163,45 +158,6 @@ class AuthEnablementRefreshRequirementProviderTest {
|
|||
}));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void testAccountEnabledChanged(final long authenticatedDeviceId, final boolean initialEnabled,
|
||||
final boolean finalEnabled) {
|
||||
|
||||
DevicesHelper.setEnabled(account.getMasterDevice().orElseThrow(), initialEnabled);
|
||||
|
||||
authenticatedDevice = account.getDevice(authenticatedDeviceId).orElseThrow();
|
||||
|
||||
final Response response = resources.getJerseyTest()
|
||||
.target("/v1/test/account/enabled/" + finalEnabled)
|
||||
.request()
|
||||
.header("Authorization",
|
||||
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
|
||||
.put(Entity.entity("", MediaType.TEXT_PLAIN));
|
||||
|
||||
assertEquals(200, response.getStatus());
|
||||
|
||||
if (initialEnabled != finalEnabled) {
|
||||
verify(clientPresenceManager, times(account.getDevices().size())).displacePresence(eq(account.getUuid()),
|
||||
anyLong());
|
||||
} else {
|
||||
verifyNoInteractions(clientPresenceManager);
|
||||
}
|
||||
}
|
||||
|
||||
static Stream<Arguments> testAccountEnabledChanged() {
|
||||
return Stream.of(
|
||||
Arguments.of(1L, true, false),
|
||||
Arguments.of(1L, false, true),
|
||||
Arguments.of(1L, true, true),
|
||||
Arguments.of(1L, false, false),
|
||||
Arguments.of(2L, true, false),
|
||||
Arguments.of(2L, false, true),
|
||||
Arguments.of(2L, true, true),
|
||||
Arguments.of(2L, false, false)
|
||||
);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource
|
||||
void testDeviceEnabledChanged(final Map<Long, Boolean> initialEnabled, final Map<Long, Boolean> finalEnabled) {
|
||||
|
@ -221,21 +177,22 @@ class AuthEnablementRefreshRequirementProviderTest {
|
|||
|
||||
assertEquals(200, response.getStatus());
|
||||
|
||||
assertAll(
|
||||
finalEnabled.entrySet().stream()
|
||||
.map(deviceIdEnabled -> () -> {
|
||||
final boolean expectDisplacedPresence =
|
||||
initialEnabled.get(deviceIdEnabled.getKey()) != deviceIdEnabled.getValue();
|
||||
final boolean expectDisplacedPresence = !initialEnabled.equals(finalEnabled);
|
||||
|
||||
verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0)).displacePresence(account.getUuid(),
|
||||
deviceIdEnabled.getKey());
|
||||
})
|
||||
);
|
||||
assertAll(
|
||||
initialEnabled.keySet().stream()
|
||||
.map(deviceId -> () -> verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0))
|
||||
.displacePresence(account.getUuid(), deviceId)));
|
||||
|
||||
assertAll(
|
||||
finalEnabled.keySet().stream()
|
||||
.map(deviceId -> () -> verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0))
|
||||
.displacePresence(account.getUuid(), deviceId)));
|
||||
}
|
||||
|
||||
static Stream<Arguments> testDeviceEnabledChanged() {
|
||||
return Stream.of(
|
||||
// Not testing device ID 1L because that will trigger "account enabled changed"
|
||||
Arguments.of(Map.of(1L, false, 2L, false), Map.of(1L, true, 2L, false)),
|
||||
Arguments.of(Map.of(2L, false, 3L, false), Map.of(2L, true, 3L, true)),
|
||||
Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, false, 3L, false)),
|
||||
Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, true, 3L, true)),
|
||||
|
@ -262,7 +219,9 @@ class AuthEnablementRefreshRequirementProviderTest {
|
|||
|
||||
assertEquals(initialDeviceCount + addedDeviceNames.size(), account.getDevices().size());
|
||||
|
||||
verifyNoInteractions(clientPresenceManager);
|
||||
verify(clientPresenceManager).displacePresence(account.getUuid(), 1);
|
||||
verify(clientPresenceManager).displacePresence(account.getUuid(), 2);
|
||||
verify(clientPresenceManager).displacePresence(account.getUuid(), 3);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
|
@ -270,6 +229,8 @@ class AuthEnablementRefreshRequirementProviderTest {
|
|||
void testDeviceRemoved(final int removedDeviceCount) {
|
||||
assert account.getMasterDevice().orElseThrow().isEnabled();
|
||||
|
||||
final List<Long> initialDeviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toList());
|
||||
|
||||
final List<Long> deletedDeviceIds = account.getDevices().stream()
|
||||
.map(Device::getId)
|
||||
.filter(deviceId -> deviceId != 1L)
|
||||
|
@ -290,8 +251,8 @@ class AuthEnablementRefreshRequirementProviderTest {
|
|||
|
||||
assertEquals(200, response.getStatus());
|
||||
|
||||
deletedDeviceIds.forEach(deletedDeviceId ->
|
||||
verify(clientPresenceManager).displacePresence(account.getUuid(), deletedDeviceId));
|
||||
initialDeviceIds.forEach(deviceId ->
|
||||
verify(clientPresenceManager).displacePresence(account.getUuid(), deviceId));
|
||||
|
||||
verifyNoMoreInteractions(clientPresenceManager);
|
||||
}
|
||||
|
@ -374,32 +335,6 @@ class AuthEnablementRefreshRequirementProviderTest {
|
|||
provider.onWebSocketConnect(session);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("org.whispersystems.textsecuregcm.auth.AuthEnablementRefreshRequirementProviderTest#testAccountEnabledChanged")
|
||||
void testAccountEnabledChangedWebSocket(final long authenticatedDeviceId, final boolean initialEnabled,
|
||||
final boolean finalEnabled) throws Exception {
|
||||
|
||||
DevicesHelper.setEnabled(account.getMasterDevice().orElseThrow(), initialEnabled);
|
||||
|
||||
authenticatedDevice = account.getDevice(authenticatedDeviceId).orElseThrow();
|
||||
|
||||
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT",
|
||||
"/v1/test/account/enabled/" + finalEnabled,
|
||||
new LinkedList<>(), Optional.empty()).toByteArray();
|
||||
|
||||
provider.onWebSocketBinary(message, 0, message.length);
|
||||
|
||||
final SubProtocol.WebSocketResponseMessage response = verifyAndGetResponse(remoteEndpoint);
|
||||
|
||||
assertEquals(200, response.getStatus());
|
||||
if (initialEnabled != finalEnabled) {
|
||||
verify(clientPresenceManager, times(account.getDevices().size())).displacePresence(eq(account.getUuid()),
|
||||
anyLong());
|
||||
} else {
|
||||
verifyNoInteractions(clientPresenceManager);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testOnEvent() throws Exception {
|
||||
|
||||
|
|
Loading…
Reference in New Issue