Add request event listener that handles device.isEnabled changes

This commit is contained in:
Chris Eager 2021-08-17 18:11:56 -05:00 committed by Chris Eager
parent 2866f1b213
commit 539b62a829
6 changed files with 721 additions and 2 deletions

View File

@ -0,0 +1,16 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
public interface AccountAndAuthenticatedDeviceHolder {
Account getAccount();
Device getAuthenticatedDevice();
}

View File

@ -0,0 +1,139 @@
/*
* Copyright 2021 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import com.google.common.annotations.VisibleForTesting;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.ws.rs.core.SecurityContext;
import org.glassfish.jersey.server.ContainerRequest;
import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.glassfish.jersey.server.monitoring.RequestEvent.Type;
import org.glassfish.jersey.server.monitoring.RequestEventListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
/**
* This {@link RequestEventListener} observes intra-request changes in {@link Account#isEnabled()} and {@link
* Device#isEnabled()}.
* <p>
* If a change in {@link Account#isEnabled()} is observed, then any active WebSocket connections for the account must be
* closed, in order for clients to get a refreshed {@link io.dropwizard.auth.Auth} object.
* <p>
* If a change in {@link Device#isEnabled()} is observed, including deletion of the {@link Device}, then any active
* WebSocket connections for the device must be closed and re-authenticated.
*
* @see AuthenticatedAccount
* @see DisabledPermittedAuthenticatedAccount
*/
public class AuthEnablementRequestEventListener implements RequestEventListener {
private static final Logger logger = LoggerFactory.getLogger(AuthEnablementRequestEventListener.class);
private static final String ACCOUNT_ENABLED = AuthEnablementRequestEventListener.class.getName() + ".accountEnabled";
private static final String DEVICES_ENABLED = AuthEnablementRequestEventListener.class.getName() + ".devicesEnabled";
private final ClientPresenceManager clientPresenceManager;
public AuthEnablementRequestEventListener(final ClientPresenceManager clientPresenceManager) {
this.clientPresenceManager = clientPresenceManager;
}
@Override
public void onEvent(final RequestEvent event) {
if (event.getType() == Type.REQUEST_FILTERED) {
// The authenticated principal, if any, will be available after filters have run.
// Now that the account is known, capture a snapshot of `isEnabled` for the account and its devices,
// before carrying out the requests business logic.
findAccount(event.getContainerRequest())
.ifPresent(
account -> {
event.getContainerRequest().setProperty(ACCOUNT_ENABLED, account.isEnabled());
event.getContainerRequest().setProperty(DEVICES_ENABLED, buildDevicesEnabledMap(account));
});
} else if (event.getType() == Type.FINISHED) {
// Now that the request is finished, check whether `isEnabled` changed for any of the devices, or the account
// as a whole. If the value did change, the affected device(s) must disconnect and reauthenticate.
// If a device was removed, it must also disconnect.
if (event.getContainerRequest().getProperty(ACCOUNT_ENABLED) != null &&
event.getContainerRequest().getProperty(DEVICES_ENABLED) != null) {
final boolean accountInitiallyEnabled = (boolean) event.getContainerRequest().getProperty(ACCOUNT_ENABLED);
@SuppressWarnings("unchecked") final Map<Long, Boolean> initialDevicesEnabled = (Map<Long, Boolean>) event.getContainerRequest()
.getProperty(DEVICES_ENABLED);
findAccount(event.getContainerRequest()).ifPresentOrElse(account -> {
final Set<Long> deviceIdsToDisplace;
if (account.isEnabled() != accountInitiallyEnabled) {
// the @Auth for all active connections must change when account.isEnabled() changes
deviceIdsToDisplace = account.getDevices().stream()
.map(Device::getId).collect(Collectors.toSet());
deviceIdsToDisplace.addAll(initialDevicesEnabled.keySet());
} else if (!initialDevicesEnabled.isEmpty()) {
deviceIdsToDisplace = new HashSet<>();
final Map<Long, Boolean> currentDevicesEnabled = buildDevicesEnabledMap(account);
initialDevicesEnabled.forEach((deviceId, enabled) -> {
// `null` indicates the device was removed from the account. Any active presence should be removed.
final boolean enabledMatches = Objects.equals(enabled,
currentDevicesEnabled.getOrDefault(deviceId, null));
if (!enabledMatches) {
deviceIdsToDisplace.add(deviceId);
}
});
} else {
deviceIdsToDisplace = Collections.emptySet();
}
deviceIdsToDisplace.forEach(deviceId -> {
try {
// displacing presence will cause a reauthorization for the devices active connections
clientPresenceManager.displacePresence(account.getUuid(), deviceId);
} catch (final Exception e) {
logger.error("Could not displace device presence", e);
}
});
},
() -> logger.error("Request had account, but it is no longer present")
);
}
}
}
private Optional<Account> findAccount(final ContainerRequest containerRequest) {
return Optional.ofNullable(containerRequest.getSecurityContext())
.map(SecurityContext::getUserPrincipal)
.map(principal -> {
if (principal instanceof AccountAndAuthenticatedDeviceHolder) {
return ((AccountAndAuthenticatedDeviceHolder) principal).getAccount();
}
return null;
});
}
@VisibleForTesting
Map<Long, Boolean> buildDevicesEnabledMap(final Account account) {
return account.getDevices().stream()
.collect(() -> new HashMap<>(account.getDevices().size()),
(map, device) -> map.put(device.getId(), device.isEnabled()), HashMap::putAll);
}
}

View File

@ -12,7 +12,7 @@ import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Pair;
public class AuthenticatedAccount implements Principal {
public class AuthenticatedAccount implements Principal, AccountAndAuthenticatedDeviceHolder {
private final Supplier<Pair<Account, Device>> accountAndDevice;
@ -20,10 +20,12 @@ public class AuthenticatedAccount implements Principal {
this.accountAndDevice = accountAndDevice;
}
@Override
public Account getAccount() {
return accountAndDevice.get().first();
}
@Override
public Device getAuthenticatedDevice() {
return accountAndDevice.get().second();
}

View File

@ -10,7 +10,7 @@ import javax.security.auth.Subject;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
public class DisabledPermittedAuthenticatedAccount implements Principal {
public class DisabledPermittedAuthenticatedAccount implements Principal, AccountAndAuthenticatedDeviceHolder {
private final AuthenticatedAccount authenticatedAccount;
@ -18,10 +18,12 @@ public class DisabledPermittedAuthenticatedAccount implements Principal {
this.authenticatedAccount = authenticatedAccount;
}
@Override
public Account getAccount() {
return authenticatedAccount.getAccount();
}
@Override
public Device getAuthenticatedDevice() {
return authenticatedAccount.getAuthenticatedDevice();
}

View File

@ -171,6 +171,10 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter<String, Str
}
}
public void displacePresence(final UUID accountUuid, final long deviceId) {
displacePresence(getPresenceKey(accountUuid, deviceId));
}
private void displacePresence(final String presenceKey) {
final DisplacedPresenceListener displacementListener = displacementListenersByPresenceKey.get(presenceKey);

View File

@ -0,0 +1,556 @@
/*
* Copyright 2013-2020 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.whispersystems.textsecuregcm.auth;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.auth.Auth;
import io.dropwizard.auth.PolymorphicAuthDynamicFeature;
import io.dropwizard.auth.PolymorphicAuthValueFactoryProvider;
import io.dropwizard.auth.basic.BasicCredentialAuthFilter;
import io.dropwizard.jersey.DropwizardResourceConfig;
import io.dropwizard.jersey.jackson.JacksonMessageBodyProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.Principal;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import javax.ws.rs.DELETE;
import javax.ws.rs.GET;
import javax.ws.rs.POST;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.eclipse.jetty.websocket.api.RemoteEndpoint;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.glassfish.jersey.server.ApplicationHandler;
import org.glassfish.jersey.server.ResourceConfig;
import org.glassfish.jersey.server.monitoring.ApplicationEventListener;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.websocket.WebSocketResourceProvider;
import org.whispersystems.websocket.auth.WebsocketAuthValueFactoryProvider;
import org.whispersystems.websocket.logging.WebsocketRequestLog;
import org.whispersystems.websocket.messages.protobuf.ProtobufWebSocketMessageFactory;
import org.whispersystems.websocket.messages.protobuf.SubProtocol;
import org.whispersystems.websocket.session.WebSocketSessionContextValueFactoryProvider;
@ExtendWith(DropwizardExtensionsSupport.class)
class AuthEnablementRequestEventListenerTest {
private static final Random RANDOM = new Random();
private final ApplicationEventListener applicationEventListener = mock(ApplicationEventListener.class);
private Account account = new Account();
private Device authenticatedDevice = createDevice(1L);
private Supplier<Optional<TestPrincipal>> principalSupplier = () -> Optional.of(
new TestPrincipal("test", account, authenticatedDevice));
private final ResourceExtension resources = ResourceExtension.builder()
.addProvider(
new PolymorphicAuthDynamicFeature<>(ImmutableMap.of(
TestPrincipal.class,
new BasicCredentialAuthFilter.Builder<TestPrincipal>()
.setAuthenticator(c -> principalSupplier.get()).buildAuthFilter())))
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(TestPrincipal.class)))
.addProvider(applicationEventListener)
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
.addResource(new TestResource())
.build();
private ClientPresenceManager clientPresenceManager;
private AuthEnablementRequestEventListener listener;
@BeforeEach
void setup() {
clientPresenceManager = mock(ClientPresenceManager.class);
listener = new AuthEnablementRequestEventListener(clientPresenceManager);
when(applicationEventListener.onRequest(any())).thenReturn(listener);
final UUID uuid = UUID.randomUUID();
account.setUuid(uuid);
account.addDevice(authenticatedDevice);
LongStream.range(2, 4).forEach(deviceId -> {
account.addDevice(createDevice(deviceId));
});
account.getDevices()
.forEach(device -> when(clientPresenceManager.isPresent(uuid, device.getId())).thenReturn(true));
}
@Test
void testBuildDevicesEnabled() {
final long disabledDeviceId = 3L;
final Account account = mock(Account.class);
final Set<Device> devices = new HashSet<>();
when(account.getDevices()).thenReturn(devices);
LongStream.range(1, 5)
.forEach(id -> {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(id);
when(device.isEnabled()).thenReturn(id != disabledDeviceId);
devices.add(device);
});
final Map<Long, Boolean> devicesEnabled = listener.buildDevicesEnabledMap(account);
assertEquals(4, devicesEnabled.size());
assertAll(devicesEnabled.entrySet().stream()
.map(deviceAndEnabled -> () -> {
if (deviceAndEnabled.getKey().equals(disabledDeviceId)) {
assertFalse(deviceAndEnabled.getValue());
} else {
assertTrue(deviceAndEnabled.getValue());
}
}));
}
@ParameterizedTest
@MethodSource
void testAccountEnabledChanged(final long authenticatedDeviceId, final boolean initialEnabled,
final boolean finalEnabled) {
setDeviceEnabled(account.getMasterDevice().orElseThrow(), initialEnabled);
authenticatedDevice = account.getDevice(authenticatedDeviceId).orElseThrow();
final Response response = resources.getJerseyTest()
.target("/v1/test/account/enabled/" + finalEnabled)
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.put(Entity.entity("", MediaType.TEXT_PLAIN));
assertEquals(200, response.getStatus());
if (initialEnabled != finalEnabled) {
verify(clientPresenceManager, times(account.getDevices().size())).displacePresence(eq(account.getUuid()),
anyLong());
} else {
verifyNoInteractions(clientPresenceManager);
}
}
static Stream<Arguments> testAccountEnabledChanged() {
return Stream.of(
Arguments.of(1L, true, false),
Arguments.of(1L, false, true),
Arguments.of(1L, true, true),
Arguments.of(1L, false, false),
Arguments.of(2L, true, false),
Arguments.of(2L, false, true),
Arguments.of(2L, true, true),
Arguments.of(2L, false, false)
);
}
@ParameterizedTest
@MethodSource
void testDeviceEnabledChanged(final Map<Long, Boolean> initialEnabled, final Map<Long, Boolean> finalEnabled) {
assert initialEnabled.size() == finalEnabled.size();
assert account.getMasterDevice().orElseThrow().isEnabled();
initialEnabled.forEach((deviceId, enabled) ->
setDeviceEnabled(account.getDevice(deviceId).orElseThrow(), enabled));
final Response response = resources.getJerseyTest()
.target("/v1/test/account/devices/enabled")
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.post(Entity.entity(finalEnabled, MediaType.APPLICATION_JSON));
assertEquals(200, response.getStatus());
assertAll(
finalEnabled.entrySet().stream()
.map(deviceIdEnabled -> () -> {
final boolean expectDisplacedPresence =
initialEnabled.get(deviceIdEnabled.getKey()) != deviceIdEnabled.getValue();
verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0)).displacePresence(account.getUuid(),
deviceIdEnabled.getKey());
})
);
}
static Stream<Arguments> testDeviceEnabledChanged() {
return Stream.of(
// Not testing device ID 1L because that will trigger "account enabled changed"
Arguments.of(Map.of(2L, false, 3L, false), Map.of(2L, true, 3L, true)),
Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, false, 3L, false)),
Arguments.of(Map.of(2L, true, 3L, true), Map.of(2L, true, 3L, true)),
Arguments.of(Map.of(2L, false, 3L, true), Map.of(2L, true, 3L, true)),
Arguments.of(Map.of(2L, true, 3L, false), Map.of(2L, true, 3L, true))
);
}
@Test
void testDeviceAdded() {
assert account.getMasterDevice().orElseThrow().isEnabled();
final int initialDeviceCount = account.getDevices().size();
final List<String> addedDeviceNames = List.of("newDevice1", "newDevice2");
final Response response = resources.getJerseyTest()
.target("/v1/test/account/devices")
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.put(Entity.entity(addedDeviceNames, MediaType.APPLICATION_JSON_PATCH_JSON));
assertEquals(200, response.getStatus());
assertEquals(initialDeviceCount + addedDeviceNames.size(), account.getDevices().size());
verifyNoInteractions(clientPresenceManager);
}
@ParameterizedTest
@ValueSource(ints = {1, 2})
void testDeviceRemoved(final int removedDeviceCount) {
assert account.getMasterDevice().orElseThrow().isEnabled();
final List<Long> deletedDeviceIds = account.getDevices().stream()
.map(Device::getId)
.filter(deviceId -> deviceId != 1L)
.limit(removedDeviceCount)
.collect(Collectors.toList());
assert deletedDeviceIds.size() == removedDeviceCount;
final String deletedDeviceIdsParam = deletedDeviceIds.stream().map(String::valueOf)
.collect(Collectors.joining(","));
final Response response = resources.getJerseyTest()
.target("/v1/test/account/devices/" + deletedDeviceIdsParam)
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.delete();
assertEquals(200, response.getStatus());
deletedDeviceIds.forEach(deletedDeviceId ->
verify(clientPresenceManager).displacePresence(account.getUuid(), deletedDeviceId));
verifyNoMoreInteractions(clientPresenceManager);
}
@Test
void testMasterDeviceDisabledAndDeviceRemoved() {
assert account.getMasterDevice().orElseThrow().isEnabled();
final Set<Long> initialDeviceIds = account.getDevices().stream().map(Device::getId).collect(Collectors.toSet());
final long deletedDeviceId = 2L;
assertTrue(initialDeviceIds.remove(deletedDeviceId));
final Response response = resources.getJerseyTest()
.target("/v1/test/account/disableMasterDeviceAndDeleteDevice/" + deletedDeviceId)
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.post(Entity.entity("", MediaType.TEXT_PLAIN));
assertEquals(200, response.getStatus());
assertTrue(account.getDevice(deletedDeviceId).isEmpty());
initialDeviceIds.forEach(deviceId -> verify(clientPresenceManager).displacePresence(account.getUuid(), deviceId));
verify(clientPresenceManager).displacePresence(account.getUuid(), deletedDeviceId);
verifyNoMoreInteractions(clientPresenceManager);
}
@Test
void testOnEvent() {
Response response = resources.getJerseyTest()
.target("/v1/test/hello")
.request()
// no authorization required
.get();
assertEquals(200, response.getStatus());
response = resources.getJerseyTest()
.target("/v1/test/authorized")
.request()
.header("Authorization",
"Basic " + Base64.getEncoder().encodeToString("user:pass".getBytes(StandardCharsets.UTF_8)))
.get();
assertEquals(200, response.getStatus());
}
@Nested
class WebSocket {
private WebSocketResourceProvider<TestPrincipal> provider;
private RemoteEndpoint remoteEndpoint;
@BeforeEach
void setup() {
ResourceConfig resourceConfig = new DropwizardResourceConfig();
resourceConfig.register(applicationEventListener);
resourceConfig.register(new TestResource());
resourceConfig.register(new WebSocketSessionContextValueFactoryProvider.Binder());
resourceConfig.register(new WebsocketAuthValueFactoryProvider.Binder<>(TestPrincipal.class));
resourceConfig.register(new JacksonMessageBodyProvider(new ObjectMapper()));
ApplicationHandler applicationHandler = new ApplicationHandler(resourceConfig);
WebsocketRequestLog requestLog = mock(WebsocketRequestLog.class);
provider = new WebSocketResourceProvider<>("127.0.0.1", applicationHandler,
requestLog, new TestPrincipal("test", account, authenticatedDevice), new ProtobufWebSocketMessageFactory(),
Optional.empty(), 30000);
remoteEndpoint = mock(RemoteEndpoint.class);
Session session = mock(Session.class);
UpgradeRequest request = mock(UpgradeRequest.class);
when(session.getRemote()).thenReturn(remoteEndpoint);
when(session.getUpgradeRequest()).thenReturn(request);
provider.onWebSocketConnect(session);
}
@ParameterizedTest
@MethodSource("org.whispersystems.textsecuregcm.auth.AuthEnablementRequestEventListenerTest#testAccountEnabledChanged")
void testAccountEnabledChangedWebSocket(final long authenticatedDeviceId, final boolean initialEnabled,
final boolean finalEnabled) throws Exception {
setDeviceEnabled(account.getMasterDevice().orElseThrow(), initialEnabled);
authenticatedDevice = account.getDevice(authenticatedDeviceId).orElseThrow();
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "PUT",
"/v1/test/account/enabled/" + finalEnabled,
new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
final SubProtocol.WebSocketResponseMessage response = verifyAndGetResponse(remoteEndpoint);
assertEquals(200, response.getStatus());
if (initialEnabled != finalEnabled) {
verify(clientPresenceManager, times(account.getDevices().size())).displacePresence(eq(account.getUuid()),
anyLong());
} else {
verifyNoInteractions(clientPresenceManager);
}
}
@Test
void testOnEvent() throws Exception {
byte[] message = new ProtobufWebSocketMessageFactory().createRequest(Optional.of(111L), "GET", "/v1/test/hello",
new LinkedList<>(), Optional.empty()).toByteArray();
provider.onWebSocketBinary(message, 0, message.length);
final SubProtocol.WebSocketResponseMessage response = verifyAndGetResponse(remoteEndpoint);
assertEquals(200, response.getStatus());
}
private SubProtocol.WebSocketResponseMessage verifyAndGetResponse(final RemoteEndpoint remoteEndpoint)
throws InvalidProtocolBufferException {
ArgumentCaptor<ByteBuffer> responseBytesCaptor = ArgumentCaptor.forClass(ByteBuffer.class);
verify(remoteEndpoint).sendBytesByFuture(responseBytesCaptor.capture());
return SubProtocol.WebSocketMessage.parseFrom(responseBytesCaptor.getValue().array()).getResponse();
}
}
private static Device createDevice(final long deviceId) {
final Device device = new Device(deviceId, null, null, null, null, null, null, false, 0, null, 0, 0, "OWT", 0,
null);
setDeviceEnabled(device, true);
return device;
}
private static void setDeviceEnabled(Device device, boolean enabled) {
if (enabled) {
device.setSignedPreKey(new SignedPreKey(RANDOM.nextLong(), "testPublicKey-" + RANDOM.nextLong(),
"testSignature-" + RANDOM.nextLong()));
device.setGcmId("testGcmId" + RANDOM.nextLong());
device.setLastSeen(Util.todayInMillis());
} else {
device.setSignedPreKey(null);
}
assert enabled == device.isEnabled();
}
public static class TestPrincipal implements Principal, AccountAndAuthenticatedDeviceHolder {
private final String name;
private final Account account;
private final Device device;
private TestPrincipal(String name, final Account account, final Device device) {
this.name = name;
this.account = account;
this.device = device;
}
@Override
public String getName() {
return name;
}
@Override
public Account getAccount() {
return account;
}
@Override
public Device getAuthenticatedDevice() {
return device;
}
}
@Path("/v1/test")
public static class TestResource {
@GET
@Path("/hello")
public String testGetHello() {
return "Hello!";
}
@GET
@Path("/authorized")
public String testAuth(@Auth TestPrincipal principal) {
return "Youre in!";
}
@PUT
@Path("/account/enabled/{enabled}")
public String setAccountEnabled(@Auth TestPrincipal principal, @PathParam("enabled") final boolean enabled) {
final Device device = principal.getAccount().getMasterDevice().orElseThrow();
AuthEnablementRequestEventListenerTest.setDeviceEnabled(device, enabled);
assert device.isEnabled() == enabled;
return String.format("Set account to %s", enabled);
}
@POST
@Path("/account/devices/enabled")
public String setDeviceEnabled(@Auth TestPrincipal principal, Map<Long, Boolean> deviceIdsEnabled) {
final StringBuilder response = new StringBuilder();
for (Entry<Long, Boolean> deviceIdEnabled : deviceIdsEnabled.entrySet()) {
final Device device = principal.getAccount().getDevice(deviceIdEnabled.getKey()).orElseThrow();
AuthEnablementRequestEventListenerTest.setDeviceEnabled(device, deviceIdEnabled.getValue());
response.append(String.format("Set device enabled %s", deviceIdEnabled));
}
return response.toString();
}
@PUT
@Path("/account/devices")
public String addDevices(@Auth TestPrincipal auth, List<String> deviceNames) {
deviceNames.forEach(name -> {
final Device device = createDevice(auth.getAccount().getNextDeviceId());
auth.getAccount().addDevice(device);
device.setName(name);
});
return "Added devices " + deviceNames;
}
@DELETE
@Path("/account/devices/{deviceIds}")
public String removeDevices(@Auth TestPrincipal auth, @PathParam("deviceIds") String deviceIds) {
Arrays.stream(deviceIds.split(","))
.map(Long::valueOf)
.forEach(auth.getAccount()::removeDevice);
return "Removed device(s) " + deviceIds;
}
@POST
@Path("/account/disableMasterDeviceAndDeleteDevice/{deviceId}")
public String disableMasterDeviceAndRemoveDevice(@Auth TestPrincipal auth, @PathParam("deviceId") long deviceId) {
AuthEnablementRequestEventListenerTest.setDeviceEnabled(auth.getAccount().getMasterDevice().orElseThrow(), false);
auth.getAccount().removeDevice(deviceId);
return "Removed device " + deviceId;
}
}
}