Return destination client presence when inserting messages

This commit is contained in:
Jon Chambers 2024-11-07 16:10:30 -05:00 committed by Jon Chambers
parent 1fa31b3974
commit eeeb565313
5 changed files with 73 additions and 23 deletions

View File

@ -145,7 +145,7 @@ public class PubSubClientEventManager extends RedisClusterPubSubAdapter<byte[],
} }
final UUID connectionId = UUID.randomUUID(); final UUID connectionId = UUID.randomUUID();
final byte[] clientPresenceKey = getClientPresenceKey(accountIdentifier, deviceId); final byte[] clientPresenceKey = getClientEventChannel(accountIdentifier, deviceId);
final AtomicReference<ClientEventListener> displacedListener = new AtomicReference<>(); final AtomicReference<ClientEventListener> displacedListener = new AtomicReference<>();
final AtomicReference<CompletionStage<Void>> subscribeFuture = new AtomicReference<>(); final AtomicReference<CompletionStage<Void>> subscribeFuture = new AtomicReference<>();
@ -216,7 +216,7 @@ public class PubSubClientEventManager extends RedisClusterPubSubAdapter<byte[],
listenersByAccountAndDeviceIdentifier.compute(new AccountAndDeviceIdentifier(accountIdentifier, deviceId), listenersByAccountAndDeviceIdentifier.compute(new AccountAndDeviceIdentifier(accountIdentifier, deviceId),
(ignored, existingListener) -> { (ignored, existingListener) -> {
unsubscribeFuture.set(pubSubConnection.withPubSubConnection(connection -> unsubscribeFuture.set(pubSubConnection.withPubSubConnection(connection ->
connection.async().sunsubscribe(getClientPresenceKey(accountIdentifier, deviceId))) connection.async().sunsubscribe(getClientEventChannel(accountIdentifier, deviceId)))
.thenRun(Util.NOOP)); .thenRun(Util.NOOP));
return null; return null;
@ -245,7 +245,7 @@ public class PubSubClientEventManager extends RedisClusterPubSubAdapter<byte[],
} }
return pubSubConnection.withPubSubConnection(connection -> return pubSubConnection.withPubSubConnection(connection ->
connection.async().spublish(getClientPresenceKey(accountIdentifier, deviceId), NEW_MESSAGE_EVENT_BYTES)) connection.async().spublish(getClientEventChannel(accountIdentifier, deviceId), NEW_MESSAGE_EVENT_BYTES))
.thenApply(listeners -> listeners > 0); .thenApply(listeners -> listeners > 0);
} }
@ -264,7 +264,7 @@ public class PubSubClientEventManager extends RedisClusterPubSubAdapter<byte[],
} }
return pubSubConnection.withPubSubConnection(connection -> return pubSubConnection.withPubSubConnection(connection ->
connection.async().spublish(getClientPresenceKey(accountIdentifier, deviceId), MESSAGES_PERSISTED_EVENT_BYTES)) connection.async().spublish(getClientEventChannel(accountIdentifier, deviceId), MESSAGES_PERSISTED_EVENT_BYTES))
.thenRun(Util.NOOP); .thenRun(Util.NOOP);
} }
@ -305,7 +305,7 @@ public class PubSubClientEventManager extends RedisClusterPubSubAdapter<byte[],
public CompletableFuture<Void> requestDisconnection(final UUID accountIdentifier, final Collection<Byte> deviceIds) { public CompletableFuture<Void> requestDisconnection(final UUID accountIdentifier, final Collection<Byte> deviceIds) {
return CompletableFuture.allOf(deviceIds.stream() return CompletableFuture.allOf(deviceIds.stream()
.map(deviceId -> { .map(deviceId -> {
final byte[] clientPresenceKey = getClientPresenceKey(accountIdentifier, deviceId); final byte[] clientPresenceKey = getClientEventChannel(accountIdentifier, deviceId);
return clusterClient.withBinaryCluster(connection -> connection.async() return clusterClient.withBinaryCluster(connection -> connection.async()
.spublish(clientPresenceKey, DISCONNECT_REQUESTED_EVENT_BYTES)) .spublish(clientPresenceKey, DISCONNECT_REQUESTED_EVENT_BYTES))
@ -323,12 +323,12 @@ public class PubSubClientEventManager extends RedisClusterPubSubAdapter<byte[],
// Organize subscriptions by slot so we can issue a smaller number of larger resubscription commands // Organize subscriptions by slot so we can issue a smaller number of larger resubscription commands
listenersByAccountAndDeviceIdentifier.keySet() listenersByAccountAndDeviceIdentifier.keySet()
.stream() .stream()
.map(accountAndDeviceIdentifier -> getClientPresenceKey(accountAndDeviceIdentifier.accountIdentifier(), accountAndDeviceIdentifier.deviceId())) .map(accountAndDeviceIdentifier -> getClientEventChannel(accountAndDeviceIdentifier.accountIdentifier(), accountAndDeviceIdentifier.deviceId()))
.forEach(clientPresenceKey -> { .forEach(clientEventChannel -> {
final int slot = SlotHash.getSlot(clientPresenceKey); final int slot = SlotHash.getSlot(clientEventChannel);
if (changedSlots[slot]) { if (changedSlots[slot]) {
clientPresenceKeysBySlot.computeIfAbsent(slot, ignored -> new ArrayList<>()).add(clientPresenceKey); clientPresenceKeysBySlot.computeIfAbsent(slot, ignored -> new ArrayList<>()).add(clientEventChannel);
} }
}); });
@ -380,8 +380,7 @@ public class PubSubClientEventManager extends RedisClusterPubSubAdapter<byte[],
} }
} }
@VisibleForTesting public static byte[] getClientEventChannel(final UUID accountIdentifier, final byte deviceId) {
static byte[] getClientPresenceKey(final UUID accountIdentifier, final byte deviceId) {
return ("client_presence::{" + accountIdentifier + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); return ("client_presence::{" + accountIdentifier + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
} }

View File

@ -13,36 +13,55 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.ClientEvent;
import org.whispersystems.textsecuregcm.push.NewMessageAvailableEvent;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
/** /**
* Inserts an envelope into the message queue for a destination device. * Inserts an envelope into the message queue for a destination device and publishes a "new message available" event.
*/ */
class MessagesCacheInsertScript { class MessagesCacheInsertScript {
private final ClusterLuaScript insertScript; private final ClusterLuaScript insertScript;
private static final byte[] NEW_MESSAGE_EVENT_BYTES = ClientEvent.newBuilder()
.setNewMessageAvailable(NewMessageAvailableEvent.getDefaultInstance())
.build()
.toByteArray();
MessagesCacheInsertScript(FaultTolerantRedisClusterClient redisCluster) throws IOException { MessagesCacheInsertScript(FaultTolerantRedisClusterClient redisCluster) throws IOException {
this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER); this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.BOOLEAN);
} }
void execute(final UUID destinationUuid, final byte destinationDevice, final MessageProtos.Envelope envelope) { /**
* Inserts a message into the given device's message queue and publishes a "new message available" event.
*
* @param destinationUuid the account identifier for the receiving account
* @param destinationDevice the ID of the receiving device within the given account
* @param envelope the message to insert
* @return {@code true} if the destination device had a registered "presence"/event subscriber or {@code false}
* otherwise
*/
boolean execute(final UUID destinationUuid, final byte destinationDevice, final MessageProtos.Envelope envelope) {
assert envelope.hasServerGuid(); assert envelope.hasServerGuid();
assert envelope.hasServerTimestamp(); assert envelope.hasServerTimestamp();
final List<byte[]> keys = List.of( final List<byte[]> keys = List.of(
MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey MessagesCache.getMessageQueueKey(destinationUuid, destinationDevice), // queueKey
MessagesCache.getMessageQueueMetadataKey(destinationUuid, destinationDevice), // queueMetadataKey MessagesCache.getMessageQueueMetadataKey(destinationUuid, destinationDevice), // queueMetadataKey
MessagesCache.getQueueIndexKey(destinationUuid, destinationDevice) // queueTotalIndexKey MessagesCache.getQueueIndexKey(destinationUuid, destinationDevice), // queueTotalIndexKey
PubSubClientEventManager.getClientEventChannel(destinationUuid, destinationDevice) // eventChannelKey
); );
final List<byte[]> args = new ArrayList<>(Arrays.asList( final List<byte[]> args = new ArrayList<>(Arrays.asList(
envelope.toByteArray(), // message envelope.toByteArray(), // message
String.valueOf(envelope.getServerTimestamp()).getBytes(StandardCharsets.UTF_8), // currentTime String.valueOf(envelope.getServerTimestamp()).getBytes(StandardCharsets.UTF_8), // currentTime
envelope.getServerGuid().getBytes(StandardCharsets.UTF_8) // guid envelope.getServerGuid().getBytes(StandardCharsets.UTF_8), // guid
NEW_MESSAGE_EVENT_BYTES // eventPayload
)); ));
insertScript.executeBinary(keys, args); return (boolean) insertScript.executeBinary(keys, args);
} }
} }

View File

@ -4,9 +4,11 @@
local queueKey = KEYS[1] -- sorted set of Envelopes for a device, by queue-local ID local queueKey = KEYS[1] -- sorted set of Envelopes for a device, by queue-local ID
local queueMetadataKey = KEYS[2] -- hash of message GUID to queue-local IDs local queueMetadataKey = KEYS[2] -- hash of message GUID to queue-local IDs
local queueTotalIndexKey = KEYS[3] -- sorted set of all queues in the shard, by timestamp of oldest message local queueTotalIndexKey = KEYS[3] -- sorted set of all queues in the shard, by timestamp of oldest message
local eventChannelKey = KEYS[4] -- pub/sub channel for message availability events
local message = ARGV[1] -- [bytes] the Envelope to insert local message = ARGV[1] -- [bytes] the Envelope to insert
local currentTime = ARGV[2] -- [number] the message timestamp, to sort the queue in the queueTotalIndex local currentTime = ARGV[2] -- [number] the message timestamp, to sort the queue in the queueTotalIndex
local guid = ARGV[3] -- [string] the message GUID local guid = ARGV[3] -- [string] the message GUID
local eventPayload = ARGV[4] -- [bytes] a protobuf payload for a "message available" pub/sub event
if redis.call("HEXISTS", queueMetadataKey, guid) == 1 then if redis.call("HEXISTS", queueMetadataKey, guid) == 1 then
return tonumber(redis.call("HGET", queueMetadataKey, guid)) return tonumber(redis.call("HGET", queueMetadataKey, guid))
@ -21,4 +23,5 @@ redis.call("EXPIRE", queueKey, 3974400) -- 46 days
redis.call("EXPIRE", queueMetadataKey, 3974400) -- 46 days redis.call("EXPIRE", queueMetadataKey, 3974400) -- 46 days
redis.call("ZADD", queueTotalIndexKey, "NX", currentTime, queueKey) redis.call("ZADD", queueTotalIndexKey, "NX", currentTime, queueKey)
return messageId
return redis.call("SPUBLISH", eventChannelKey, eventPayload) > 0

View File

@ -301,7 +301,7 @@ class PubSubClientEventManagerTest {
final UUID firstAccountIdentifier = UUID.randomUUID(); final UUID firstAccountIdentifier = UUID.randomUUID();
final byte firstDeviceId = Device.PRIMARY_ID; final byte firstDeviceId = Device.PRIMARY_ID;
final int firstSlot = SlotHash.getSlot(PubSubClientEventManager.getClientPresenceKey(firstAccountIdentifier, firstDeviceId)); final int firstSlot = SlotHash.getSlot(PubSubClientEventManager.getClientEventChannel(firstAccountIdentifier, firstDeviceId));
final UUID secondAccountIdentifier; final UUID secondAccountIdentifier;
final byte secondDeviceId = firstDeviceId + 1; final byte secondDeviceId = firstDeviceId + 1;
@ -312,7 +312,7 @@ class PubSubClientEventManagerTest {
do { do {
candidateIdentifier = UUID.randomUUID(); candidateIdentifier = UUID.randomUUID();
} while (SlotHash.getSlot(PubSubClientEventManager.getClientPresenceKey(candidateIdentifier, secondDeviceId)) == firstSlot); } while (SlotHash.getSlot(PubSubClientEventManager.getClientEventChannel(candidateIdentifier, secondDeviceId)) == firstSlot);
secondAccountIdentifier = candidateIdentifier; secondAccountIdentifier = candidateIdentifier;
} }
@ -320,7 +320,7 @@ class PubSubClientEventManagerTest {
presenceManager.handleClientConnected(firstAccountIdentifier, firstDeviceId, new ClientEventAdapter()).toCompletableFuture().join(); presenceManager.handleClientConnected(firstAccountIdentifier, firstDeviceId, new ClientEventAdapter()).toCompletableFuture().join();
presenceManager.handleClientConnected(secondAccountIdentifier, secondDeviceId, new ClientEventAdapter()).toCompletableFuture().join(); presenceManager.handleClientConnected(secondAccountIdentifier, secondDeviceId, new ClientEventAdapter()).toCompletableFuture().join();
final int secondSlot = SlotHash.getSlot(PubSubClientEventManager.getClientPresenceKey(secondAccountIdentifier, secondDeviceId)); final int secondSlot = SlotHash.getSlot(PubSubClientEventManager.getClientEventChannel(secondAccountIdentifier, secondDeviceId));
final String firstNodeId = UUID.randomUUID().toString(); final String firstNodeId = UUID.randomUUID().toString();
@ -343,7 +343,7 @@ class PubSubClientEventManagerTest {
List.of(firstBeforeNode), List.of(firstBeforeNode),
List.of(firstAfterNode, secondAfterNode))); List.of(firstAfterNode, secondAfterNode)));
verify(pubSubCommands).ssubscribe(PubSubClientEventManager.getClientPresenceKey(secondAccountIdentifier, secondDeviceId)); verify(pubSubCommands).ssubscribe(PubSubClientEventManager.getClientEventChannel(secondAccountIdentifier, secondDeviceId));
verify(pubSubCommands, never()).ssubscribe(PubSubClientEventManager.getClientPresenceKey(firstAccountIdentifier, firstDeviceId)); verify(pubSubCommands, never()).ssubscribe(PubSubClientEventManager.getClientEventChannel(firstAccountIdentifier, firstDeviceId));
} }
} }

View File

@ -6,6 +6,8 @@
package org.whispersystems.textsecuregcm.storage; package org.whispersystems.textsecuregcm.storage;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.IOException; import java.io.IOException;
import java.io.UncheckedIOException; import java.io.UncheckedIOException;
@ -18,6 +20,8 @@ import com.google.protobuf.InvalidProtocolBufferException;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.entities.MessageProtos; import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubClusterConnection;
import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; import org.whispersystems.textsecuregcm.redis.RedisClusterExtension;
class MessagesCacheInsertScriptTest { class MessagesCacheInsertScriptTest {
@ -76,4 +80,29 @@ class MessagesCacheInsertScriptTest {
return messages; return messages;
} }
@Test
void returnPresence() throws IOException {
final UUID destinationUuid = UUID.randomUUID();
final byte deviceId = 1;
final MessagesCacheInsertScript insertScript =
new MessagesCacheInsertScript(REDIS_CLUSTER_EXTENSION.getRedisCluster());
assertFalse(insertScript.execute(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(UUID.randomUUID().toString())
.build()));
final FaultTolerantPubSubClusterConnection<byte[], byte[]> pubSubClusterConnection =
REDIS_CLUSTER_EXTENSION.getRedisCluster().createBinaryPubSubConnection();
pubSubClusterConnection.usePubSubConnection(connection ->
connection.sync().ssubscribe(PubSubClientEventManager.getClientEventChannel(destinationUuid, deviceId)));
assertTrue(insertScript.execute(destinationUuid, deviceId, MessageProtos.Envelope.newBuilder()
.setServerTimestamp(Instant.now().getEpochSecond())
.setServerGuid(UUID.randomUUID().toString())
.build()));
}
} }