Stop monitoring device "enabled" state changes from auth enablement refresh requirement provider

Device enabled states no longer affect anything at an authentication level
This commit is contained in:
Jon Chambers 2024-06-10 11:47:20 -04:00 committed by Jon Chambers
parent 2f76738b50
commit 2d1610b075
9 changed files with 32 additions and 217 deletions

View File

@ -8,7 +8,6 @@ package org.whispersystems.textsecuregcm.auth;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
@ -18,16 +17,13 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair;
/**
* This {@link WebsocketRefreshRequirementProvider} observes intra-request changes in
* {@link Device#hasMessageDeliveryChannel()}.
* <p>
* If a change in any associated {@link Device#hasMessageDeliveryChannel()} is observed, then any active WebSocket
* connections for the account must be closed in order for clients to get a refreshed {@link io.dropwizard.auth.Auth}
* object with a current device list.
* This {@link WebsocketRefreshRequirementProvider} observes intra-request changes in devices linked to an
* {@link Account} and triggers a WebSocket refresh if that set changes. If a change in linked devices is observed, then
* any active WebSocket connections for the account must be closed in order for clients to get a refreshed
* {@link io.dropwizard.auth.Auth} object with a current device list.
*
* @see AuthenticatedAccount
*/
@ -38,55 +34,56 @@ public class AuthEnablementRefreshRequirementProvider implements WebsocketRefres
private static final Logger logger = LoggerFactory.getLogger(AuthEnablementRefreshRequirementProvider.class);
private static final String ACCOUNT_UUID = AuthEnablementRefreshRequirementProvider.class.getName() + ".accountUuid";
private static final String DEVICES_ENABLED = AuthEnablementRefreshRequirementProvider.class.getName() + ".devicesEnabled";
private static final String LINKED_DEVICE_IDS = AuthEnablementRefreshRequirementProvider.class.getName() + ".deviceIds";
public AuthEnablementRefreshRequirementProvider(final AccountsManager accountsManager) {
this.accountsManager = accountsManager;
}
@Override
public void handleRequestFiltered(final RequestEvent requestEvent) {
if (requestEvent.getUriInfo().getMatchedResourceMethod().getInvocable().getHandlingMethod().getAnnotation(ChangesDeviceEnabledState.class) != null) {
if (requestEvent.getUriInfo().getMatchedResourceMethod().getInvocable().getHandlingMethod().getAnnotation(
ChangesLinkedDevices.class) != null) {
// The authenticated principal, if any, will be available after filters have run. Now that the account is known,
// capture a snapshot of the account's devices before carrying out the requests business logic.
ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest()).ifPresent(account ->
setAccount(requestEvent.getContainerRequest(), account));
// capture a snapshot of the account's linked devices before carrying out the requests business logic.
ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest())
.ifPresent(account -> setAccount(requestEvent.getContainerRequest(), account));
}
}
public static void setAccount(final ContainerRequest containerRequest, final Account account) {
setAccount(containerRequest, ContainerRequestUtil.AccountInfo.fromAccount(account));
}
private static void setAccount(final ContainerRequest containerRequest, final ContainerRequestUtil.AccountInfo info) {
containerRequest.setProperty(ACCOUNT_UUID, info.accountId());
containerRequest.setProperty(DEVICES_ENABLED, info.devicesEnabled());
containerRequest.setProperty(LINKED_DEVICE_IDS, info.deviceIds());
}
@Override
public List<Pair<UUID, Byte>> handleRequestFinished(final RequestEvent requestEvent) {
// Now that the request is finished, check whether `hasMessageDeliveryChannel` changed for any of the devices. If
// the value did change or if a devices was added or removed, all devices must disconnect and reauthenticate.
if (requestEvent.getContainerRequest().getProperty(DEVICES_ENABLED) != null) {
// Now that the request is finished, check whether the set of linked devices has changed. If the value did change or
// if a devices was added or removed, all devices must disconnect and reauthenticate.
if (requestEvent.getContainerRequest().getProperty(LINKED_DEVICE_IDS) != null) {
@SuppressWarnings("unchecked") final Map<Byte, Boolean> initialDevicesEnabled =
(Map<Byte, Boolean>) requestEvent.getContainerRequest().getProperty(DEVICES_ENABLED);
@SuppressWarnings("unchecked") final Set<Byte> initialLinkedDeviceIds =
(Set<Byte>) requestEvent.getContainerRequest().getProperty(LINKED_DEVICE_IDS);
return accountsManager.getByAccountIdentifier((UUID) requestEvent.getContainerRequest().getProperty(ACCOUNT_UUID))
.map(ContainerRequestUtil.AccountInfo::fromAccount)
.map(account -> {
.map(accountInfo -> {
final Set<Byte> deviceIdsToDisplace;
final Map<Byte, Boolean> currentDevicesEnabled = account.devicesEnabled();
final Set<Byte> currentLinkedDeviceIds = accountInfo.deviceIds();
if (!initialDevicesEnabled.equals(currentDevicesEnabled)) {
deviceIdsToDisplace = new HashSet<>(initialDevicesEnabled.keySet());
deviceIdsToDisplace.addAll(currentDevicesEnabled.keySet());
if (!initialLinkedDeviceIds.equals(currentLinkedDeviceIds)) {
deviceIdsToDisplace = new HashSet<>(initialLinkedDeviceIds);
deviceIdsToDisplace.addAll(currentLinkedDeviceIds);
} else {
deviceIdsToDisplace = Collections.emptySet();
}
return deviceIdsToDisplace.stream()
.map(deviceId -> new Pair<>(account.accountId(), deviceId))
.map(deviceId -> new Pair<>(accountInfo.accountId(), deviceId))
.collect(Collectors.toList());
}).orElseGet(() -> {
logger.error("Request had account, but it is no longer present");

View File

@ -16,5 +16,5 @@ import java.lang.annotation.Target;
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface ChangesDeviceEnabledState {
public @interface ChangesLinkedDevices {
}

View File

@ -11,26 +11,23 @@ import org.whispersystems.textsecuregcm.storage.Device;
import javax.ws.rs.core.SecurityContext;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
class ContainerRequestUtil {
private static Map<Byte, Boolean> buildDevicesEnabledMap(final Account account) {
return account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::hasMessageDeliveryChannel));
}
/**
* A read-only subset of the authenticated Account object, to enforce that filter-based consumers do not perform
* account modifying operations.
*/
record AccountInfo(UUID accountId, String e164, Map<Byte, Boolean> devicesEnabled) {
record AccountInfo(UUID accountId, String e164, Set<Byte> deviceIds) {
static AccountInfo fromAccount(final Account account) {
return new AccountInfo(
account.getUuid(),
account.getNumber(),
buildDevicesEnabledMap(account));
account.getDevices().stream().map(Device::getId).collect(Collectors.toSet()));
}
}

View File

@ -34,7 +34,6 @@ import javax.ws.rs.core.Response.Status;
import org.signal.libsignal.usernames.BaseUsernameException;
import org.whispersystems.textsecuregcm.auth.AccountAndAuthenticatedDeviceHolder;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.ChangesDeviceEnabledState;
import org.whispersystems.textsecuregcm.auth.SaltedTokenHash;
import org.whispersystems.textsecuregcm.auth.TurnToken;
import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator;
@ -109,7 +108,6 @@ public class AccountController {
@Path("/gcm/")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@ChangesDeviceEnabledState
public void setGcmRegistrationId(@Mutable @Auth AuthenticatedAccount auth,
@NotNull @Valid GcmRegistrationId registrationId) {
@ -130,7 +128,6 @@ public class AccountController {
@DELETE
@Path("/gcm/")
@ChangesDeviceEnabledState
public void deleteGcmRegistrationId(@Mutable @Auth AuthenticatedAccount auth) {
Account account = auth.getAccount();
Device device = auth.getAuthenticatedDevice();
@ -146,7 +143,6 @@ public class AccountController {
@Path("/apn/")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@ChangesDeviceEnabledState
public void setApnRegistrationId(@Mutable @Auth AuthenticatedAccount auth,
@NotNull @Valid ApnRegistrationId registrationId) {
@ -165,7 +161,6 @@ public class AccountController {
@DELETE
@Path("/apn/")
@ChangesDeviceEnabledState
public void deleteApnRegistrationId(@Mutable @Auth AuthenticatedAccount auth) {
Account account = auth.getAccount();
Device device = auth.getAuthenticatedDevice();
@ -210,7 +205,6 @@ public class AccountController {
@Path("/attributes/")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@ChangesDeviceEnabledState
public void setAccountAttributes(
@Mutable @Auth AuthenticatedAccount auth,
@HeaderParam(HeaderUtils.X_SIGNAL_AGENT) String userAgent,

View File

@ -49,7 +49,7 @@ import org.glassfish.jersey.server.ContainerRequest;
import org.whispersystems.textsecuregcm.auth.AuthEnablementRefreshRequirementProvider;
import org.whispersystems.textsecuregcm.auth.AuthenticatedAccount;
import org.whispersystems.textsecuregcm.auth.BasicAuthorizationHeader;
import org.whispersystems.textsecuregcm.auth.ChangesDeviceEnabledState;
import org.whispersystems.textsecuregcm.auth.ChangesLinkedDevices;
import org.whispersystems.textsecuregcm.entities.AccountAttributes;
import org.whispersystems.textsecuregcm.entities.DeviceActivationRequest;
import org.whispersystems.textsecuregcm.entities.DeviceInfo;
@ -132,7 +132,7 @@ public class DeviceController {
@DELETE
@Produces(MediaType.APPLICATION_JSON)
@Path("/{device_id}")
@ChangesDeviceEnabledState
@ChangesLinkedDevices
public void removeDevice(@Mutable @Auth AuthenticatedAccount auth, @PathParam("device_id") byte deviceId) {
if (auth.getAuthenticatedDevice().getId() != Device.PRIMARY_ID) {
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
@ -176,7 +176,7 @@ public class DeviceController {
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
@Path("/link")
@ChangesDeviceEnabledState
@ChangesLinkedDevices
@Operation(summary = "Link a device to an account",
description = """
Links a device to an account identified by a given phone number.

View File

@ -131,54 +131,8 @@ class AuthEnablementRefreshRequirementProviderTest {
.forEach(device -> when(clientPresenceManager.isPresent(uuid, device.getId())).thenReturn(true));
}
@ParameterizedTest
@MethodSource
void testDeviceEnabledChanged(final Map<Byte, Boolean> initialEnabled, final Map<Byte, Boolean> finalEnabled) {
assert initialEnabled.size() == finalEnabled.size();
assert account.getPrimaryDevice().hasMessageDeliveryChannel();
initialEnabled.forEach((deviceId, enabled) ->
DevicesHelper.setEnabled(account.getDevice(deviceId).orElseThrow(), enabled));
final Response response = resources.getJerseyTest()
.target("/v1/test/account/devices/enabled")
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.post(Entity.entity(finalEnabled, MediaType.APPLICATION_JSON));
assertEquals(200, response.getStatus());
final boolean expectDisplacedPresence = !initialEnabled.equals(finalEnabled);
assertAll(
initialEnabled.keySet().stream()
.map(deviceId -> () -> verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0))
.disconnectPresence(account.getUuid(), deviceId)));
assertAll(
finalEnabled.keySet().stream()
.map(deviceId -> () -> verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0))
.disconnectPresence(account.getUuid(), deviceId)));
}
static Stream<Arguments> testDeviceEnabledChanged() {
final byte deviceId2 = 2;
final byte deviceId3 = 3;
return Stream.of(
Arguments.of(Map.of(deviceId2, false, deviceId3, false), Map.of(deviceId2, true, deviceId3, true)),
Arguments.of(Map.of(deviceId2, true, deviceId3, true), Map.of(deviceId2, false, deviceId3, false)),
Arguments.of(Map.of(deviceId2, true, deviceId3, true), Map.of(deviceId2, true, deviceId3, true)),
Arguments.of(Map.of(deviceId2, false, deviceId3, true), Map.of(deviceId2, true, deviceId3, true)),
Arguments.of(Map.of(deviceId2, true, deviceId3, false), Map.of(deviceId2, true, deviceId3, true))
);
}
@Test
void testDeviceAdded() {
assert account.getPrimaryDevice().hasMessageDeliveryChannel();
final int initialDeviceCount = account.getDevices().size();
final List<String> addedDeviceNames = List.of(
@ -204,8 +158,6 @@ class AuthEnablementRefreshRequirementProviderTest {
@ParameterizedTest
@ValueSource(ints = {1, 2})
void testDeviceRemoved(final int removedDeviceCount) {
assert account.getPrimaryDevice().hasMessageDeliveryChannel();
final List<Byte> initialDeviceIds = account.getDevices().stream().map(Device::getId).toList();
final List<Byte> deletedDeviceIds = account.getDevices().stream()
@ -358,40 +310,9 @@ class AuthEnablementRefreshRequirementProviderTest {
return "Youre in!";
}
@PUT
@Path("/account/enabled/{enabled}")
@ChangesDeviceEnabledState
public String setAccountEnabled(@Auth TestPrincipal principal, @PathParam("enabled") final boolean enabled) {
final Device device = principal.getAccount().getPrimaryDevice();
DevicesHelper.setEnabled(device, enabled);
assert device.hasMessageDeliveryChannel() == enabled;
return String.format("Set account to %s", enabled);
}
@POST
@Path("/account/devices/enabled")
@ChangesDeviceEnabledState
public String setEnabled(@Auth TestPrincipal principal, Map<Byte, Boolean> deviceIdsEnabled) {
final StringBuilder response = new StringBuilder();
for (Entry<Byte, Boolean> deviceIdEnabled : deviceIdsEnabled.entrySet()) {
final Device device = principal.getAccount().getDevice(deviceIdEnabled.getKey()).orElseThrow();
DevicesHelper.setEnabled(device, deviceIdEnabled.getValue());
response.append(String.format("Set device enabled %s", deviceIdEnabled));
}
return response.toString();
}
@PUT
@Path("/account/devices")
@ChangesDeviceEnabledState
@ChangesLinkedDevices
public String addDevices(@Auth TestPrincipal auth, List<byte[]> deviceNames) {
deviceNames.forEach(name -> {
@ -406,7 +327,7 @@ class AuthEnablementRefreshRequirementProviderTest {
@DELETE
@Path("/account/devices/{deviceIds}")
@ChangesDeviceEnabledState
@ChangesLinkedDevices
public String removeDevices(@Auth TestPrincipal auth, @PathParam("deviceIds") String deviceIds) {
Arrays.stream(deviceIds.split(","))
@ -415,17 +336,5 @@ class AuthEnablementRefreshRequirementProviderTest {
return "Removed device(s) " + deviceIds;
}
@POST
@Path("/account/disablePrimaryDeviceAndDeleteDevice/{deviceId}")
@ChangesDeviceEnabledState
public String disablePrimaryDeviceAndRemoveDevice(@Auth TestPrincipal auth, @PathParam("deviceId") byte deviceId) {
DevicesHelper.setEnabled(auth.getAccount().getPrimaryDevice(), false);
auth.getAccount().removeDevice(deviceId);
return "Removed device " + deviceId;
}
}
}

View File

@ -1,55 +0,0 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import org.junit.jupiter.api.Test;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.IntStream;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class ContainerRequestUtilTest {
@Test
void testBuildDevicesEnabled() {
final byte disabledDeviceId = 3;
final Account account = mock(Account.class);
final List<Device> devices = new ArrayList<>();
when(account.getDevices()).thenReturn(devices);
IntStream.range(1, 5)
.forEach(id -> {
final Device device = mock(Device.class);
when(device.getId()).thenReturn((byte) id);
when(device.hasMessageDeliveryChannel()).thenReturn(id != disabledDeviceId);
devices.add(device);
});
final Map<Byte, Boolean> devicesEnabled = ContainerRequestUtil.AccountInfo.fromAccount(account).devicesEnabled();
assertEquals(4, devicesEnabled.size());
assertAll(devicesEnabled.entrySet().stream()
.map(deviceAndEnabled -> () -> {
if (deviceAndEnabled.getKey().equals(disabledDeviceId)) {
assertFalse(deviceAndEnabled.getValue());
} else {
assertTrue(deviceAndEnabled.getValue());
}
}));
}
}

View File

@ -15,7 +15,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.tests.util.DevicesHelper.createDevice;
import static org.whispersystems.textsecuregcm.tests.util.DevicesHelper.setEnabled;
import com.fasterxml.jackson.annotation.JsonFilter;
import java.lang.annotation.Annotation;
@ -28,12 +27,8 @@ import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.tests.util.AccountsHelper;
import org.whispersystems.textsecuregcm.util.TestClock;
@ -207,12 +202,6 @@ class AccountTest {
final byte deviceId3 = 3;
assertThat(account.getNextDeviceId()).isEqualTo(deviceId3);
account.addDevice(createDevice(deviceId3));
setEnabled(account.getDevice(deviceId2).orElseThrow(), false);
assertThat(account.getNextDeviceId()).isEqualTo((byte) 4);
account.removeDevice(deviceId2);
assertThat(account.getNextDeviceId()).isEqualTo(deviceId2);

View File

@ -27,8 +27,6 @@ public class DevicesHelper {
device.setUserAgent("OWT");
device.setRegistrationId(registrationId);
setEnabled(device, true);
return device;
}
@ -38,20 +36,6 @@ public class DevicesHelper {
device.setUserAgent("OWT");
device.setRegistrationId(registrationId);
setEnabled(device, false);
return device;
}
public static void setEnabled(Device device, boolean enabled) {
if (enabled) {
device.setGcmId("testGcmId" + RANDOM.nextLong());
} else {
device.setGcmId(null);
}
// fail fast, to guard against a change to the isEnabled() implementation causing unexpected test behavior
assert enabled == device.hasMessageDeliveryChannel();
}
}