Distinguish local vs remote in `ClientPresenceManager#disconnectPresence`

This commit is contained in:
Chris Eager 2021-12-02 15:32:42 -07:00 committed by GitHub
parent e507ce2f26
commit 13e346d4eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 388 additions and 293 deletions

View File

@ -5,10 +5,14 @@
package org.whispersystems.textsecuregcm.auth; package org.whispersystems.textsecuregcm.auth;
import java.util.Arrays; import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
import java.util.concurrent.atomic.AtomicInteger;
import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Metrics;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
import javax.ws.rs.container.ResourceInfo;
import javax.ws.rs.core.Context;
import org.glassfish.jersey.server.monitoring.RequestEvent; import org.glassfish.jersey.server.monitoring.RequestEvent;
import org.glassfish.jersey.server.monitoring.RequestEvent.Type; import org.glassfish.jersey.server.monitoring.RequestEvent.Type;
import org.glassfish.jersey.server.monitoring.RequestEventListener; import org.glassfish.jersey.server.monitoring.RequestEventListener;
@ -16,11 +20,6 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager; import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import javax.ws.rs.container.ResourceInfo;
import javax.ws.rs.core.Context;
import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name;
public class WebsocketRefreshRequestEventListener implements RequestEventListener { public class WebsocketRefreshRequestEventListener implements RequestEventListener {
private final ClientPresenceManager clientPresenceManager; private final ClientPresenceManager clientPresenceManager;
@ -60,7 +59,7 @@ public class WebsocketRefreshRequestEventListener implements RequestEventListene
.forEach(pair -> { .forEach(pair -> {
try { try {
displacedDevices.incrementAndGet(); displacedDevices.incrementAndGet();
clientPresenceManager.displacePresence(pair.first(), pair.second()); clientPresenceManager.disconnectPresence(pair.first(), pair.second());
} catch (final Exception e) { } catch (final Exception e) {
logger.error("Could not displace device presence", e); logger.error("Could not displace device presence", e);
} }

View File

@ -171,8 +171,16 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter<String, Str
} }
} }
public void displacePresence(final UUID accountUuid, final long deviceId) { public void disconnectPresence(final UUID accountUuid, final long deviceId) {
displacePresence(getPresenceKey(accountUuid, deviceId)); final String presenceKey = getPresenceKey(accountUuid, deviceId);
if (isLocallyPresent(accountUuid, deviceId)) {
displacePresence(presenceKey);
}
// If connected locally, we still need to clean up the presence key.
// If connected remotely, the other server will get a keyspace message and handle the disconnect
presenceCluster.useCluster(connection -> connection.sync().del(presenceKey));
} }
private void displacePresence(final String presenceKey) { private void displacePresence(final String presenceKey) {
@ -268,18 +276,22 @@ public class ClientPresenceManager extends RedisClusterPubSubAdapter<String, Str
public void message(final RedisClusterNode node, final String channel, final String message) { public void message(final RedisClusterNode node, final String channel, final String message) {
pubSubMessageMeter.mark(); pubSubMessageMeter.mark();
if ("set".equals(message) && channel.startsWith("__keyspace@0__:presence::{")) { if (channel.startsWith("__keyspace@0__:presence::{")) {
// Another process has overwritten this presence key, which means the client has connected to another host. if ("set".equals(message) || "del".equals(message)) {
// At this point, we're on a Lettuce IO thread and need to dispatch to a separate thread before making // for "set", another process has overwritten this presence key, which means the client has connected to another host.
// synchronous Lettuce calls to avoid deadlocking. // for "del", another process has indicated the client should be disconnected
keyspaceNotificationExecutorService.execute(() -> {
try { // At this point, we're on a Lettuce IO thread and need to dispatch to a separate thread before making
displacePresence(channel.substring("__keyspace@0__:".length())); // synchronous Lettuce calls to avoid deadlocking.
remoteDisplacementMeter.mark(); keyspaceNotificationExecutorService.execute(() -> {
} catch (final Exception e) { try {
log.warn("Error displacing presence", e); displacePresence(channel.substring("__keyspace@0__:".length()));
} remoteDisplacementMeter.mark();
}); } catch (final Exception e) {
log.warn("Error displacing presence", e);
}
});
}
} }
} }

View File

@ -552,7 +552,7 @@ public class AccountsManager {
RedisOperation.unchecked(() -> RedisOperation.unchecked(() ->
account.getDevices().forEach(device -> account.getDevices().forEach(device ->
clientPresenceManager.displacePresence(account.getUuid(), device.getId()))); clientPresenceManager.disconnectPresence(account.getUuid(), device.getId())));
} }
private String getAccountMapKey(String key) { private String getAccountMapKey(String key) {

View File

@ -190,12 +190,12 @@ class AuthEnablementRefreshRequirementProviderTest {
assertAll( assertAll(
initialEnabled.keySet().stream() initialEnabled.keySet().stream()
.map(deviceId -> () -> verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0)) .map(deviceId -> () -> verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0))
.displacePresence(account.getUuid(), deviceId))); .disconnectPresence(account.getUuid(), deviceId)));
assertAll( assertAll(
finalEnabled.keySet().stream() finalEnabled.keySet().stream()
.map(deviceId -> () -> verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0)) .map(deviceId -> () -> verify(clientPresenceManager, times(expectDisplacedPresence ? 1 : 0))
.displacePresence(account.getUuid(), deviceId))); .disconnectPresence(account.getUuid(), deviceId)));
} }
static Stream<Arguments> testDeviceEnabledChanged() { static Stream<Arguments> testDeviceEnabledChanged() {
@ -227,9 +227,9 @@ class AuthEnablementRefreshRequirementProviderTest {
assertEquals(initialDeviceCount + addedDeviceNames.size(), account.getDevices().size()); assertEquals(initialDeviceCount + addedDeviceNames.size(), account.getDevices().size());
verify(clientPresenceManager).displacePresence(account.getUuid(), 1); verify(clientPresenceManager).disconnectPresence(account.getUuid(), 1);
verify(clientPresenceManager).displacePresence(account.getUuid(), 2); verify(clientPresenceManager).disconnectPresence(account.getUuid(), 2);
verify(clientPresenceManager).displacePresence(account.getUuid(), 3); verify(clientPresenceManager).disconnectPresence(account.getUuid(), 3);
} }
@ParameterizedTest @ParameterizedTest
@ -260,7 +260,7 @@ class AuthEnablementRefreshRequirementProviderTest {
assertEquals(200, response.getStatus()); assertEquals(200, response.getStatus());
initialDeviceIds.forEach(deviceId -> initialDeviceIds.forEach(deviceId ->
verify(clientPresenceManager).displacePresence(account.getUuid(), deviceId)); verify(clientPresenceManager).disconnectPresence(account.getUuid(), deviceId));
verifyNoMoreInteractions(clientPresenceManager); verifyNoMoreInteractions(clientPresenceManager);
} }
@ -285,8 +285,8 @@ class AuthEnablementRefreshRequirementProviderTest {
assertTrue(account.getDevice(deletedDeviceId).isEmpty()); assertTrue(account.getDevice(deletedDeviceId).isEmpty());
initialDeviceIds.forEach(deviceId -> verify(clientPresenceManager).displacePresence(account.getUuid(), deviceId)); initialDeviceIds.forEach(deviceId -> verify(clientPresenceManager).disconnectPresence(account.getUuid(), deviceId));
verify(clientPresenceManager).displacePresence(account.getUuid(), deletedDeviceId); verify(clientPresenceManager).disconnectPresence(account.getUuid(), deletedDeviceId);
verifyNoMoreInteractions(clientPresenceManager); verifyNoMoreInteractions(clientPresenceManager);
} }

View File

@ -5,236 +5,314 @@
package org.whispersystems.textsecuregcm.push; package org.whispersystems.textsecuregcm.push;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import io.lettuce.core.cluster.api.StatefulRedisClusterConnection;
import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent; import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent;
import java.time.Duration;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import org.junit.After; import java.util.function.Function;
import org.junit.Before; import org.junit.jupiter.api.AfterEach;
import org.junit.Test; import org.junit.jupiter.api.BeforeEach;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
public class ClientPresenceManagerTest extends AbstractRedisClusterTest { class ClientPresenceManagerTest {
private ScheduledExecutorService presenceRenewalExecutorService; @RegisterExtension
private ClientPresenceManager clientPresenceManager; static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
private static final DisplacedPresenceListener NO_OP = () -> {}; private ScheduledExecutorService presenceRenewalExecutorService;
private ClientPresenceManager clientPresenceManager;
@Override private static final DisplacedPresenceListener NO_OP = () -> {
@Before };
public void setUp() throws Exception {
super.setUp();
getRedisCluster().useCluster(connection -> { @BeforeEach
connection.sync().flushall(); void setUp() throws Exception {
connection.sync().upstream().commands().configSet("notify-keyspace-events", "K$z");
});
presenceRenewalExecutorService = Executors.newSingleThreadScheduledExecutor(); REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(connection -> {
clientPresenceManager = new ClientPresenceManager(getRedisCluster(), presenceRenewalExecutorService, presenceRenewalExecutorService); connection.sync().flushall();
connection.sync().upstream().commands().configSet("notify-keyspace-events", "K$glz");
});
presenceRenewalExecutorService = Executors.newSingleThreadScheduledExecutor();
clientPresenceManager = new ClientPresenceManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
presenceRenewalExecutorService,
presenceRenewalExecutorService);
}
@AfterEach
public void tearDown() throws Exception {
presenceRenewalExecutorService.shutdown();
presenceRenewalExecutorService.awaitTermination(1, TimeUnit.MINUTES);
clientPresenceManager.stop();
}
@Test
void testIsPresent() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
assertFalse(clientPresenceManager.isPresent(accountUuid, deviceId));
clientPresenceManager.setPresent(accountUuid, deviceId, NO_OP);
assertTrue(clientPresenceManager.isPresent(accountUuid, deviceId));
}
@Test
void testIsLocallyPresent() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
assertFalse(clientPresenceManager.isLocallyPresent(accountUuid, deviceId));
clientPresenceManager.setPresent(accountUuid, deviceId, NO_OP);
REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(connection -> connection.sync().flushall());
assertTrue(clientPresenceManager.isLocallyPresent(accountUuid, deviceId));
}
@Test
void testLocalDisplacement() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final AtomicInteger displacementCounter = new AtomicInteger(0);
final DisplacedPresenceListener displacementListener = displacementCounter::incrementAndGet;
clientPresenceManager.setPresent(accountUuid, deviceId, displacementListener);
assertEquals(0, displacementCounter.get());
clientPresenceManager.setPresent(accountUuid, deviceId, displacementListener);
assertEquals(1, displacementCounter.get());
}
@Test
void testRemoteDisplacement() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final CompletableFuture<?> displaced = new CompletableFuture<>();
clientPresenceManager.start();
clientPresenceManager.setPresent(accountUuid, deviceId, () -> displaced.complete(null));
REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(
connection -> connection.sync().set(ClientPresenceManager.getPresenceKey(accountUuid, deviceId),
UUID.randomUUID().toString()));
assertTimeoutPreemptively(Duration.ofSeconds(10), displaced::join);
}
@Test
void testRemoteDisplacementAfterTopologyChange() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final CompletableFuture<?> displaced = new CompletableFuture<>();
clientPresenceManager.start();
clientPresenceManager.setPresent(accountUuid, deviceId, () -> displaced.complete(null));
clientPresenceManager.getPubSubConnection()
.usePubSubConnection(connection -> connection.getResources().eventBus()
.publish(new ClusterTopologyChangedEvent(List.of(), List.of())));
REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(
connection -> connection.sync().set(ClientPresenceManager.getPresenceKey(accountUuid, deviceId),
UUID.randomUUID().toString()));
assertTimeoutPreemptively(Duration.ofSeconds(10), displaced::join);
}
@Test
void testClearPresence() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
assertFalse(clientPresenceManager.isPresent(accountUuid, deviceId));
clientPresenceManager.setPresent(accountUuid, deviceId, NO_OP);
assertTrue(clientPresenceManager.clearPresence(accountUuid, deviceId));
clientPresenceManager.setPresent(accountUuid, deviceId, NO_OP);
REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(
connection -> connection.sync().set(ClientPresenceManager.getPresenceKey(accountUuid, deviceId),
UUID.randomUUID().toString()));
assertFalse(clientPresenceManager.clearPresence(accountUuid, deviceId));
}
@Test
void testPruneMissingPeers() {
final String presentPeerId = UUID.randomUUID().toString();
final String missingPeerId = UUID.randomUUID().toString();
REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(connection -> {
connection.sync().sadd(ClientPresenceManager.MANAGER_SET_KEY, presentPeerId);
connection.sync().sadd(ClientPresenceManager.MANAGER_SET_KEY, missingPeerId);
});
for (int i = 0; i < 10; i++) {
addClientPresence(presentPeerId);
addClientPresence(missingPeerId);
} }
@Override clientPresenceManager.getPubSubConnection().usePubSubConnection(
@After connection -> connection.sync().upstream().commands()
public void tearDown() throws Exception { .subscribe(ClientPresenceManager.getManagerPresenceChannel(presentPeerId)));
super.tearDown(); clientPresenceManager.pruneMissingPeers();
presenceRenewalExecutorService.shutdown(); assertEquals(1, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster().withCluster(
presenceRenewalExecutorService.awaitTermination(1, TimeUnit.MINUTES); connection -> connection.sync().exists(ClientPresenceManager.getConnectedClientSetKey(presentPeerId))));
assertTrue(REDIS_CLUSTER_EXTENSION.getRedisCluster().withCluster(
(Function<StatefulRedisClusterConnection<String, String>, Boolean>) connection -> connection.sync()
.sismember(ClientPresenceManager.MANAGER_SET_KEY, presentPeerId)));
assertEquals(0, (long) REDIS_CLUSTER_EXTENSION.getRedisCluster().withCluster(
connection -> connection.sync().exists(ClientPresenceManager.getConnectedClientSetKey(missingPeerId))));
assertFalse(REDIS_CLUSTER_EXTENSION.getRedisCluster().withCluster(
(Function<StatefulRedisClusterConnection<String, String>, Boolean>) connection -> connection.sync()
.sismember(ClientPresenceManager.MANAGER_SET_KEY, missingPeerId)));
}
private void addClientPresence(final String managerId) {
final String clientPresenceKey = ClientPresenceManager.getPresenceKey(UUID.randomUUID(), 7);
REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(connection -> {
connection.sync().set(clientPresenceKey, managerId);
connection.sync().sadd(ClientPresenceManager.getConnectedClientSetKey(managerId), clientPresenceKey);
});
}
@Test
void testClearAllOnStop() {
final int localAccounts = 10;
final UUID[] localUuids = new UUID[localAccounts];
final long[] localDeviceIds = new long[localAccounts];
for (int i = 0; i < localAccounts; i++) {
localUuids[i] = UUID.randomUUID();
localDeviceIds[i] = i;
clientPresenceManager.setPresent(localUuids[i], localDeviceIds[i], NO_OP);
}
final UUID displacedAccountUuid = UUID.randomUUID();
final long displacedAccountDeviceId = 7;
clientPresenceManager.setPresent(displacedAccountUuid, displacedAccountDeviceId, NO_OP);
REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(connection -> connection.sync()
.set(ClientPresenceManager.getPresenceKey(displacedAccountUuid, displacedAccountDeviceId),
UUID.randomUUID().toString()));
clientPresenceManager.stop();
for (int i = 0; i < localAccounts; i++) {
localUuids[i] = UUID.randomUUID();
localDeviceIds[i] = i;
assertFalse(clientPresenceManager.isPresent(localUuids[i], localDeviceIds[i]));
}
assertTrue(clientPresenceManager.isPresent(displacedAccountUuid, displacedAccountDeviceId));
}
@Nested
class MultiServerTest {
private ClientPresenceManager server1;
private ClientPresenceManager server2;
@BeforeEach
void setup() throws Exception {
REDIS_CLUSTER_EXTENSION.getRedisCluster().useCluster(connection -> {
connection.sync().flushall();
connection.sync().upstream().commands().configSet("notify-keyspace-events", "K$glz");
});
final ScheduledExecutorService scheduledExecutorService1 = mock(ScheduledExecutorService.class);
final ExecutorService keyspaceNotificationExecutorService1 = Executors.newSingleThreadExecutor();
server1 = new ClientPresenceManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
scheduledExecutorService1, keyspaceNotificationExecutorService1);
final ScheduledExecutorService scheduledExecutorService2 = mock(ScheduledExecutorService.class);
final ExecutorService keyspaceNotificationExecutorService2 = Executors.newSingleThreadExecutor();
server2 = new ClientPresenceManager(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
scheduledExecutorService2, keyspaceNotificationExecutorService2);
server1.start();
server2.start();
}
@AfterEach
void teardown() {
server2.stop();
server1.stop();
} }
@Test @Test
public void testIsPresent() { void testSetPresentRemotely() {
final UUID accountUuid = UUID.randomUUID(); final UUID uuid1 = UUID.randomUUID();
final long deviceId = 1; final long deviceId = 1L;
assertFalse(clientPresenceManager.isPresent(accountUuid, deviceId)); final CompletableFuture<?> displaced = new CompletableFuture<>();
final DisplacedPresenceListener listener1 = () -> displaced.complete(null);
server1.setPresent(uuid1, deviceId, listener1);
clientPresenceManager.setPresent(accountUuid, deviceId, NO_OP); server2.setPresent(uuid1, deviceId, () -> {
assertTrue(clientPresenceManager.isPresent(accountUuid, deviceId)); });
assertTimeoutPreemptively(Duration.ofSeconds(10), displaced::join);
} }
@Test @Test
public void testIsLocallyPresent() { void testDisconnectPresenceLocally() {
final UUID accountUuid = UUID.randomUUID(); final UUID uuid1 = UUID.randomUUID();
final long deviceId = 1; final long deviceId = 1L;
assertFalse(clientPresenceManager.isLocallyPresent(accountUuid, deviceId)); final CompletableFuture<?> displaced = new CompletableFuture<>();
final DisplacedPresenceListener listener1 = () -> displaced.complete(null);
server1.setPresent(uuid1, deviceId, listener1);
clientPresenceManager.setPresent(accountUuid, deviceId, NO_OP); server1.disconnectPresence(uuid1, deviceId);
getRedisCluster().useCluster(connection -> connection.sync().flushall());
assertTrue(clientPresenceManager.isLocallyPresent(accountUuid, deviceId)); assertTimeoutPreemptively(Duration.ofSeconds(10), displaced::join);
} }
@Test @Test
public void testLocalDisplacement() { void testDisconnectPresenceRemotely() {
final UUID accountUuid = UUID.randomUUID(); final UUID uuid1 = UUID.randomUUID();
final long deviceId = 1; final long deviceId = 1L;
final AtomicInteger displacementCounter = new AtomicInteger(0); final CompletableFuture<?> displaced = new CompletableFuture<>();
final DisplacedPresenceListener displacementListener = displacementCounter::incrementAndGet; final DisplacedPresenceListener listener1 = () -> displaced.complete(null);
server1.setPresent(uuid1, deviceId, listener1);
clientPresenceManager.setPresent(accountUuid, deviceId, displacementListener); server2.disconnectPresence(uuid1, deviceId);
assertEquals(0, displacementCounter.get()); assertTimeoutPreemptively(Duration.ofSeconds(10), displaced::join);
clientPresenceManager.setPresent(accountUuid, deviceId, displacementListener);
assertEquals(1, displacementCounter.get());
}
@Test(timeout = 10_000)
public void testRemoteDisplacement() throws InterruptedException {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final AtomicBoolean displaced = new AtomicBoolean(false);
clientPresenceManager.start();
try {
clientPresenceManager.setPresent(accountUuid, deviceId, () -> {
synchronized (displaced) {
displaced.set(true);
displaced.notifyAll();
}
});
getRedisCluster().useCluster(connection -> connection.sync().set(ClientPresenceManager.getPresenceKey(accountUuid, deviceId),
UUID.randomUUID().toString()));
synchronized (displaced) {
while (!displaced.get()) {
displaced.wait();
}
}
} finally {
clientPresenceManager.stop();
}
}
@Test(timeout = 10_000)
public void testRemoteDisplacementAfterTopologyChange() throws InterruptedException {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
final AtomicBoolean displaced = new AtomicBoolean(false);
clientPresenceManager.start();
try {
clientPresenceManager.setPresent(accountUuid, deviceId, () -> {
synchronized (displaced) {
displaced.set(true);
displaced.notifyAll();
}
});
clientPresenceManager.getPubSubConnection().usePubSubConnection(connection -> connection.getResources().eventBus().publish(new ClusterTopologyChangedEvent(List.of(), List.of())));
getRedisCluster().useCluster(connection -> connection.sync().set(ClientPresenceManager.getPresenceKey(accountUuid, deviceId),
UUID.randomUUID().toString()));
synchronized (displaced) {
while (!displaced.get()) {
displaced.wait();
}
}
} finally {
clientPresenceManager.stop();
}
}
@Test
public void testClearPresence() {
final UUID accountUuid = UUID.randomUUID();
final long deviceId = 1;
assertFalse(clientPresenceManager.isPresent(accountUuid, deviceId));
clientPresenceManager.setPresent(accountUuid, deviceId, NO_OP);
assertTrue(clientPresenceManager.clearPresence(accountUuid, deviceId));
clientPresenceManager.setPresent(accountUuid, deviceId, NO_OP);
getRedisCluster().useCluster(connection -> connection.sync().set(ClientPresenceManager.getPresenceKey(accountUuid, deviceId),
UUID.randomUUID().toString()));
assertFalse(clientPresenceManager.clearPresence(accountUuid, deviceId));
}
@Test
public void testPruneMissingPeers() {
final String presentPeerId = UUID.randomUUID().toString();
final String missingPeerId = UUID.randomUUID().toString();
getRedisCluster().useCluster(connection -> {
connection.sync().sadd(ClientPresenceManager.MANAGER_SET_KEY, presentPeerId);
connection.sync().sadd(ClientPresenceManager.MANAGER_SET_KEY, missingPeerId);
});
for (int i = 0; i < 10; i++) {
addClientPresence(presentPeerId);
addClientPresence(missingPeerId);
}
clientPresenceManager.getPubSubConnection().usePubSubConnection(connection -> connection.sync().upstream().commands().subscribe(ClientPresenceManager.getManagerPresenceChannel(presentPeerId)));
clientPresenceManager.pruneMissingPeers();
assertEquals(1, (long)getRedisCluster().withCluster(connection -> connection.sync().exists(ClientPresenceManager.getConnectedClientSetKey(presentPeerId))));
assertTrue(getRedisCluster().withCluster(connection -> connection.sync().sismember(ClientPresenceManager.MANAGER_SET_KEY, presentPeerId)));
assertEquals(0, (long)getRedisCluster().withCluster(connection -> connection.sync().exists(ClientPresenceManager.getConnectedClientSetKey(missingPeerId))));
assertFalse(getRedisCluster().withCluster(connection -> connection.sync().sismember(ClientPresenceManager.MANAGER_SET_KEY, missingPeerId)));
}
private void addClientPresence(final String managerId) {
final String clientPresenceKey = ClientPresenceManager.getPresenceKey(UUID.randomUUID(), 7);
getRedisCluster().useCluster(connection -> {
connection.sync().set(clientPresenceKey, managerId);
connection.sync().sadd(ClientPresenceManager.getConnectedClientSetKey(managerId), clientPresenceKey);
});
}
@Test
public void testClearAllOnStop() {
final int localAccounts = 10;
final UUID[] localUuids = new UUID[localAccounts];
final long[] localDeviceIds = new long[localAccounts];
for (int i = 0; i < localAccounts; i++) {
localUuids[i] = UUID.randomUUID();
localDeviceIds[i] = i;
clientPresenceManager.setPresent(localUuids[i], localDeviceIds[i], NO_OP);
}
final UUID displacedAccountUuid = UUID.randomUUID();
final long displacedAccountDeviceId = 7;
clientPresenceManager.setPresent(displacedAccountUuid, displacedAccountDeviceId, NO_OP);
getRedisCluster().useCluster(connection -> connection.sync().set(ClientPresenceManager.getPresenceKey(displacedAccountUuid, displacedAccountDeviceId),
UUID.randomUUID().toString()));
clientPresenceManager.stop();
for (int i = 0; i < localAccounts; i++) {
localUuids[i] = UUID.randomUUID();
localDeviceIds[i] = i;
assertFalse(clientPresenceManager.isPresent(localUuids[i], localDeviceIds[i]));
}
assertTrue(clientPresenceManager.isPresent(displacedAccountUuid, displacedAccountDeviceId));
} }
}
} }

View File

@ -135,7 +135,7 @@ public class RedisClusterExtension implements BeforeAllCallback, BeforeEachCallb
try { try {
final StatefulRedisConnection<String, String> connection = meetClient.connect(); final StatefulRedisConnection<String, String> connection = meetClient.connect();
final RedisCommands<String, String> commands = connection.sync(); final RedisCommands<String, String> commands = connection.sync();
for (int i = 1; i < nodes.length; i++) { for (int i = 1; i < nodes.length; i++) {
commands.clusterMeet("127.0.0.1", nodes[i].ports().get(0)); commands.clusterMeet("127.0.0.1", nodes[i].ports().get(0));
@ -148,7 +148,7 @@ public class RedisClusterExtension implements BeforeAllCallback, BeforeEachCallb
for (int i = 0; i < nodes.length; i++) { for (int i = 0; i < nodes.length; i++) {
final int startInclusive = i * slotsPerNode; final int startInclusive = i * slotsPerNode;
final int endExclusive = i == nodes.length - 1 ? SlotHash.SLOT_COUNT : (i + 1) * slotsPerNode; final int endExclusive = i == nodes.length - 1 ? SlotHash.SLOT_COUNT : (i + 1) * slotsPerNode;
final RedisClient assignSlotClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[i].ports().get(0))); final RedisClient assignSlotClient = RedisClient.create(RedisURI.create("127.0.0.1", nodes[i].ports().get(0)));

View File

@ -265,7 +265,7 @@ class AccountsManagerChangeNumberIntegrationTest {
assertEquals(secondNumber, accountsManager.getByAccountIdentifier(originalUuid).map(Account::getNumber).orElseThrow()); assertEquals(secondNumber, accountsManager.getByAccountIdentifier(originalUuid).map(Account::getNumber).orElseThrow());
verify(clientPresenceManager).displacePresence(existingAccountUuid, Device.MASTER_ID); verify(clientPresenceManager).disconnectPresence(existingAccountUuid, Device.MASTER_ID);
assertEquals(Optional.of(existingAccountUuid), deletedAccounts.findUuid(originalNumber)); assertEquals(Optional.of(existingAccountUuid), deletedAccounts.findUuid(originalNumber));
assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber)); assertEquals(Optional.empty(), deletedAccounts.findUuid(secondNumber));

View File

@ -61,15 +61,16 @@ import org.whispersystems.textsecuregcm.util.VerificationCode;
@ExtendWith(DropwizardExtensionsSupport.class) @ExtendWith(DropwizardExtensionsSupport.class)
class DeviceControllerTest { class DeviceControllerTest {
@Path("/v1/devices") @Path("/v1/devices")
static class DumbVerificationDeviceController extends DeviceController { static class DumbVerificationDeviceController extends DeviceController {
public DumbVerificationDeviceController(StoredVerificationCodeManager pendingDevices, public DumbVerificationDeviceController(StoredVerificationCodeManager pendingDevices,
AccountsManager accounts, AccountsManager accounts,
MessagesManager messages, MessagesManager messages,
Keys keys, Keys keys,
RateLimiters rateLimiters, RateLimiters rateLimiters,
Map<String, Integer> deviceConfiguration) Map<String, Integer> deviceConfiguration) {
{
super(pendingDevices, accounts, messages, keys, rateLimiters, deviceConfiguration); super(pendingDevices, accounts, messages, keys, rateLimiters, deviceConfiguration);
} }
@ -80,17 +81,17 @@ class DeviceControllerTest {
} }
private static StoredVerificationCodeManager pendingDevicesManager = mock(StoredVerificationCodeManager.class); private static StoredVerificationCodeManager pendingDevicesManager = mock(StoredVerificationCodeManager.class);
private static AccountsManager accountsManager = mock(AccountsManager.class ); private static AccountsManager accountsManager = mock(AccountsManager.class);
private static MessagesManager messagesManager = mock(MessagesManager.class); private static MessagesManager messagesManager = mock(MessagesManager.class);
private static Keys keys = mock(Keys.class); private static Keys keys = mock(Keys.class);
private static RateLimiters rateLimiters = mock(RateLimiters.class ); private static RateLimiters rateLimiters = mock(RateLimiters.class);
private static RateLimiter rateLimiter = mock(RateLimiter.class ); private static RateLimiter rateLimiter = mock(RateLimiter.class);
private static Account account = mock(Account.class ); private static Account account = mock(Account.class);
private static Account maxedAccount = mock(Account.class); private static Account maxedAccount = mock(Account.class);
private static Device masterDevice = mock(Device.class); private static Device masterDevice = mock(Device.class);
private static ClientPresenceManager clientPresenceManager = mock(ClientPresenceManager.class); private static ClientPresenceManager clientPresenceManager = mock(ClientPresenceManager.class);
private static Map<String, Integer> deviceConfiguration = new HashMap<>(); private static Map<String, Integer> deviceConfiguration = new HashMap<>();
private static final ResourceExtension resources = ResourceExtension.builder() private static final ResourceExtension resources = ResourceExtension.builder()
.addProvider(AuthHelper.getAuthFilter()) .addProvider(AuthHelper.getAuthFilter())
@ -162,27 +163,27 @@ class DeviceControllerTest {
when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(Set.of(existingDevice)); when(AuthHelper.VALID_ACCOUNT.getDevices()).thenReturn(Set.of(existingDevice));
VerificationCode deviceCode = resources.getJerseyTest() VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code") .target("/v1/devices/provisioning/code")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class); .get(VerificationCode.class);
assertThat(deviceCode).isEqualTo(new VerificationCode(5678901)); assertThat(deviceCode).isEqualTo(new VerificationCode(5678901));
DeviceResponse response = resources.getJerseyTest() DeviceResponse response = resources.getJerseyTest()
.target("/v1/devices/5678901") .target("/v1/devices/5678901")
.request() .request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1")) .header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.put(Entity.entity(new AccountAttributes(false, 1234, null, .put(Entity.entity(new AccountAttributes(false, 1234, null,
null, true, null), null, true, null),
MediaType.APPLICATION_JSON_TYPE), MediaType.APPLICATION_JSON_TYPE),
DeviceResponse.class); DeviceResponse.class);
assertThat(response.getDeviceId()).isEqualTo(42L); assertThat(response.getDeviceId()).isEqualTo(42L);
verify(pendingDevicesManager).remove(AuthHelper.VALID_NUMBER); verify(pendingDevicesManager).remove(AuthHelper.VALID_NUMBER);
verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L)); verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L));
verify(clientPresenceManager).displacePresence(AuthHelper.VALID_UUID, Device.MASTER_ID); verify(clientPresenceManager).disconnectPresence(AuthHelper.VALID_UUID, Device.MASTER_ID);
} }
@Test @Test
@ -201,30 +202,30 @@ class DeviceControllerTest {
@Test @Test
void disabledDeviceRegisterTest() { void disabledDeviceRegisterTest() {
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()
.target("/v1/devices/provisioning/code") .target("/v1/devices/provisioning/code")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.DISABLED_UUID, AuthHelper.DISABLED_PASSWORD))
.get(); .get();
assertThat(response.getStatus()).isEqualTo(401); assertThat(response.getStatus()).isEqualTo(401);
} }
@Test @Test
void invalidDeviceRegisterTest() { void invalidDeviceRegisterTest() {
VerificationCode deviceCode = resources.getJerseyTest() VerificationCode deviceCode = resources.getJerseyTest()
.target("/v1/devices/provisioning/code") .target("/v1/devices/provisioning/code")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get(VerificationCode.class); .get(VerificationCode.class);
assertThat(deviceCode).isEqualTo(new VerificationCode(5678901)); assertThat(deviceCode).isEqualTo(new VerificationCode(5678901));
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()
.target("/v1/devices/5678902") .target("/v1/devices/5678902")
.request() .request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1")) .header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.put(Entity.entity(new AccountAttributes(false, 1234, null, null, true, null), .put(Entity.entity(new AccountAttributes(false, 1234, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE)); MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(403); assertThat(response.getStatus()).isEqualTo(403);
@ -234,11 +235,12 @@ class DeviceControllerTest {
@Test @Test
void oldDeviceRegisterTest() { void oldDeviceRegisterTest() {
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()
.target("/v1/devices/1112223") .target("/v1/devices/1112223")
.request() .request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER_TWO, AuthHelper.VALID_PASSWORD_TWO)) .header("Authorization",
.put(Entity.entity(new AccountAttributes(false, 1234, null, null, true, null), AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER_TWO, AuthHelper.VALID_PASSWORD_TWO))
MediaType.APPLICATION_JSON_TYPE)); .put(Entity.entity(new AccountAttributes(false, 1234, null, null, true, null),
MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(403); assertThat(response.getStatus()).isEqualTo(403);
@ -248,10 +250,10 @@ class DeviceControllerTest {
@Test @Test
void maxDevicesTest() { void maxDevicesTest() {
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()
.target("/v1/devices/provisioning/code") .target("/v1/devices/provisioning/code")
.request() .request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_TWO, AuthHelper.VALID_PASSWORD_TWO)) .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID_TWO, AuthHelper.VALID_PASSWORD_TWO))
.get(); .get();
assertEquals(411, response.getStatus()); assertEquals(411, response.getStatus());
verifyNoMoreInteractions(messagesManager); verifyNoMoreInteractions(messagesManager);
@ -260,11 +262,13 @@ class DeviceControllerTest {
@Test @Test
void longNameTest() { void longNameTest() {
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()
.target("/v1/devices/5678901") .target("/v1/devices/5678901")
.request() .request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1")) .header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.put(Entity.entity(new AccountAttributes(false, 1234, "this is a really long name that is longer than 80 characters it's so long that it's even longer than 204 characters. that's a lot of characters. we're talking lots and lots and lots of characters. 12345678", null, true, null), .put(Entity.entity(new AccountAttributes(false, 1234,
MediaType.APPLICATION_JSON_TYPE)); "this is a really long name that is longer than 80 characters it's so long that it's even longer than 204 characters. that's a lot of characters. we're talking lots and lots and lots of characters. 12345678",
null, true, null),
MediaType.APPLICATION_JSON_TYPE));
assertEquals(response.getStatus(), 422); assertEquals(response.getStatus(), 422);
verifyNoMoreInteractions(messagesManager); verifyNoMoreInteractions(messagesManager);
@ -272,15 +276,17 @@ class DeviceControllerTest {
@ParameterizedTest @ParameterizedTest
@MethodSource @MethodSource
void deviceDowngradeCapabilitiesTest(final String userAgent, final boolean gv2, final boolean gv2_2, final boolean gv2_3, final int expectedStatus) { void deviceDowngradeCapabilitiesTest(final String userAgent, final boolean gv2, final boolean gv2_2,
DeviceCapabilities deviceCapabilities = new DeviceCapabilities(gv2, gv2_2, gv2_3, true, false, true, true, true, true); final boolean gv2_3, final int expectedStatus) {
DeviceCapabilities deviceCapabilities = new DeviceCapabilities(gv2, gv2_2, gv2_3, true, false, true, true, true,
true);
AccountAttributes accountAttributes = new AccountAttributes(false, 1234, null, null, true, deviceCapabilities); AccountAttributes accountAttributes = new AccountAttributes(false, 1234, null, null, true, deviceCapabilities);
Response response = resources.getJerseyTest() Response response = resources.getJerseyTest()
.target("/v1/devices/5678901") .target("/v1/devices/5678901")
.request() .request()
.header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1")) .header("Authorization", AuthHelper.getProvisioningAuthHeader(AuthHelper.VALID_NUMBER, "password1"))
.header("User-Agent", userAgent) .header("User-Agent", userAgent)
.put(Entity.entity(accountAttributes, MediaType.APPLICATION_JSON_TYPE)); .put(Entity.entity(accountAttributes, MediaType.APPLICATION_JSON_TYPE));
assertThat(response.getStatus()).isEqualTo(expectedStatus); assertThat(response.getStatus()).isEqualTo(expectedStatus);