diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AccountAndAuthenticatedDeviceHolder.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AccountAndAuthenticatedDeviceHolder.java
new file mode 100644
index 000000000..bf10bd657
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AccountAndAuthenticatedDeviceHolder.java
@@ -0,0 +1,16 @@
+/*
+ * Copyright 2021 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.auth;
+
+import org.whispersystems.textsecuregcm.storage.Account;
+import org.whispersystems.textsecuregcm.storage.Device;
+
+public interface AccountAndAuthenticatedDeviceHolder {
+
+ Account getAccount();
+
+ Device getAuthenticatedDevice();
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRequestEventListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRequestEventListener.java
new file mode 100644
index 000000000..51c5d458d
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthEnablementRequestEventListener.java
@@ -0,0 +1,139 @@
+/*
+ * Copyright 2021 Signal Messenger, LLC
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+package org.whispersystems.textsecuregcm.auth;
+
+import com.google.common.annotations.VisibleForTesting;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+import javax.ws.rs.core.SecurityContext;
+import org.glassfish.jersey.server.ContainerRequest;
+import org.glassfish.jersey.server.monitoring.RequestEvent;
+import org.glassfish.jersey.server.monitoring.RequestEvent.Type;
+import org.glassfish.jersey.server.monitoring.RequestEventListener;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
+import org.whispersystems.textsecuregcm.storage.Account;
+import org.whispersystems.textsecuregcm.storage.Device;
+
+/**
+ * This {@link RequestEventListener} observes intra-request changes in {@link Account#isEnabled()} and {@link
+ * Device#isEnabled()}.
+ *
+ * If a change in {@link Account#isEnabled()} is observed, then any active WebSocket connections for the account must be
+ * closed, in order for clients to get a refreshed {@link io.dropwizard.auth.Auth} object.
+ *
+ * If a change in {@link Device#isEnabled()} is observed, including deletion of the {@link Device}, then any active
+ * WebSocket connections for the device must be closed and re-authenticated.
+ *
+ * @see AuthenticatedAccount
+ * @see DisabledPermittedAuthenticatedAccount
+ */
+public class AuthEnablementRequestEventListener implements RequestEventListener {
+
+ private static final Logger logger = LoggerFactory.getLogger(AuthEnablementRequestEventListener.class);
+
+ private static final String ACCOUNT_ENABLED = AuthEnablementRequestEventListener.class.getName() + ".accountEnabled";
+ private static final String DEVICES_ENABLED = AuthEnablementRequestEventListener.class.getName() + ".devicesEnabled";
+
+ private final ClientPresenceManager clientPresenceManager;
+
+ public AuthEnablementRequestEventListener(final ClientPresenceManager clientPresenceManager) {
+ this.clientPresenceManager = clientPresenceManager;
+ }
+
+ @Override
+ public void onEvent(final RequestEvent event) {
+
+ if (event.getType() == Type.REQUEST_FILTERED) {
+ // The authenticated principal, if any, will be available after filters have run.
+ // Now that the account is known, capture a snapshot of `isEnabled` for the account and its devices,
+ // before carrying out the request’s business logic.
+ findAccount(event.getContainerRequest())
+ .ifPresent(
+ account -> {
+ event.getContainerRequest().setProperty(ACCOUNT_ENABLED, account.isEnabled());
+ event.getContainerRequest().setProperty(DEVICES_ENABLED, buildDevicesEnabledMap(account));
+ });
+
+ } else if (event.getType() == Type.FINISHED) {
+ // Now that the request is finished, check whether `isEnabled` changed for any of the devices, or the account
+ // as a whole. If the value did change, the affected device(s) must disconnect and reauthenticate.
+ // If a device was removed, it must also disconnect.
+ if (event.getContainerRequest().getProperty(ACCOUNT_ENABLED) != null &&
+ event.getContainerRequest().getProperty(DEVICES_ENABLED) != null) {
+
+ final boolean accountInitiallyEnabled = (boolean) event.getContainerRequest().getProperty(ACCOUNT_ENABLED);
+ @SuppressWarnings("unchecked") final Map initialDevicesEnabled = (Map) event.getContainerRequest()
+ .getProperty(DEVICES_ENABLED);
+
+ findAccount(event.getContainerRequest()).ifPresentOrElse(account -> {
+ final Set deviceIdsToDisplace;
+
+ if (account.isEnabled() != accountInitiallyEnabled) {
+ // the @Auth for all active connections must change when account.isEnabled() changes
+ deviceIdsToDisplace = account.getDevices().stream()
+ .map(Device::getId).collect(Collectors.toSet());
+
+ deviceIdsToDisplace.addAll(initialDevicesEnabled.keySet());
+
+ } else if (!initialDevicesEnabled.isEmpty()) {
+
+ deviceIdsToDisplace = new HashSet<>();
+ final Map currentDevicesEnabled = buildDevicesEnabledMap(account);
+
+ initialDevicesEnabled.forEach((deviceId, enabled) -> {
+ // `null` indicates the device was removed from the account. Any active presence should be removed.
+ final boolean enabledMatches = Objects.equals(enabled,
+ currentDevicesEnabled.getOrDefault(deviceId, null));
+
+ if (!enabledMatches) {
+ deviceIdsToDisplace.add(deviceId);
+ }
+ });
+ } else {
+ deviceIdsToDisplace = Collections.emptySet();
+ }
+
+ deviceIdsToDisplace.forEach(deviceId -> {
+ try {
+ // displacing presence will cause a reauthorization for the device’s active connections
+ clientPresenceManager.displacePresence(account.getUuid(), deviceId);
+ } catch (final Exception e) {
+ logger.error("Could not displace device presence", e);
+ }
+ });
+ },
+ () -> logger.error("Request had account, but it is no longer present")
+ );
+ }
+ }
+ }
+
+ private Optional findAccount(final ContainerRequest containerRequest) {
+ return Optional.ofNullable(containerRequest.getSecurityContext())
+ .map(SecurityContext::getUserPrincipal)
+ .map(principal -> {
+ if (principal instanceof AccountAndAuthenticatedDeviceHolder) {
+ return ((AccountAndAuthenticatedDeviceHolder) principal).getAccount();
+ }
+ return null;
+ });
+ }
+
+ @VisibleForTesting
+ Map buildDevicesEnabledMap(final Account account) {
+ return account.getDevices().stream()
+ .collect(() -> new HashMap<>(account.getDevices().size()),
+ (map, device) -> map.put(device.getId(), device.isEnabled()), HashMap::putAll);
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedAccount.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedAccount.java
index e1a34f45d..13c9e504c 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedAccount.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/AuthenticatedAccount.java
@@ -12,7 +12,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair;
-public class AuthenticatedAccount implements Principal {
+public class AuthenticatedAccount implements Principal, AccountAndAuthenticatedDeviceHolder {
private final Supplier> accountAndDevice;
@@ -20,10 +20,12 @@ public class AuthenticatedAccount implements Principal {
this.accountAndDevice = accountAndDevice;
}
+ @Override
public Account getAccount() {
return accountAndDevice.get().first();
}
+ @Override
public Device getAuthenticatedDevice() {
return accountAndDevice.get().second();
}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisabledPermittedAuthenticatedAccount.java b/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisabledPermittedAuthenticatedAccount.java
index 4001a9573..2b4fd73f1 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisabledPermittedAuthenticatedAccount.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/auth/DisabledPermittedAuthenticatedAccount.java
@@ -10,7 +10,7 @@ import javax.security.auth.Subject;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
-public class DisabledPermittedAuthenticatedAccount implements Principal {
+public class DisabledPermittedAuthenticatedAccount implements Principal, AccountAndAuthenticatedDeviceHolder {
private final AuthenticatedAccount authenticatedAccount;
@@ -18,10 +18,12 @@ public class DisabledPermittedAuthenticatedAccount implements Principal {
this.authenticatedAccount = authenticatedAccount;
}
+ @Override
public Account getAccount() {
return authenticatedAccount.getAccount();
}
+ @Override
public Device getAuthenticatedDevice() {
return authenticatedAccount.getAuthenticatedDevice();
}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java
index 0b8790764..5fdffb715 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/push/ClientPresenceManager.java
@@ -171,6 +171,10 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter> principalSupplier = () -> Optional.of(
+ new TestPrincipal("test", account, authenticatedDevice));
+
+ private final ResourceExtension resources = ResourceExtension.builder()
+ .addProvider(
+ new PolymorphicAuthDynamicFeature<>(ImmutableMap.of(
+ TestPrincipal.class,
+ new BasicCredentialAuthFilter.Builder()
+ .setAuthenticator(c -> principalSupplier.get()).buildAuthFilter())))
+ .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(TestPrincipal.class)))
+ .addProvider(applicationEventListener)
+ .setTestContainerFactory(new GrizzlyWebTestContainerFactory())
+ .addResource(new TestResource())
+ .build();
+
+ private ClientPresenceManager clientPresenceManager;
+
+ private AuthEnablementRequestEventListener listener;
+
+ @BeforeEach
+ void setup() {
+ clientPresenceManager = mock(ClientPresenceManager.class);
+ listener = new AuthEnablementRequestEventListener(clientPresenceManager);
+ when(applicationEventListener.onRequest(any())).thenReturn(listener);
+
+ final UUID uuid = UUID.randomUUID();
+ account.setUuid(uuid);
+ account.addDevice(authenticatedDevice);
+ LongStream.range(2, 4).forEach(deviceId -> {
+ account.addDevice(createDevice(deviceId));
+ });
+
+ account.getDevices()
+ .forEach(device -> when(clientPresenceManager.isPresent(uuid, device.getId())).thenReturn(true));
+ }
+
+ @Test
+ void testBuildDevicesEnabled() {
+
+ final long disabledDeviceId = 3L;
+
+ final Account account = mock(Account.class);
+
+ final Set devices = new HashSet<>();
+ when(account.getDevices()).thenReturn(devices);
+
+ LongStream.range(1, 5)
+ .forEach(id -> {
+ final Device device = mock(Device.class);
+ when(device.getId()).thenReturn(id);
+ when(device.isEnabled()).thenReturn(id != disabledDeviceId);
+ devices.add(device);
+ });
+
+ final Map devicesEnabled = listener.buildDevicesEnabledMap(account);
+
+ assertEquals(4, devicesEnabled.size());
+
+ assertAll(devicesEnabled.entrySet().stream()
+ .map(deviceAndEnabled -> () -> {
+ if (deviceAndEnabled.getKey().equals(disabledDeviceId)) {
+ assertFalse(deviceAndEnabled.getValue());
+ } else {
+ assertTrue(deviceAndEnabled.getValue());
+ }
+ }));
+ }
+
+ @ParameterizedTest
+ @MethodSource
+ void testAccountEnabledChanged(final long authenticatedDeviceId, final boolean initialEnabled,
+ final boolean finalEnabled) {
+
+ setDeviceEnabled(account.getMasterDevice().orElseThrow(), initialEnabled);
+
+ authenticatedDevice = account.getDevice(authenticatedDeviceId).orElseThrow();
+
+ final Response response = resources.getJerseyTest()
+ .target("/v1/test/account/enabled/" + finalEnabled)
+ .request()
+ .header("Authorization",
+ "Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
+ .put(Entity.entity("", MediaType.TEXT_PLAIN));
+
+ assertEquals(200, response.getStatus());
+
+ if (initialEnabled != finalEnabled) {
+ verify(clientPresenceManager, times(account.getDevices().size())).displacePresence(eq(account.getUuid()),
+ anyLong());
+ } else {
+ verifyNoInteractions(clientPresenceManager);
+ }
+ }
+
+ static Stream testAccountEnabledChanged() {
+ return Stream.of(
+ Arguments.of(1L, true, false),
+ Arguments.of(1L, false, true),
+ Arguments.of(1L, true, true),
+ Arguments.of(1L, false, false),
+ Arguments.of(2L, true, false),
+ Arguments.of(2L, false, true),
+ Arguments.of(2L, true, true),
+ Arguments.of(2L, false, false)
+ );
+ }
+
+ @ParameterizedTest
+ @MethodSource
+ void testDeviceEnabledChanged(final Map initialEnabled, final Map finalEnabled) {
+ assert initialEnabled.size() == finalEnabled.size();
+
+ assert account.getMasterDevice().orElseThrow().isEnabled();
+
+ initialEnabled.forEach((deviceId, enabled) ->
+ setDeviceEnabled(account.getDevice(deviceId).orElseThrow(), enabled));
+
+ final Response response = resources.getJerseyTest()
+ .target("/v1/test/account/devices/enabled")
+ .request()
+ .header("Authorization",
+ "Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
+ .post(Entity.entity(finalEnabled, MediaType.APPLICATION_JSON));
+
+ assertEquals(200, response.getStatus());
+
+ assertAll(
+ finalEnabled.entrySet().stream()
+ .map(deviceIdEnabled -> () -> {
+ final boolean expectDisplacedPresence =
+ initialEnabled.get(deviceIdEnabled.getKey()) != deviceIdEnabled.getValue();
+
+ verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0)).displacePresence(account.getUuid(),
+ deviceIdEnabled.getKey());
+ })
+ );
+ }
+
+ static Stream testDeviceEnabledChanged() {
+ return Stream.of(
+ // Not testing device ID 1L because that will trigger "account enabled changed"
+ Arguments.of(Map.of(2L, false, 3L, false), Map.of(2L, true, 3L, true)),
+ Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, false, 3L, false)),
+ Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, true, 3L, true)),
+ Arguments.of(Map.of(2L, false, 3L, true), Map.of(2L, true, 3L, true)),
+ Arguments.of(Map.of(2L, true, 3L, false), Map.of(2L, true, 3L, true))
+ );
+ }
+
+ @Test
+ void testDeviceAdded() {
+ assert account.getMasterDevice().orElseThrow().isEnabled();
+
+ final int initialDeviceCount = account.getDevices().size();
+
+ final List addedDeviceNames = List.of("newDevice1", "newDevice2");
+ final Response response = resources.getJerseyTest()
+ .target("/v1/test/account/devices")
+ .request()
+ .header("Authorization",
+ "Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
+ .put(Entity.entity(addedDeviceNames, MediaType.APPLICATION_JSON_PATCH_JSON));
+
+ assertEquals(200, response.getStatus());
+
+ assertEquals(initialDeviceCount + addedDeviceNames.size(), account.getDevices().size());
+
+ verifyNoInteractions(clientPresenceManager);
+ }
+
+ @ParameterizedTest
+ @ValueSource(ints = {1, 2})
+ void testDeviceRemoved(final int removedDeviceCount) {
+ assert account.getMasterDevice().orElseThrow().isEnabled();
+
+ final List deletedDeviceIds = account.getDevices().stream()
+ .map(Device::getId)
+ .filter(deviceId -> deviceId != 1L)
+ .limit(removedDeviceCount)
+ .collect(Collectors.toList());
+
+ assert deletedDeviceIds.size() == removedDeviceCount;
+
+ final String deletedDeviceIdsParam = deletedDeviceIds.stream().map(String::valueOf)
+ .collect(Collectors.joining(","));
+
+ final Response response = resources.getJerseyTest()
+ .target("/v1/test/account/devices/" + deletedDeviceIdsParam)
+ .request()
+ .header("Authorization",
+ "Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
+ .delete();
+
+ assertEquals(200, response.getStatus());
+
+ deletedDeviceIds.forEach(deletedDeviceId ->
+ verify(clientPresenceManager).displacePresence(account.getUuid(), deletedDeviceId));
+
+ verifyNoMoreInteractions(clientPresenceManager);
+ }
+
+ @Test
+ void testMasterDeviceDisabledAndDeviceRemoved() {
+ assert account.getMasterDevice().orElseThrow().isEnabled();
+
+ final Set initialDeviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toSet());
+
+ final long deletedDeviceId = 2L;
+ assertTrue(initialDeviceIds.remove(deletedDeviceId));
+
+ final Response response = resources.getJerseyTest()
+ .target("/v1/test/account/disableMasterDeviceAndDeleteDevice/" + deletedDeviceId)
+ .request()
+ .header("Authorization",
+ "Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
+ .post(Entity.entity("", MediaType.TEXT_PLAIN));
+
+ assertEquals(200, response.getStatus());
+
+ assertTrue(account.getDevice(deletedDeviceId).isEmpty());
+
+ initialDeviceIds.forEach(deviceId -> verify(clientPresenceManager).displacePresence(account.getUuid(), deviceId));
+ verify(clientPresenceManager).displacePresence(account.getUuid(), deletedDeviceId);
+
+ verifyNoMoreInteractions(clientPresenceManager);
+ }
+
+ @Test
+ void testOnEvent() {
+ Response response = resources.getJerseyTest()
+ .target("/v1/test/hello")
+ .request()
+ // no authorization required
+ .get();
+
+ assertEquals(200, response.getStatus());
+
+ response = resources.getJerseyTest()
+ .target("/v1/test/authorized")
+ .request()
+ .header("Authorization",
+ "Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
+ .get();
+
+ assertEquals(200, response.getStatus());
+ }
+
+ @Nested
+ class WebSocket {
+
+ private WebSocketResourceProvider provider;
+ private RemoteEndpoint remoteEndpoint;
+
+ @BeforeEach
+ void setup() {
+ ResourceConfig resourceConfig = new DropwizardResourceConfig();
+ resourceConfig.register(applicationEventListener);
+ resourceConfig.register(new TestResource());
+ resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder());
+ resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class));
+ resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper()));
+
+ ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
+ WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
+
+ provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
+ requestLog, new TestPrincipal("test", account, authenticatedDevice), new ProtobufWebSocketMessageFactory(),
+ Optional.empty(), 30000);
+
+ remoteEndpoint = mock(RemoteEndpoint.class);
+ Session session = mock(Session.class);
+ UpgradeRequest request = mock(UpgradeRequest.class);
+
+ when(session.getRemote()).thenReturn(remoteEndpoint);
+ when(session.getUpgradeRequest()).thenReturn(request);
+
+ provider.onWebSocketConnect(session);
+ }
+
+ @ParameterizedTest
+ @MethodSource("org.whispersystems.textsecuregcm.auth.AuthEnablementRequestEventListenerTest#testAccountEnabledChanged")
+ void testAccountEnabledChangedWebSocket(final long authenticatedDeviceId, final boolean initialEnabled,
+ final boolean finalEnabled) throws Exception {
+
+ setDeviceEnabled(account.getMasterDevice().orElseThrow(), initialEnabled);
+
+ authenticatedDevice = account.getDevice(authenticatedDeviceId).orElseThrow();
+
+ byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT",
+ "/v1/test/account/enabled/" + finalEnabled,
+ new LinkedList<>(), Optional.empty()).toByteArray();
+
+ provider.onWebSocketBinary(message, 0, message.length);
+
+ final SubProtocol.WebSocketResponseMessage response = verifyAndGetResponse(remoteEndpoint);
+
+ assertEquals(200, response.getStatus());
+ if (initialEnabled != finalEnabled) {
+ verify(clientPresenceManager, times(account.getDevices().size())).displacePresence(eq(account.getUuid()),
+ anyLong());
+ } else {
+ verifyNoInteractions(clientPresenceManager);
+ }
+ }
+
+ @Test
+ void testOnEvent() throws Exception {
+
+ byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello",
+ new LinkedList<>(), Optional.empty()).toByteArray();
+
+ provider.onWebSocketBinary(message, 0, message.length);
+
+ final SubProtocol.WebSocketResponseMessage response = verifyAndGetResponse(remoteEndpoint);
+
+ assertEquals(200, response.getStatus());
+ }
+
+ private SubProtocol.WebSocketResponseMessage verifyAndGetResponse(final RemoteEndpoint remoteEndpoint)
+ throws InvalidProtocolBufferException {
+ ArgumentCaptor responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
+ verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture());
+
+ return SubProtocol.WebSocketMessage.parseFrom(responseBytesCaptor.getValue().array()).getResponse();
+ }
+ }
+
+ private static Device createDevice(final long deviceId) {
+ final Device device = new Device(deviceId, null, null, null, null, null, null, false, 0, null, 0, 0, "OWT", 0,
+ null);
+
+ setDeviceEnabled(device, true);
+
+ return device;
+ }
+
+ private static void setDeviceEnabled(Device device, boolean enabled) {
+ if (enabled) {
+ device.setSignedPreKey(new SignedPreKey(RANDOM.nextLong(), "testPublicKey-" + RANDOM.nextLong(),
+ "testSignature-" + RANDOM.nextLong()));
+ device.setGcmId("testGcmId" + RANDOM.nextLong());
+ device.setLastSeen(Util.todayInMillis());
+ } else {
+ device.setSignedPreKey(null);
+ }
+
+ assert enabled == device.isEnabled();
+ }
+
+ public static class TestPrincipal implements Principal, AccountAndAuthenticatedDeviceHolder {
+
+ private final String name;
+ private final Account account;
+ private final Device device;
+
+ private TestPrincipal(String name, final Account account, final Device device) {
+ this.name = name;
+ this.account = account;
+ this.device = device;
+ }
+
+ @Override
+ public String getName() {
+ return name;
+ }
+
+ @Override
+ public Account getAccount() {
+ return account;
+ }
+
+ @Override
+ public Device getAuthenticatedDevice() {
+ return device;
+ }
+ }
+
+ @Path("/v1/test")
+ public static class TestResource {
+
+ @GET
+ @Path("/hello")
+ public String testGetHello() {
+ return "Hello!";
+ }
+
+ @GET
+ @Path("/authorized")
+ public String testAuth(@Auth TestPrincipal principal) {
+ return "You’re in!";
+ }
+
+ @PUT
+ @Path("/account/enabled/{enabled}")
+ public String setAccountEnabled(@Auth TestPrincipal principal, @PathParam("enabled") final boolean enabled) {
+
+ final Device device = principal.getAccount().getMasterDevice().orElseThrow();
+
+ AuthEnablementRequestEventListenerTest.setDeviceEnabled(device, enabled);
+
+ assert device.isEnabled() == enabled;
+
+ return String.format("Set account to %s", enabled);
+ }
+
+ @POST
+ @Path("/account/devices/enabled")
+ public String setDeviceEnabled(@Auth TestPrincipal principal, Map deviceIdsEnabled) {
+
+ final StringBuilder response = new StringBuilder();
+
+ for (Entry deviceIdEnabled : deviceIdsEnabled.entrySet()) {
+ final Device device = principal.getAccount().getDevice(deviceIdEnabled.getKey()).orElseThrow();
+ AuthEnablementRequestEventListenerTest.setDeviceEnabled(device, deviceIdEnabled.getValue());
+
+ response.append(String.format("Set device enabled %s", deviceIdEnabled));
+ }
+
+ return response.toString();
+ }
+
+ @PUT
+ @Path("/account/devices")
+ public String addDevices(@Auth TestPrincipal auth, List deviceNames) {
+
+ deviceNames.forEach(name -> {
+ final Device device = createDevice(auth.getAccount().getNextDeviceId());
+ auth.getAccount().addDevice(device);
+
+ device.setName(name);
+ });
+
+ return "Added devices " + deviceNames;
+ }
+
+ @DELETE
+ @Path("/account/devices/{deviceIds}")
+ public String removeDevices(@Auth TestPrincipal auth, @PathParam("deviceIds") String deviceIds) {
+
+ Arrays.stream(deviceIds.split(","))
+ .map(Long::valueOf)
+ .forEach(auth.getAccount()::removeDevice);
+
+ return "Removed device(s) " + deviceIds;
+ }
+
+ @POST
+ @Path("/account/disableMasterDeviceAndDeleteDevice/{deviceId}")
+ public String disableMasterDeviceAndRemoveDevice(@Auth TestPrincipal auth, @PathParam("deviceId") long deviceId) {
+
+ AuthEnablementRequestEventListenerTest.setDeviceEnabled(auth.getAccount().getMasterDevice().orElseThrow(), false);
+
+ auth.getAccount().removeDevice(deviceId);
+
+ return "Removed device " + deviceId;
+ }
+ }
+}