From 06754d6158d92a8b3e2c2b6c126097f4db747a59 Mon Sep 17 00:00:00 2001 From: Jon Chambers Date: Wed, 2 Sep 2020 18:22:01 -0400 Subject: [PATCH] Add a system for storing, retrieving, and notifying listeners about ephemeral (online) messages. --- .../storage/MessageAvailabilityListener.java | 4 + .../textsecuregcm/storage/MessagesCache.java | 101 ++++++++++++++---- .../websocket/WebSocketConnection.java | 23 ++-- .../storage/MessagesCacheTest.java | 60 ++++++++++- 4 files changed, 161 insertions(+), 27 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageAvailabilityListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageAvailabilityListener.java index 3eeffc089..9c2208c24 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageAvailabilityListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageAvailabilityListener.java @@ -1,5 +1,7 @@ package org.whispersystems.textsecuregcm.storage; +import java.util.UUID; + /** * A message availability listener is notified when new messages are available for a specific device for a specific * account. Availability listeners are also notified when messages are moved from the message cache to long-term storage @@ -9,5 +11,7 @@ public interface MessageAvailabilityListener { void handleNewMessagesAvailable(); + void handleEphemeralMessageAvailable(UUID ephemeralMessageGuid); + void handleMessagesPersisted(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 940a63a49..872ca8e70 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -55,18 +55,22 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp private final Map messageListenersByQueueName = new HashMap<>(); private final Map queueNamesByMessageListener = new IdentityHashMap<>(); - private final Timer insertTimer = Metrics.timer(name(MessagesCache.class, "insert")); - private final Timer getMessagesTimer = Metrics.timer(name(MessagesCache.class, "get")); - private final Timer clearQueueTimer = Metrics.timer(name(MessagesCache.class, "clear")); - private final Counter pubSubMessageCounter = Metrics.counter(name(MessagesCache.class, "pubSubMessage")); - private final Counter newMessageNotificationCounter = Metrics.counter(name(MessagesCache.class, "newMessageNotification")); - private final Counter queuePersistedNotificationCounter = Metrics.counter(name(MessagesCache.class, "queuePersisted")); + private final Timer insertTimer = Metrics.timer(name(MessagesCache.class, "insert"), "ephemeral", "false"); + private final Timer insertEphemeralTimer = Metrics.timer(name(MessagesCache.class, "insert"), "epehmeral", "true"); + private final Timer getMessagesTimer = Metrics.timer(name(MessagesCache.class, "get")); + private final Timer clearQueueTimer = Metrics.timer(name(MessagesCache.class, "clear")); + private final Timer takeEphemeralMessageTimer = Metrics.timer(name(MessagesCache.class, "takeEphemeral")); + private final Counter pubSubMessageCounter = Metrics.counter(name(MessagesCache.class, "pubSubMessage")); + private final Counter newMessageNotificationCounter = Metrics.counter(name(MessagesCache.class, "newMessageNotification"), "ephemeral", "false"); + private final Counter ephemeralMessageNotificationCounter = Metrics.counter(name(MessagesCache.class, "newMessageNotification"), "ephemeral", "true"); + private final Counter queuePersistedNotificationCounter = Metrics.counter(name(MessagesCache.class, "queuePersisted")); static final String NEXT_SLOT_TO_PERSIST_KEY = "user_queue_persist_slot"; private static final byte[] LOCK_VALUE = "1".getBytes(StandardCharsets.UTF_8); private static final String QUEUE_KEYSPACE_PREFIX = "__keyspace@0__:user_queue::"; private static final String PERSISTING_KEYSPACE_PREFIX = "__keyspace@0__:user_queue_persisting::"; + private static final String EPHEMERAL_KEYSPACE_PREFIX = "__keyspace@0__:ephemeral_message::"; private static final String REMOVE_TIMER_NAME = name(MessagesCache.class, "remove"); @@ -123,8 +127,8 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp } } - public long insert(final UUID guid, final UUID destinationUuid, final long destinationDevice, final MessageProtos.Envelope message) { - final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build(); + public long insert(final UUID messageGuid, final UUID destinationUuid, final long destinationDevice, final MessageProtos.Envelope message) { + final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(messageGuid.toString()).build(); final String sender = message.hasSource() ? (message.getSource() + "::" + message.getTimestamp()) : "nil"; return (long)insertTimer.record(() -> @@ -134,7 +138,12 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp List.of(messageWithGuid.toByteArray(), String.valueOf(message.getTimestamp()).getBytes(StandardCharsets.UTF_8), sender.getBytes(StandardCharsets.UTF_8), - guid.toString().getBytes(StandardCharsets.UTF_8)))); + messageGuid.toString().getBytes(StandardCharsets.UTF_8)))); + } + + public void insertEphemeral(final UUID messageGuid, final UUID destinationUuid, final long destinationDevice, final MessageProtos.Envelope message) { + insertEphemeralTimer.record(() -> + redisCluster.useBinaryCluster(connection -> connection.async().setex(getEphemeralMessageKey(destinationUuid, destinationDevice, messageGuid), 10, message.toByteArray()))); } public Optional remove(final UUID destinationUuid, final long destinationDevice, final long id) { @@ -252,6 +261,33 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp }); } + public Optional takeEphemeralMessage(final UUID destinationUuid, final long destinationDevice, final UUID messageGuid) { + final byte[] ephemeralMessageKey = getEphemeralMessageKey(destinationUuid, destinationDevice, messageGuid); + + return takeEphemeralMessageTimer.record(() -> redisCluster.withBinaryCluster(connection -> { + final byte[] messageBytes = connection.sync().get(ephemeralMessageKey); + connection.sync().del(ephemeralMessageKey); + + final Optional maybeEnvelope; + + if (messageBytes != null) { + MessageProtos.Envelope parsedEnvelope = null; + + try { + parsedEnvelope = MessageProtos.Envelope.parseFrom(messageBytes); + } catch (final InvalidProtocolBufferException e) { + logger.warn("Failed to parse envelope", e); + } + + maybeEnvelope = Optional.ofNullable(parsedEnvelope); + } else { + maybeEnvelope = Optional.empty(); + } + + return maybeEnvelope; + })); + } + public void clear(final UUID destinationUuid) { // TODO Remove null check in a fully UUID-based world if (destinationUuid != null) { @@ -314,22 +350,35 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp private void subscribeForKeyspaceNotifications(final String queueName) { final int slot = SlotHash.getSlot(queueName); - pubSubConnection.usePubSubConnection(connection -> connection.sync().nodes(node -> node.is(RedisClusterNode.NodeFlag.MASTER) && node.hasSlot(slot)) - .commands() - .subscribe(QUEUE_KEYSPACE_PREFIX + "{" + queueName + "}", - PERSISTING_KEYSPACE_PREFIX + "{" + queueName + "}")); + pubSubConnection.usePubSubConnection(connection -> { + connection.sync().nodes(node -> node.is(RedisClusterNode.NodeFlag.MASTER) && node.hasSlot(slot)) + .commands() + .subscribe(QUEUE_KEYSPACE_PREFIX + "{" + queueName + "}", + PERSISTING_KEYSPACE_PREFIX + "{" + queueName + "}"); + + connection.sync().nodes(node -> node.is(RedisClusterNode.NodeFlag.MASTER) && node.hasSlot(slot)) + .commands() + .psubscribe(EPHEMERAL_KEYSPACE_PREFIX + "{" + queueName + "}::*"); + }); } private void unsubscribeFromKeyspaceNotifications(final String queueName) { - pubSubConnection.usePubSubConnection(connection -> connection.sync().masters() - .commands() - .unsubscribe(QUEUE_KEYSPACE_PREFIX + "{" + queueName + "}", - PERSISTING_KEYSPACE_PREFIX + "{" + queueName + "}")); + pubSubConnection.usePubSubConnection(connection -> { + connection.sync().masters() + .commands() + .unsubscribe(QUEUE_KEYSPACE_PREFIX + "{" + queueName + "}", + PERSISTING_KEYSPACE_PREFIX + "{" + queueName + "}"); + + connection.sync().masters() + .commands() + .punsubscribe(EPHEMERAL_KEYSPACE_PREFIX + "{" + queueName + "}::*"); + }); } @Override public void message(final RedisClusterNode node, final String channel, final String message) { pubSubMessageCounter.increment(); + if (channel.startsWith(QUEUE_KEYSPACE_PREFIX) && "zadd".equals(message)) { newMessageNotificationCounter.increment(); notificationExecutorService.execute(() -> findListener(channel).ifPresent(MessageAvailabilityListener::handleNewMessagesAvailable)); @@ -339,6 +388,18 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp } } + @Override + public void message(final RedisClusterNode node, final String pattern, final String channel, final String message) { + pubSubMessageCounter.increment(); + + if (channel.startsWith(EPHEMERAL_KEYSPACE_PREFIX) && "set".equals(message)) { + ephemeralMessageNotificationCounter.increment(); + + notificationExecutorService.execute(() -> findListener(channel).ifPresent(listener -> + listener.handleEphemeralMessageAvailable(UUID.fromString(channel.substring(channel.lastIndexOf("::") + 2))))); + } + } + private Optional findListener(final String keyspaceChannel) { final String queueName = getQueueNameFromKeyspaceChannel(keyspaceChannel); @@ -348,7 +409,7 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp } @VisibleForTesting - static OutgoingMessageEntity constructEntityFromEnvelope(long id, MessageProtos.Envelope envelope) { + static OutgoingMessageEntity constructEntityFromEnvelope(final long id, final MessageProtos.Envelope envelope) { return new OutgoingMessageEntity(id, true, envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null, envelope.getType().getNumber(), @@ -380,6 +441,10 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp return ("user_queue::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); } + private static byte[] getEphemeralMessageKey(final UUID accountUuid, final long deviceId, final UUID messageGuid) { + return ("ephemeral_message::{" + accountUuid.toString() + "::" + deviceId + "}::" + messageGuid.toString()).getBytes(StandardCharsets.UTF_8); + } + private static byte[] getMessageQueueMetadataKey(final UUID accountUuid, final long deviceId) { return ("user_queue_metadata::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index b11401317..0701ba8bd 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -34,6 +34,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Optional; +import java.util.UUID; import static com.codahale.metrics.MetricRegistry.name; import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope; @@ -42,14 +43,15 @@ import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessag @SuppressWarnings("OptionalUsedAsFieldOrParameterType") public class WebSocketConnection implements DispatchChannel, MessageAvailabilityListener, DisplacedPresenceListener { - private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); - public static final Histogram messageTime = metricRegistry.histogram(name(MessageController.class, "message_delivery_duration")); - private static final Meter sendMessageMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_message")); - private static final Meter messageAvailableMeter = metricRegistry.meter(name(WebSocketConnection.class, "messagesAvailable")); - private static final Meter messagesPersistedMeter = metricRegistry.meter(name(WebSocketConnection.class, "messagesPersisted")); - private static final Meter pubSubNewMessageMeter = metricRegistry.meter(name(WebSocketConnection.class, "pubSubNewMessage")); - private static final Meter pubSubPersistedMeter = metricRegistry.meter(name(WebSocketConnection.class, "pubSubPersisted")); - private static final Meter displacementMeter = metricRegistry.meter(name(WebSocketConnection.class, "explicitDisplacement")); + private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); + public static final Histogram messageTime = metricRegistry.histogram(name(MessageController.class, "message_delivery_duration")); + private static final Meter sendMessageMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_message")); + private static final Meter messageAvailableMeter = metricRegistry.meter(name(WebSocketConnection.class, "messagesAvailable")); + private static final Meter ephemeralMessageAvailableMeter = metricRegistry.meter(name(WebSocketConnection.class, "ephemeralMessagesAvailable")); + private static final Meter messagesPersistedMeter = metricRegistry.meter(name(WebSocketConnection.class, "messagesPersisted")); + private static final Meter pubSubNewMessageMeter = metricRegistry.meter(name(WebSocketConnection.class, "pubSubNewMessage")); + private static final Meter pubSubPersistedMeter = metricRegistry.meter(name(WebSocketConnection.class, "pubSubPersisted")); + private static final Meter displacementMeter = metricRegistry.meter(name(WebSocketConnection.class, "explicitDisplacement")); private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class); @@ -220,6 +222,11 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability messageAvailableMeter.mark(); } + @Override + public void handleEphemeralMessageAvailable(final UUID ephemeralMessageGuid) { + ephemeralMessageAvailableMeter.mark(); + } + @Override public void handleMessagesPersisted() { messagesPersistedMeter.mark(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index e001a1e5f..a3a247c8a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -24,6 +24,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -242,6 +243,10 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { } } + @Override + public void handleEphemeralMessageAvailable(final UUID ephemeralMessageGuid) { + } + @Override public void handleMessagesPersisted() { } @@ -261,13 +266,17 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { @Test(timeout = 5_000L) public void testNotifyListenerPersisted() throws InterruptedException { - final AtomicBoolean notified = new AtomicBoolean(false); + final AtomicBoolean notified = new AtomicBoolean(false); final MessageAvailabilityListener listener = new MessageAvailabilityListener() { @Override public void handleNewMessagesAvailable() { } + @Override + public void handleEphemeralMessageAvailable(final UUID ephemeralMessageGuid) { + } + @Override public void handleMessagesPersisted() { synchronized (notified) { @@ -290,4 +299,53 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { assertTrue(notified.get()); } + + @Test(timeout = 5_000L) + public void testInsertAndNotifyEphemeralMessage() throws InterruptedException { + final AtomicReference notifiedGuid = new AtomicReference<>(); + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true); + + final MessageAvailabilityListener listener = new MessageAvailabilityListener() { + @Override + public void handleNewMessagesAvailable() { + } + + @Override + public void handleEphemeralMessageAvailable(final UUID ephemeralMessageGuid) { + synchronized (notifiedGuid) { + notifiedGuid.set(ephemeralMessageGuid); + notifiedGuid.notifyAll(); + } + } + + @Override + public void handleMessagesPersisted() { + } + }; + + messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener); + messagesCache.insertEphemeral(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + + synchronized (notifiedGuid) { + while (notifiedGuid.get() == null) { + notifiedGuid.wait(); + } + } + + assertEquals(messageGuid, notifiedGuid.get()); + } + + @Test + public void testTakeEphemeralMessage() { + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true); + + assertEquals(Optional.empty(), messagesCache.takeEphemeralMessage(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid)); + + messagesCache.insertEphemeral(messageGuid, DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + + assertEquals(Optional.of(message), messagesCache.takeEphemeralMessage(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid)); + assertEquals(Optional.empty(), messagesCache.takeEphemeralMessage(DESTINATION_UUID, DESTINATION_DEVICE_ID, messageGuid)); + } }