diff --git a/service/config/sample.yml b/service/config/sample.yml
index efbf66102..dc86f1da0 100644
--- a/service/config/sample.yml
+++ b/service/config/sample.yml
@@ -47,8 +47,13 @@ directory:
reconciliationChunkIntervalMs: # CDS reconciliation chunk interval, in milliseconds
messageCache: # Redis server configuration for message store cache
- url:
- replicaUrls:
+ redis:
+ url:
+ replicaUrls:
+
+ cluster:
+ urls:
+ - redis://redis.example.com:6379/
messageStore: # Postgresql database configuration for message store
driverClass: org.postgresql.Driver
diff --git a/service/pom.xml b/service/pom.xml
index d3d5f9f82..d15272975 100644
--- a/service/pom.xml
+++ b/service/pom.xml
@@ -198,6 +198,12 @@
test
+
+ pl.pragmatists
+ JUnitParams
+ 1.1.1
+ test
+
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
index 5634ecab9..b84b97607 100644
--- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java
@@ -50,7 +50,6 @@ import io.micrometer.datadog.DatadogMeterRegistry;
import io.micrometer.wavefront.WavefrontConfig;
import io.micrometer.wavefront.WavefrontMeterRegistry;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
-import org.coursera.metrics.datadog.DatadogReporter;
import org.eclipse.jetty.servlets.CrossOriginFilter;
import org.jdbi.v3.core.Jdbi;
import org.signal.zkgroup.ServerSecretParams;
@@ -137,6 +136,7 @@ import org.whispersystems.textsecuregcm.storage.Profiles;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PushFeedbackProcessor;
+import org.whispersystems.textsecuregcm.storage.RedisClusterMessagesCache;
import org.whispersystems.textsecuregcm.storage.RemoteConfigs;
import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager;
import org.whispersystems.textsecuregcm.storage.ReservedUsernames;
@@ -300,7 +300,8 @@ public class WhisperServerService extends Application(1_000));
+
+ private final Experiment insertExperiment = new Experiment("MessagesCache", "insert");
+ private final Experiment removeByIdExperiment = new Experiment("MessagesCache", "removeById");
+ private final Experiment removeBySenderExperiment = new Experiment("MessagesCache", "removeBySender");
+ private final Experiment removeByUuidExperiment = new Experiment("MessagesCache", "removeByUuid");
+ private final Experiment getMessagesExperiment = new Experiment("MessagesCache", "getMessages");
+
+ public MessagesCache(ReplicatedJedisPool jedisPool, Messages database, AccountsManager accountsManager, int delayMinutes, RedisClusterMessagesCache clusterMessagesCache) throws IOException {
+ this.jedisPool = jedisPool;
+ this.database = database;
+ this.accountsManager = accountsManager;
+ this.delayMinutes = delayMinutes;
+
+ this.insertOperation = new InsertOperation(jedisPool);
+ this.removeOperation = new RemoveOperation(jedisPool);
+ this.getOperation = new GetOperation(jedisPool);
+
+ this.clusterMessagesCache = clusterMessagesCache;
}
- public void insert(UUID guid, String destination, long destinationDevice, Envelope message) {
- message = message.toBuilder().setServerGuid(guid.toString()).build();
+ @Override
+ public long insert(UUID guid, String destination, long destinationDevice, Envelope message) {
+ final Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build();
Timer.Context timer = insertTimer.time();
try {
- insertOperation.insert(guid, destination, destinationDevice, System.currentTimeMillis(), message);
+ final long messageId = insertOperation.insert(guid, destination, destinationDevice, System.currentTimeMillis(), messageWithGuid);
+ insertExperiment.compareSupplierResultAsync(messageId, () -> clusterMessagesCache.insert(guid, destination, destinationDevice, message, messageId), experimentExecutor);
+
+ return messageId;
} finally {
timer.stop();
}
}
- public void remove(String destination, long destinationDevice, long id) {
+ @Override
+ public Optional remove(String destination, long destinationDevice, long id) {
+ OutgoingMessageEntity removedMessageEntity = null;
+
try (Jedis jedis = jedisPool.getWriteResource();
Timer.Context ignored = removeByIdTimer.time())
{
- removeOperation.remove(jedis, destination, destinationDevice, id);
+ byte[] serialized = removeOperation.remove(jedis, destination, destinationDevice, id);
+
+ if (serialized != null) {
+ removedMessageEntity = UserMessagesCache.constructEntityFromEnvelope(id, Envelope.parseFrom(serialized));
+ }
+ } catch (InvalidProtocolBufferException e) {
+ logger.warn("Failed to parse envelope", e);
}
+
+ final Optional maybeRemovedMessage = Optional.ofNullable(removedMessageEntity);
+
+ removeByIdExperiment.compareSupplierResultAsync(maybeRemovedMessage, () -> clusterMessagesCache.remove(destination, destinationDevice, id), experimentExecutor);
+
+ return maybeRemovedMessage;
}
+ @Override
public Optional remove(String destination, long destinationDevice, String sender, long timestamp) {
+ OutgoingMessageEntity removedMessageEntity = null;
Timer.Context timer = removeByNameTimer.time();
try {
byte[] serialized = removeOperation.remove(destination, destinationDevice, sender, timestamp);
if (serialized != null) {
- Envelope envelope = Envelope.parseFrom(serialized);
- return Optional.of(constructEntityFromEnvelope(0, envelope));
+ removedMessageEntity = UserMessagesCache.constructEntityFromEnvelope(0, Envelope.parseFrom(serialized));
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
@@ -105,18 +144,23 @@ public class MessagesCache implements Managed {
timer.stop();
}
- return Optional.empty();
+ final Optional maybeRemovedMessage = Optional.ofNullable(removedMessageEntity);
+
+ removeBySenderExperiment.compareSupplierResultAsync(maybeRemovedMessage, () -> clusterMessagesCache.remove(destination, destinationDevice, sender, timestamp), experimentExecutor);
+
+ return maybeRemovedMessage;
}
+ @Override
public Optional remove(String destination, long destinationDevice, UUID guid) {
+ OutgoingMessageEntity removedMessageEntity = null;
Timer.Context timer = removeByGuidTimer.time();
try {
byte[] serialized = removeOperation.remove(destination, destinationDevice, guid);
if (serialized != null) {
- Envelope envelope = Envelope.parseFrom(serialized);
- return Optional.of(constructEntityFromEnvelope(0, envelope));
+ removedMessageEntity = UserMessagesCache.constructEntityFromEnvelope(0, Envelope.parseFrom(serialized));
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
@@ -124,9 +168,14 @@ public class MessagesCache implements Managed {
timer.stop();
}
- return Optional.empty();
+ final Optional maybeRemovedMessage = Optional.ofNullable(removedMessageEntity);
+
+ removeByUuidExperiment.compareSupplierResultAsync(maybeRemovedMessage, () -> clusterMessagesCache.remove(destination, destinationDevice, guid), experimentExecutor);
+
+ return maybeRemovedMessage;
}
+ @Override
public List get(String destination, long destinationDevice, int limit) {
Timer.Context timer = getTimer.time();
@@ -139,18 +188,21 @@ public class MessagesCache implements Managed {
try {
long id = item.second().longValue();
Envelope message = Envelope.parseFrom(item.first());
- results.add(constructEntityFromEnvelope(id, message));
+ results.add(UserMessagesCache.constructEntityFromEnvelope(id, message));
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}
+ getMessagesExperiment.compareSupplierResultAsync(results, () -> clusterMessagesCache.get(destination, destinationDevice, limit), experimentExecutor);
+
return results;
} finally {
timer.stop();
}
}
+ @Override
public void clear(String destination) {
Timer.Context timer = clearAccountTimer.time();
@@ -163,6 +215,7 @@ public class MessagesCache implements Managed {
}
}
+ @Override
public void clear(String destination, long deviceId) {
Timer.Context timer = clearDeviceTimer.time();
@@ -180,11 +233,7 @@ public class MessagesCache implements Managed {
@Override
public void start() throws Exception {
- this.insertOperation = new InsertOperation(jedisPool);
- this.removeOperation = new RemoveOperation(jedisPool);
- this.getOperation = new GetOperation(jedisPool);
this.messagePersister = new MessagePersister(jedisPool, database, pubSubManager, pushSender, accountsManager, delayMinutes, TimeUnit.MINUTES);
-
this.messagePersister.start();
}
@@ -192,20 +241,8 @@ public class MessagesCache implements Managed {
public void stop() throws Exception {
messagePersister.shutdown();
logger.info("Message persister shut down...");
- }
- private OutgoingMessageEntity constructEntityFromEnvelope(long id, Envelope envelope) {
- return new OutgoingMessageEntity(id, true,
- envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null,
- envelope.getType().getNumber(),
- envelope.getRelay(),
- envelope.getTimestamp(),
- envelope.getSource(),
- envelope.hasSourceUuid() ? UUID.fromString(envelope.getSourceUuid()) : null,
- envelope.getSourceDevice(),
- envelope.hasLegacyMessage() ? envelope.getLegacyMessage().toByteArray() : null,
- envelope.hasContent() ? envelope.getContent().toByteArray() : null,
- envelope.hasServerTimestamp() ? envelope.getServerTimestamp() : 0);
+ this.experimentExecutor.shutdown();
}
private static class Key {
@@ -271,14 +308,14 @@ public class MessagesCache implements Managed {
this.insert = LuaScript.fromResource(jedisPool, "lua/insert_item.lua");
}
- public void insert(UUID guid, String destination, long destinationDevice, long timestamp, Envelope message) {
+ public long insert(UUID guid, String destination, long destinationDevice, long timestamp, Envelope message) {
Key key = new Key(destination, destinationDevice);
String sender = message.hasSource() ? (message.getSource() + "::" + message.getTimestamp()) : "nil";
List keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex());
List args = Arrays.asList(message.toByteArray(), String.valueOf(timestamp).getBytes(), sender.getBytes(), guid.toString().getBytes());
- insert.execute(keys, args);
+ return (long)insert.execute(keys, args);
}
}
@@ -296,13 +333,13 @@ public class MessagesCache implements Managed {
this.removeQueue = LuaScript.fromResource(jedisPool, "lua/remove_queue.lua" );
}
- public void remove(Jedis jedis, String destination, long destinationDevice, long id) {
+ public byte[] remove(Jedis jedis, String destination, long destinationDevice, long id) {
Key key = new Key(destination, destinationDevice);
List keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex());
List args = Collections.singletonList(String.valueOf(id).getBytes());
- this.removeById.execute(jedis, keys, args);
+ return (byte[])this.removeById.execute(jedis, keys, args);
}
public byte[] remove(String destination, long destinationDevice, String sender, long timestamp) {
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCache.java
new file mode 100644
index 000000000..c6e2c8857
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCache.java
@@ -0,0 +1,206 @@
+package org.whispersystems.textsecuregcm.storage;
+
+import com.google.protobuf.InvalidProtocolBufferException;
+import io.lettuce.core.ScriptOutputType;
+import io.micrometer.core.instrument.Metrics;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.textsecuregcm.entities.MessageProtos;
+import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
+import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
+import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
+import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.UUID;
+
+import static com.codahale.metrics.MetricRegistry.name;
+
+public class RedisClusterMessagesCache implements UserMessagesCache {
+
+ private final ClusterLuaScript insertScript;
+ private final ClusterLuaScript removeByIdScript;
+ private final ClusterLuaScript removeBySenderScript;
+ private final ClusterLuaScript removeByGuidScript;
+ private final ClusterLuaScript getItemsScript;
+ private final ClusterLuaScript removeQueueScript;
+
+ private static final String INSERT_TIMER_NAME = name(RedisClusterMessagesCache.class, "insert");
+ private static final String REMOVE_TIMER_NAME = name(RedisClusterMessagesCache.class, "remove");
+ private static final String GET_TIMER_NAME = name(RedisClusterMessagesCache.class, "get");
+ private static final String CLEAR_TIMER_NAME = name(RedisClusterMessagesCache.class, "clear");
+
+ private static final String REMOVE_METHOD_TAG = "method";
+ private static final String REMOVE_METHOD_ID = "id";
+ private static final String REMOVE_METHOD_SENDER = "sender";
+ private static final String REMOVE_METHOD_UUID = "uuid";
+
+ private static final Logger logger = LoggerFactory.getLogger(RedisClusterMessagesCache.class);
+
+ public RedisClusterMessagesCache(final FaultTolerantRedisCluster redisCluster) throws IOException {
+
+ this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER);
+ this.removeByIdScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_id.lua", ScriptOutputType.VALUE);
+ this.removeBySenderScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_sender.lua", ScriptOutputType.VALUE);
+ this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua", ScriptOutputType.VALUE);
+ this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.MULTI);
+ this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua", ScriptOutputType.STATUS);
+ }
+
+ @Override
+ public long insert(final UUID guid, final String destination, final long destinationDevice, final MessageProtos.Envelope message) {
+ final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build();
+ final String sender = message.hasSource() ? (message.getSource() + "::" + message.getTimestamp()) : "nil";
+
+ return (long)Metrics.timer(INSERT_TIMER_NAME).record(() ->
+ insertScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice),
+ getMessageQueueMetadataKey(destination, destinationDevice),
+ getQueueIndexKey(destination, destinationDevice)),
+ List.of(messageWithGuid.toByteArray(),
+ String.valueOf(message.getTimestamp()).getBytes(StandardCharsets.UTF_8),
+ sender.getBytes(StandardCharsets.UTF_8),
+ guid.toString().getBytes(StandardCharsets.UTF_8))));
+ }
+
+ public long insert(final UUID guid, final String destination, final long destinationDevice, final MessageProtos.Envelope message, final long messageId) {
+ final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build();
+ final String sender = message.hasSource() ? (message.getSource() + "::" + message.getTimestamp()) : "nil";
+
+ return (long)Metrics.timer(INSERT_TIMER_NAME).record(() ->
+ insertScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice),
+ getMessageQueueMetadataKey(destination, destinationDevice),
+ getQueueIndexKey(destination, destinationDevice)),
+ List.of(messageWithGuid.toByteArray(),
+ String.valueOf(message.getTimestamp()).getBytes(StandardCharsets.UTF_8),
+ sender.getBytes(StandardCharsets.UTF_8),
+ guid.toString().getBytes(StandardCharsets.UTF_8),
+ String.valueOf(messageId).getBytes(StandardCharsets.UTF_8))));
+ }
+
+ @Override
+ public Optional remove(final String destination, final long destinationDevice, final long id) {
+ try {
+ final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_ID).record(() ->
+ removeByIdScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice),
+ getMessageQueueMetadataKey(destination, destinationDevice),
+ getQueueIndexKey(destination, destinationDevice)),
+ List.of(String.valueOf(id).getBytes(StandardCharsets.UTF_8))));
+
+
+ if (serialized != null) {
+ return Optional.of(UserMessagesCache.constructEntityFromEnvelope(id, MessageProtos.Envelope.parseFrom(serialized)));
+ }
+ } catch (final InvalidProtocolBufferException e) {
+ logger.warn("Failed to parse envelope", e);
+ }
+
+ return Optional.empty();
+ }
+
+ @Override
+ public Optional remove(final String destination, final long destinationDevice, final String sender, final long timestamp) {
+ try {
+ final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_SENDER).record(() ->
+ removeBySenderScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice),
+ getMessageQueueMetadataKey(destination, destinationDevice),
+ getQueueIndexKey(destination, destinationDevice)),
+ List.of((sender + "::" + timestamp).getBytes(StandardCharsets.UTF_8))));
+
+ if (serialized != null) {
+ return Optional.of(UserMessagesCache.constructEntityFromEnvelope(0, MessageProtos.Envelope.parseFrom(serialized)));
+ }
+ } catch (final InvalidProtocolBufferException e) {
+ logger.warn("Failed to parse envelope", e);
+ }
+
+ return Optional.empty();
+ }
+
+ @Override
+ public Optional remove(final String destination, final long destinationDevice, final UUID guid) {
+ try {
+ final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_UUID).record(() ->
+ removeByGuidScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice),
+ getMessageQueueMetadataKey(destination, destinationDevice),
+ getQueueIndexKey(destination, destinationDevice)),
+ List.of(guid.toString().getBytes(StandardCharsets.UTF_8))));
+
+ if (serialized != null) {
+ return Optional.of(UserMessagesCache.constructEntityFromEnvelope(0, MessageProtos.Envelope.parseFrom(serialized)));
+ }
+ } catch (final InvalidProtocolBufferException e) {
+ logger.warn("Failed to parse envelope", e);
+ }
+
+ return Optional.empty();
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public List get(String destination, long destinationDevice, int limit) {
+ return Metrics.timer(GET_TIMER_NAME).record(() -> {
+ final List queueItems = (List)getItemsScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice),
+ getPersistInProgressKey(destination, destinationDevice)),
+ List.of(String.valueOf(limit).getBytes()));
+
+ final List messageEntities;
+
+ if (queueItems.size() % 2 == 0) {
+ messageEntities = new ArrayList<>(queueItems.size() / 2);
+
+ for (int i = 0; i < queueItems.size() - 1; i += 2) {
+ try {
+ final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(queueItems.get(i));
+ final long id = Long.parseLong(new String(queueItems.get(i + 1), StandardCharsets.UTF_8));
+
+ messageEntities.add(UserMessagesCache.constructEntityFromEnvelope(id, message));
+ } catch (InvalidProtocolBufferException e) {
+ logger.warn("Failed to parse envelope", e);
+ }
+ }
+ } else {
+ logger.error("\"Get messages\" operation returned a list with a non-even number of elements.");
+ messageEntities = Collections.emptyList();
+ }
+
+ return messageEntities;
+ });
+ }
+
+ @Override
+ public void clear(final String destination) {
+ for (int i = 1; i < 256; i++) {
+ clear(destination, i);
+ }
+ }
+
+ @Override
+ public void clear(final String destination, final long deviceId) {
+ Metrics.timer(CLEAR_TIMER_NAME).record(() ->
+ removeQueueScript.executeBinary(List.of(getMessageQueueKey(destination, deviceId),
+ getMessageQueueMetadataKey(destination, deviceId),
+ getQueueIndexKey(destination, deviceId)),
+ Collections.emptyList()));
+ }
+
+ private static byte[] getMessageQueueKey(final String address, final long deviceId) {
+ return ("user_queue::{" + address + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
+ }
+
+ private static byte[] getMessageQueueMetadataKey(final String address, final long deviceId) {
+ return ("user_queue_metadata::{" + address + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
+ }
+
+ private byte[] getQueueIndexKey(final String address, final long deviceId) {
+ return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(address + "::" + deviceId) + "}").getBytes(StandardCharsets.UTF_8);
+ }
+
+ private byte[] getPersistInProgressKey(final String address, final long deviceId) {
+ return ("user_queue_persisting::{" + address + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
+ }
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/UserMessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/UserMessagesCache.java
new file mode 100644
index 000000000..f3f813233
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/UserMessagesCache.java
@@ -0,0 +1,41 @@
+package org.whispersystems.textsecuregcm.storage;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.whispersystems.textsecuregcm.entities.MessageProtos;
+import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
+import org.whispersystems.textsecuregcm.push.PushSender;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.UUID;
+
+public interface UserMessagesCache {
+ @VisibleForTesting
+ static OutgoingMessageEntity constructEntityFromEnvelope(long id, MessageProtos.Envelope envelope) {
+ return new OutgoingMessageEntity(id, true,
+ envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null,
+ envelope.getType().getNumber(),
+ envelope.getRelay(),
+ envelope.getTimestamp(),
+ envelope.getSource(),
+ envelope.hasSourceUuid() ? UUID.fromString(envelope.getSourceUuid()) : null,
+ envelope.getSourceDevice(),
+ envelope.hasLegacyMessage() ? envelope.getLegacyMessage().toByteArray() : null,
+ envelope.hasContent() ? envelope.getContent().toByteArray() : null,
+ envelope.hasServerTimestamp() ? envelope.getServerTimestamp() : 0);
+ }
+
+ long insert(UUID guid, String destination, long destinationDevice, MessageProtos.Envelope message);
+
+ Optional remove(String destination, long destinationDevice, long id);
+
+ Optional remove(String destination, long destinationDevice, String sender, long timestamp);
+
+ Optional remove(String destination, long destinationDevice, UUID guid);
+
+ List get(String destination, long destinationDevice, int limit);
+
+ void clear(String destination);
+
+ void clear(String destination, long deviceId);
+}
diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java
new file mode 100644
index 000000000..bb02d56dd
--- /dev/null
+++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java
@@ -0,0 +1,37 @@
+package org.whispersystems.textsecuregcm.util;
+
+import io.lettuce.core.cluster.SlotHash;
+import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
+import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
+
+public class RedisClusterUtil {
+
+ private static final String[] HASHES_BY_SLOT = new String[SlotHash.SLOT_COUNT];
+
+ static {
+ int slotsCovered = 0;
+ int i = 0;
+
+ while (slotsCovered < HASHES_BY_SLOT.length) {
+ final String hash = Integer.toString(i++, 36);
+ final int slot = SlotHash.getSlot(hash);
+
+ if (HASHES_BY_SLOT[slot] == null) {
+ HASHES_BY_SLOT[slot] = hash;
+ slotsCovered += 1;
+ }
+ }
+ }
+
+ /**
+ * Returns a short Redis hash tag that maps to the same Redis cluster slot as the given key.
+ *
+ * @param key the key for which to find a matching hash tag
+ * @return a Redis hash tag that maps to the same Redis cluster slot as the given key
+ *
+ * @see Redis Cluster Specification - Keys hash tags
+ */
+ public static String getMinimalHashTag(final String key) {
+ return HASHES_BY_SLOT[SlotHash.getSlot(key)];
+ }
+}
diff --git a/service/src/main/resources/lua/insert_item.lua b/service/src/main/resources/lua/insert_item.lua
index 51c4dcaed..70ef4d248 100644
--- a/service/src/main/resources/lua/insert_item.lua
+++ b/service/src/main/resources/lua/insert_item.lua
@@ -1,7 +1,16 @@
-- keys: queue_key [1], queue_metadata_key [2], queue_total_index [3]
--- argv: message [1], current_time [2], sender (possibly null) [3], guid [4]
+-- argv: message [1], current_time [2], sender (possibly null) [3], guid [4], messageId (possibly null) [5]
+
+local messageId
+
+if ARGV[5] ~= nil then
+ -- TODO: Remove this branch (and ARGV[5]) once the migration to a clustered message cache is finished
+ messageId = tonumber(ARGV[5])
+ redis.call("HSET", KEYS[2], "counter", messageId)
+else
+ messageId = redis.call("HINCRBY", KEYS[2], "counter", 1)
+end
-local messageId = redis.call("HINCRBY", KEYS[2], "counter", 1)
redis.call("ZADD", KEYS[1], "NX", messageId, ARGV[1])
if ARGV[3] ~= "nil" then
diff --git a/service/src/main/resources/lua/remove_item_by_id.lua b/service/src/main/resources/lua/remove_item_by_id.lua
index e9647a9fb..d9769dfd0 100644
--- a/service/src/main/resources/lua/remove_item_by_id.lua
+++ b/service/src/main/resources/lua/remove_item_by_id.lua
@@ -1,6 +1,7 @@
-- keys: queue_key, queue_metadata_key, queue_index
-- argv: index_to_remove
+local envelope = redis.call("ZRANGEBYSCORE", KEYS[1], ARGV[1], ARGV[1], "LIMIT", 0, 1)
local removedCount = redis.call("ZREMRANGEBYSCORE", KEYS[1], ARGV[1], ARGV[1])
local senderIndex = redis.call("HGET", KEYS[2], ARGV[1])
local guidIndex = redis.call("HGET", KEYS[2], ARGV[1] .. "guid")
@@ -19,4 +20,8 @@ if (redis.call("ZCARD", KEYS[1]) == 0) then
redis.call("ZREM", KEYS[3], KEYS[1])
end
-return removedCount > 0
+if envelope and next(envelope) then
+ return envelope[1]
+else
+ return nil
+end
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java
index d3afb4371..b49598a92 100644
--- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java
@@ -36,7 +36,7 @@ public abstract class AbstractRedisClusterTest {
private FaultTolerantRedisCluster redisCluster;
@BeforeClass
- public static void setUpBeforeClass() throws IOException, URISyntaxException, InterruptedException {
+ public static void setUpBeforeClass() throws Exception {
assumeFalse(System.getProperty("os.name").equalsIgnoreCase("windows"));
clusterNodes = new RedisServer[NODE_COUNT];
@@ -50,7 +50,7 @@ public abstract class AbstractRedisClusterTest {
}
@Before
- public void setUp() {
+ public void setUp() throws Exception {
final List urls = Arrays.stream(clusterNodes)
.map(node -> String.format("redis://127.0.0.1:%d", node.ports().get(0)))
.collect(Collectors.toList());
@@ -63,12 +63,12 @@ public abstract class AbstractRedisClusterTest {
}
@After
- public void tearDown() {
+ public void tearDown() throws Exception {
redisCluster.stop();
}
@AfterClass
- public static void tearDownAfterClass() {
+ public static void tearDownAfterClass() throws Exception {
for (final RedisServer node : clusterNodes) {
node.stop();
}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AbstractMessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AbstractMessagesCacheTest.java
new file mode 100644
index 000000000..734803496
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AbstractMessagesCacheTest.java
@@ -0,0 +1,157 @@
+package org.whispersystems.textsecuregcm.storage;
+
+import com.google.protobuf.ByteString;
+import junitparams.JUnitParamsRunner;
+import junitparams.Parameters;
+import org.apache.commons.lang3.RandomStringUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.whispersystems.textsecuregcm.entities.MessageProtos;
+import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
+import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.Random;
+import java.util.UUID;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+@RunWith(JUnitParamsRunner.class)
+public abstract class AbstractMessagesCacheTest extends AbstractRedisClusterTest {
+
+ private static final String DESTINATION_ACCOUNT = "+18005551234";
+ private static final int DESTINATION_DEVICE_ID = 7;
+
+ private final Random random = new Random();
+ private long serialTimestamp = 0;
+
+ protected abstract UserMessagesCache getMessagesCache();
+
+ @Test
+ @Parameters({"true", "false"})
+ public void testInsert(final boolean sealedSender) {
+ final UUID messageGuid = UUID.randomUUID();
+ assertTrue(getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, generateRandomMessage(messageGuid, sealedSender)) > 0);
+ }
+
+ @Test
+ @Parameters({"true", "false"})
+ public void testRemoveById(final boolean sealedSender) {
+ final UUID messageGuid = UUID.randomUUID();
+ final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
+
+ final long messageId = getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message);
+ final Optional maybeRemovedMessage = getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageId);
+
+ assertTrue(maybeRemovedMessage.isPresent());
+ assertEquals(UserMessagesCache.constructEntityFromEnvelope(messageId, message), maybeRemovedMessage.get());
+ assertEquals(Optional.empty(), getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageId));
+ }
+
+ @Test
+ public void testRemoveBySender() {
+ final UUID messageGuid = UUID.randomUUID();
+ final MessageProtos.Envelope message = generateRandomMessage(messageGuid, false);
+
+ getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message);
+ final Optional maybeRemovedMessage = getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message.getSource(), message.getTimestamp());
+
+ assertTrue(maybeRemovedMessage.isPresent());
+ assertEquals(UserMessagesCache.constructEntityFromEnvelope(0, message), maybeRemovedMessage.get());
+ assertEquals(Optional.empty(), getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message.getSource(), message.getTimestamp()));
+ }
+
+ @Test
+ @Parameters({"true", "false"})
+ public void testRemoveByUUID(final boolean sealedSender) {
+ final UUID messageGuid = UUID.randomUUID();
+
+ assertEquals(Optional.empty(), getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageGuid));
+
+ final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
+
+ getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message);
+ final Optional maybeRemovedMessage = getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageGuid);
+
+ assertTrue(maybeRemovedMessage.isPresent());
+ assertEquals(UserMessagesCache.constructEntityFromEnvelope(0, message), maybeRemovedMessage.get());
+ }
+
+ @Test
+ @Parameters({"true", "false"})
+ public void testGetMessages(final boolean sealedSender) {
+ final int messageCount = 100;
+
+ final List expectedMessages = new ArrayList<>(messageCount);
+
+ for (int i = 0; i < messageCount; i++) {
+ final UUID messageGuid = UUID.randomUUID();
+ final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
+ final long messageId = getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message);
+
+ expectedMessages.add(UserMessagesCache.constructEntityFromEnvelope(messageId, message));
+ }
+
+ assertEquals(expectedMessages, getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageCount));
+ }
+
+ @Test
+ @Parameters({"true", "false"})
+ public void testClearQueueForDevice(final boolean sealedSender) {
+ final int messageCount = 100;
+
+ for (final int deviceId : new int[] { DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1 }) {
+ for (int i = 0; i < messageCount; i++) {
+ final UUID messageGuid = UUID.randomUUID();
+ final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
+
+ getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, deviceId, message);
+ }
+ }
+
+ getMessagesCache().clear(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID);
+
+ assertEquals(Collections.emptyList(), getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageCount));
+ assertEquals(messageCount, getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID + 1, messageCount).size());
+ }
+
+ @Test
+ @Parameters({"true", "false"})
+ public void testClearQueueForAccount(final boolean sealedSender) {
+ final int messageCount = 100;
+
+ for (final int deviceId : new int[] { DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1 }) {
+ for (int i = 0; i < messageCount; i++) {
+ final UUID messageGuid = UUID.randomUUID();
+ final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
+
+ getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, deviceId, message);
+ }
+ }
+
+ getMessagesCache().clear(DESTINATION_ACCOUNT);
+
+ assertEquals(Collections.emptyList(), getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageCount));
+ assertEquals(Collections.emptyList(), getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID + 1, messageCount));
+ }
+
+ protected MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final boolean sealedSender) {
+ final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder()
+ .setTimestamp(serialTimestamp++)
+ .setServerTimestamp(serialTimestamp++)
+ .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
+ .setType(MessageProtos.Envelope.Type.CIPHERTEXT)
+ .setServerGuid(messageGuid.toString());
+
+ if (!sealedSender) {
+ envelopeBuilder.setSourceDevice(random.nextInt(256))
+ .setSource("+1" + RandomStringUtils.randomNumeric(10));
+ }
+
+ return envelopeBuilder.build();
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java
new file mode 100644
index 000000000..6a0c50ba8
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java
@@ -0,0 +1,41 @@
+package org.whispersystems.textsecuregcm.storage;
+
+import org.junit.After;
+import org.junit.Before;
+import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
+import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
+import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
+import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
+import redis.embedded.RedisServer;
+
+import java.util.List;
+
+import static org.mockito.Mockito.mock;
+
+public class MessagesCacheTest extends AbstractMessagesCacheTest {
+
+ private RedisServer redisServer;
+ private MessagesCache messagesCache;
+
+ @Before
+ public void setUp() throws Exception {
+ redisServer = new RedisServer(AbstractRedisClusterTest.getNextRedisClusterPort());
+ redisServer.start();
+
+ final String redisUrl = String.format("redis://127.0.0.1:%d", redisServer.ports().get(0));
+ final RedisClientFactory clientFactory = new RedisClientFactory("message-cache-test", redisUrl, List.of(redisUrl), new CircuitBreakerConfiguration());
+ final ReplicatedJedisPool jedisPool = clientFactory.getRedisClientPool();
+
+ messagesCache = new MessagesCache(jedisPool, mock(Messages.class), mock(AccountsManager.class), 60, mock(RedisClusterMessagesCache.class));
+ }
+
+ @After
+ public void tearDown() {
+ redisServer.stop();
+ }
+
+ @Override
+ protected UserMessagesCache getMessagesCache() {
+ return messagesCache;
+ }
+}
diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCacheTest.java
new file mode 100644
index 000000000..dc167bf24
--- /dev/null
+++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCacheTest.java
@@ -0,0 +1,48 @@
+package org.whispersystems.textsecuregcm.storage;
+
+import junitparams.Parameters;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.UUID;
+
+import static org.junit.Assert.assertEquals;
+
+public class RedisClusterMessagesCacheTest extends AbstractMessagesCacheTest {
+
+ private static final String DESTINATION_ACCOUNT = "+18005551234";
+ private static final int DESTINATION_DEVICE_ID = 7;
+
+ private RedisClusterMessagesCache messagesCache;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+
+ try {
+ messagesCache = new RedisClusterMessagesCache(getRedisCluster());
+ } catch (final IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ getRedisCluster().useWriteCluster(connection -> connection.sync().flushall());
+ }
+
+ @Override
+ protected UserMessagesCache getMessagesCache() {
+ return messagesCache;
+ }
+
+ @Test
+ @Parameters({"true", "false"})
+ public void testInsertWithPrescribedId(final boolean sealedSender) {
+ final UUID firstMessageGuid = UUID.randomUUID();
+ final UUID secondMessageGuid = UUID.randomUUID();
+ final long messageId = 74;
+
+ assertEquals(messageId, messagesCache.insert(firstMessageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, generateRandomMessage(firstMessageGuid, sealedSender), messageId));
+ assertEquals(messageId + 1, messagesCache.insert(secondMessageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, generateRandomMessage(secondMessageGuid, sealedSender)));
+ }
+}