diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageAvailabilityListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageAvailabilityListener.java index 3eeffc089..c8a40526d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageAvailabilityListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessageAvailabilityListener.java @@ -9,5 +9,7 @@ public interface MessageAvailabilityListener { void handleNewMessagesAvailable(); + void handleNewEphemeralMessageAvailable(); + void handleMessagesPersisted(); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java index 940a63a49..ba0d92ea8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/MessagesCache.java @@ -22,6 +22,7 @@ import org.whispersystems.textsecuregcm.util.RedisClusterUtil; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.Collections; @@ -55,18 +56,24 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp private final Map messageListenersByQueueName = new HashMap<>(); private final Map queueNamesByMessageListener = new IdentityHashMap<>(); - private final Timer insertTimer = Metrics.timer(name(MessagesCache.class, "insert")); - private final Timer getMessagesTimer = Metrics.timer(name(MessagesCache.class, "get")); - private final Timer clearQueueTimer = Metrics.timer(name(MessagesCache.class, "clear")); - private final Counter pubSubMessageCounter = Metrics.counter(name(MessagesCache.class, "pubSubMessage")); - private final Counter newMessageNotificationCounter = Metrics.counter(name(MessagesCache.class, "newMessageNotification")); - private final Counter queuePersistedNotificationCounter = Metrics.counter(name(MessagesCache.class, "queuePersisted")); + private final Timer insertTimer = Metrics.timer(name(MessagesCache.class, "insert"), "ephemeral", "false"); + private final Timer insertEphemeralTimer = Metrics.timer(name(MessagesCache.class, "insert"), "ephemeral", "true"); + private final Timer getMessagesTimer = Metrics.timer(name(MessagesCache.class, "get")); + private final Timer clearQueueTimer = Metrics.timer(name(MessagesCache.class, "clear")); + private final Timer takeEphemeralMessageTimer = Metrics.timer(name(MessagesCache.class, "takeEphemeral")); + private final Counter pubSubMessageCounter = Metrics.counter(name(MessagesCache.class, "pubSubMessage")); + private final Counter newMessageNotificationCounter = Metrics.counter(name(MessagesCache.class, "newMessageNotification"), "ephemeral", "false"); + private final Counter ephemeralMessageNotificationCounter = Metrics.counter(name(MessagesCache.class, "newMessageNotification"), "ephemeral", "true"); + private final Counter queuePersistedNotificationCounter = Metrics.counter(name(MessagesCache.class, "queuePersisted")); static final String NEXT_SLOT_TO_PERSIST_KEY = "user_queue_persist_slot"; private static final byte[] LOCK_VALUE = "1".getBytes(StandardCharsets.UTF_8); - private static final String QUEUE_KEYSPACE_PREFIX = "__keyspace@0__:user_queue::"; - private static final String PERSISTING_KEYSPACE_PREFIX = "__keyspace@0__:user_queue_persisting::"; + private static final String QUEUE_KEYSPACE_PREFIX = "__keyspace@0__:user_queue::"; + private static final String EPHEMERAL_QUEUE_KEYSPACE_PREFIX = "__keyspace@0__:user_queue_ephemeral::"; + private static final String PERSISTING_KEYSPACE_PREFIX = "__keyspace@0__:user_queue_persisting::"; + + private static final Duration MAX_EPHEMERAL_MESSAGE_DELAY = Duration.ofSeconds(10); private static final String REMOVE_TIMER_NAME = name(MessagesCache.class, "remove"); @@ -137,6 +144,17 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp guid.toString().getBytes(StandardCharsets.UTF_8)))); } + public void insertEphemeral(final UUID destinationUuid, final long destinationDevice, final MessageProtos.Envelope message) { + insertEphemeralTimer.record(() -> { + final byte[] ephemeralQueueKey = getEphemeralMessageQueueKey(destinationUuid, destinationDevice); + + redisCluster.useBinaryCluster(connection -> { + connection.async().rpush(ephemeralQueueKey, message.toByteArray()); + connection.async().expire(ephemeralQueueKey, MAX_EPHEMERAL_MESSAGE_DELAY.toSeconds()); + }); + }); + } + public Optional remove(final UUID destinationUuid, final long destinationDevice, final long id) { try { final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_ID).record(() -> @@ -252,6 +270,33 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp }); } + public Optional takeEphemeralMessage(final UUID destinationUuid, final long destinationDevice) { + return takeEphemeralMessage(destinationUuid, destinationDevice, System.currentTimeMillis()); + } + + @VisibleForTesting + Optional takeEphemeralMessage(final UUID destinationUuid, final long destinationDevice, final long currentTimeMillis) { + final long earliestAllowableTimestamp = currentTimeMillis - MAX_EPHEMERAL_MESSAGE_DELAY.toMillis(); + + return takeEphemeralMessageTimer.record(() -> redisCluster.withBinaryCluster(connection -> { + byte[] messageBytes; + + while ((messageBytes = connection.sync().lpop(getEphemeralMessageQueueKey(destinationUuid, destinationDevice))) != null) { + try { + final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(messageBytes); + + if (message.getTimestamp() >= earliestAllowableTimestamp) { + return Optional.of(message); + } + } catch (final InvalidProtocolBufferException e) { + logger.warn("Failed to parse envelope", e); + } + } + + return Optional.empty(); + })); + } + public void clear(final UUID destinationUuid) { // TODO Remove null check in a fully UUID-based world if (destinationUuid != null) { @@ -316,23 +361,33 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp pubSubConnection.usePubSubConnection(connection -> connection.sync().nodes(node -> node.is(RedisClusterNode.NodeFlag.MASTER) && node.hasSlot(slot)) .commands() - .subscribe(QUEUE_KEYSPACE_PREFIX + "{" + queueName + "}", - PERSISTING_KEYSPACE_PREFIX + "{" + queueName + "}")); + .subscribe(getKeyspaceChannels(queueName))); } private void unsubscribeFromKeyspaceNotifications(final String queueName) { pubSubConnection.usePubSubConnection(connection -> connection.sync().masters() .commands() - .unsubscribe(QUEUE_KEYSPACE_PREFIX + "{" + queueName + "}", - PERSISTING_KEYSPACE_PREFIX + "{" + queueName + "}")); + .unsubscribe(getKeyspaceChannels(queueName))); + } + + private static String[] getKeyspaceChannels(final String queueName) { + return new String[] { + QUEUE_KEYSPACE_PREFIX + "{" + queueName + "}", + EPHEMERAL_QUEUE_KEYSPACE_PREFIX + "{" + queueName + "}", + PERSISTING_KEYSPACE_PREFIX + "{" + queueName + "}" + }; } @Override public void message(final RedisClusterNode node, final String channel, final String message) { pubSubMessageCounter.increment(); + if (channel.startsWith(QUEUE_KEYSPACE_PREFIX) && "zadd".equals(message)) { newMessageNotificationCounter.increment(); notificationExecutorService.execute(() -> findListener(channel).ifPresent(MessageAvailabilityListener::handleNewMessagesAvailable)); + } else if (channel.startsWith(EPHEMERAL_QUEUE_KEYSPACE_PREFIX) && "rpush".equals(message)) { + ephemeralMessageNotificationCounter.increment(); + notificationExecutorService.execute(() -> findListener(channel).ifPresent(MessageAvailabilityListener::handleNewEphemeralMessageAvailable)); } else if (channel.startsWith(PERSISTING_KEYSPACE_PREFIX) && "del".equals(message)) { queuePersistedNotificationCounter.increment(); notificationExecutorService.execute(() -> findListener(channel).ifPresent(MessageAvailabilityListener::handleMessagesPersisted)); @@ -380,6 +435,10 @@ public class MessagesCache extends RedisClusterPubSubAdapter imp return ("user_queue::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); } + static byte[] getEphemeralMessageQueueKey(final UUID accountUuid, final long deviceId) { + return ("user_queue_ephemeral::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); + } + private static byte[] getMessageQueueMetadataKey(final UUID accountUuid, final long deviceId) { return ("user_queue_metadata::{" + accountUuid.toString() + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java index b11401317..0367a0c55 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/websocket/WebSocketConnection.java @@ -42,14 +42,15 @@ import static org.whispersystems.textsecuregcm.storage.PubSubProtos.PubSubMessag @SuppressWarnings("OptionalUsedAsFieldOrParameterType") public class WebSocketConnection implements DispatchChannel, MessageAvailabilityListener, DisplacedPresenceListener { - private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); - public static final Histogram messageTime = metricRegistry.histogram(name(MessageController.class, "message_delivery_duration")); - private static final Meter sendMessageMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_message")); - private static final Meter messageAvailableMeter = metricRegistry.meter(name(WebSocketConnection.class, "messagesAvailable")); - private static final Meter messagesPersistedMeter = metricRegistry.meter(name(WebSocketConnection.class, "messagesPersisted")); - private static final Meter pubSubNewMessageMeter = metricRegistry.meter(name(WebSocketConnection.class, "pubSubNewMessage")); - private static final Meter pubSubPersistedMeter = metricRegistry.meter(name(WebSocketConnection.class, "pubSubPersisted")); - private static final Meter displacementMeter = metricRegistry.meter(name(WebSocketConnection.class, "explicitDisplacement")); + private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME); + public static final Histogram messageTime = metricRegistry.histogram(name(MessageController.class, "message_delivery_duration")); + private static final Meter sendMessageMeter = metricRegistry.meter(name(WebSocketConnection.class, "send_message")); + private static final Meter messageAvailableMeter = metricRegistry.meter(name(WebSocketConnection.class, "messagesAvailable")); + private static final Meter ephemeralMessageAvailableMeter = metricRegistry.meter(name(WebSocketConnection.class, "ephemeralMessagesAvailable")); + private static final Meter messagesPersistedMeter = metricRegistry.meter(name(WebSocketConnection.class, "messagesPersisted")); + private static final Meter pubSubNewMessageMeter = metricRegistry.meter(name(WebSocketConnection.class, "pubSubNewMessage")); + private static final Meter pubSubPersistedMeter = metricRegistry.meter(name(WebSocketConnection.class, "pubSubPersisted")); + private static final Meter displacementMeter = metricRegistry.meter(name(WebSocketConnection.class, "explicitDisplacement")); private static final Logger logger = LoggerFactory.getLogger(WebSocketConnection.class); @@ -220,6 +221,11 @@ public class WebSocketConnection implements DispatchChannel, MessageAvailability messageAvailableMeter.mark(); } + @Override + public void handleNewEphemeralMessageAvailable() { + ephemeralMessageAvailableMeter.mark(); + } + @Override public void handleMessagesPersisted() { messagesPersistedMeter.mark(); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java index e001a1e5f..a0d14526a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -45,7 +45,7 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { public void setUp() throws Exception { super.setUp(); - getRedisCluster().useCluster(connection -> connection.sync().masters().commands().configSet("notify-keyspace-events", "K$gz")); + getRedisCluster().useCluster(connection -> connection.sync().masters().commands().configSet("notify-keyspace-events", "Klgz")); notificationExecutorService = Executors.newSingleThreadExecutor(); messagesCache = new MessagesCache(getRedisCluster(), notificationExecutorService); @@ -171,10 +171,14 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { assertEquals(Collections.emptyList(), messagesCache.get(DESTINATION_UUID, DESTINATION_DEVICE_ID + 1, messageCount)); } - protected MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final boolean sealedSender) { + private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final boolean sealedSender) { + return generateRandomMessage(messageGuid, sealedSender, serialTimestamp++); + } + + private MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final boolean sealedSender, final long timestamp) { final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder() - .setTimestamp(serialTimestamp++) - .setServerTimestamp(serialTimestamp++) + .setTimestamp(timestamp) + .setServerTimestamp(timestamp) .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) .setType(MessageProtos.Envelope.Type.CIPHERTEXT) .setServerGuid(messageGuid.toString()); @@ -242,6 +246,10 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { } } + @Override + public void handleNewEphemeralMessageAvailable() { + } + @Override public void handleMessagesPersisted() { } @@ -261,13 +269,17 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { @Test(timeout = 5_000L) public void testNotifyListenerPersisted() throws InterruptedException { - final AtomicBoolean notified = new AtomicBoolean(false); + final AtomicBoolean notified = new AtomicBoolean(false); final MessageAvailabilityListener listener = new MessageAvailabilityListener() { @Override public void handleNewMessagesAvailable() { } + @Override + public void handleNewEphemeralMessageAvailable() { + } + @Override public void handleMessagesPersisted() { synchronized (notified) { @@ -290,4 +302,57 @@ public class MessagesCacheTest extends AbstractRedisClusterTest { assertTrue(notified.get()); } + + @Test(timeout = 5_000L) + public void testInsertAndNotifyEphemeralMessage() throws InterruptedException { + final AtomicBoolean notified = new AtomicBoolean(false); + final MessageProtos.Envelope message = generateRandomMessage(UUID.randomUUID(), true); + + final MessageAvailabilityListener listener = new MessageAvailabilityListener() { + @Override + public void handleNewMessagesAvailable() { + } + + @Override + public void handleNewEphemeralMessageAvailable() { + synchronized (notified) { + notified.set(true); + notified.notifyAll(); + } + } + + @Override + public void handleMessagesPersisted() { + } + }; + + messagesCache.addMessageAvailabilityListener(DESTINATION_UUID, DESTINATION_DEVICE_ID, listener); + messagesCache.insertEphemeral(DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + + synchronized (notified) { + while (!notified.get()) { + notified.wait(); + } + } + + assertTrue(notified.get()); + } + + @Test + public void testTakeEphemeralMessage() { + final long currentTime = System.currentTimeMillis(); + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, true, currentTime); + + assertEquals(Optional.empty(), messagesCache.takeEphemeralMessage(DESTINATION_UUID, DESTINATION_DEVICE_ID, currentTime)); + + messagesCache.insertEphemeral(DESTINATION_UUID, DESTINATION_DEVICE_ID, message); + + assertEquals(Optional.of(message), messagesCache.takeEphemeralMessage(DESTINATION_UUID, DESTINATION_DEVICE_ID, currentTime)); + assertEquals(Optional.empty(), messagesCache.takeEphemeralMessage(DESTINATION_UUID, DESTINATION_DEVICE_ID, currentTime)); + + messagesCache.insertEphemeral(DESTINATION_UUID, DESTINATION_DEVICE_ID, generateRandomMessage(UUID.randomUUID(), true, 0)); + + assertEquals(Optional.empty(), messagesCache.takeEphemeralMessage(DESTINATION_UUID, DESTINATION_DEVICE_ID, currentTime)); + } }