diff --git a/service/config/sample.yml b/service/config/sample.yml index efbf66102..dc86f1da0 100644 --- a/service/config/sample.yml +++ b/service/config/sample.yml @@ -47,8 +47,13 @@ directory: reconciliationChunkIntervalMs: # CDS reconciliation chunk interval, in milliseconds messageCache: # Redis server configuration for message store cache - url: - replicaUrls: + redis: + url: + replicaUrls: + + cluster: + urls: + - redis://redis.example.com:6379/ messageStore: # Postgresql database configuration for message store driverClass: org.postgresql.Driver diff --git a/service/pom.xml b/service/pom.xml index d3d5f9f82..d15272975 100644 --- a/service/pom.xml +++ b/service/pom.xml @@ -198,6 +198,12 @@ test + + pl.pragmatists + JUnitParams + 1.1.1 + test + diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 5634ecab9..b84b97607 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -50,7 +50,6 @@ import io.micrometer.datadog.DatadogMeterRegistry; import io.micrometer.wavefront.WavefrontConfig; import io.micrometer.wavefront.WavefrontMeterRegistry; import org.bouncycastle.jce.provider.BouncyCastleProvider; -import org.coursera.metrics.datadog.DatadogReporter; import org.eclipse.jetty.servlets.CrossOriginFilter; import org.jdbi.v3.core.Jdbi; import org.signal.zkgroup.ServerSecretParams; @@ -137,6 +136,7 @@ import org.whispersystems.textsecuregcm.storage.Profiles; import org.whispersystems.textsecuregcm.storage.ProfilesManager; import org.whispersystems.textsecuregcm.storage.PubSubManager; import org.whispersystems.textsecuregcm.storage.PushFeedbackProcessor; +import org.whispersystems.textsecuregcm.storage.RedisClusterMessagesCache; import org.whispersystems.textsecuregcm.storage.RemoteConfigs; import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager; import org.whispersystems.textsecuregcm.storage.ReservedUsernames; @@ -300,7 +300,8 @@ public class WhisperServerService extends Application(1_000)); + + private final Experiment insertExperiment = new Experiment("MessagesCache", "insert"); + private final Experiment removeByIdExperiment = new Experiment("MessagesCache", "removeById"); + private final Experiment removeBySenderExperiment = new Experiment("MessagesCache", "removeBySender"); + private final Experiment removeByUuidExperiment = new Experiment("MessagesCache", "removeByUuid"); + private final Experiment getMessagesExperiment = new Experiment("MessagesCache", "getMessages"); + + public MessagesCache(ReplicatedJedisPool jedisPool, Messages database, AccountsManager accountsManager, int delayMinutes, RedisClusterMessagesCache clusterMessagesCache) throws IOException { + this.jedisPool = jedisPool; + this.database = database; + this.accountsManager = accountsManager; + this.delayMinutes = delayMinutes; + + this.insertOperation = new InsertOperation(jedisPool); + this.removeOperation = new RemoveOperation(jedisPool); + this.getOperation = new GetOperation(jedisPool); + + this.clusterMessagesCache = clusterMessagesCache; } - public void insert(UUID guid, String destination, long destinationDevice, Envelope message) { - message = message.toBuilder().setServerGuid(guid.toString()).build(); + @Override + public long insert(UUID guid, String destination, long destinationDevice, Envelope message) { + final Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build(); Timer.Context timer = insertTimer.time(); try { - insertOperation.insert(guid, destination, destinationDevice, System.currentTimeMillis(), message); + final long messageId = insertOperation.insert(guid, destination, destinationDevice, System.currentTimeMillis(), messageWithGuid); + insertExperiment.compareSupplierResultAsync(messageId, () -> clusterMessagesCache.insert(guid, destination, destinationDevice, message, messageId), experimentExecutor); + + return messageId; } finally { timer.stop(); } } - public void remove(String destination, long destinationDevice, long id) { + @Override + public Optional remove(String destination, long destinationDevice, long id) { + OutgoingMessageEntity removedMessageEntity = null; + try (Jedis jedis = jedisPool.getWriteResource(); Timer.Context ignored = removeByIdTimer.time()) { - removeOperation.remove(jedis, destination, destinationDevice, id); + byte[] serialized = removeOperation.remove(jedis, destination, destinationDevice, id); + + if (serialized != null) { + removedMessageEntity = UserMessagesCache.constructEntityFromEnvelope(id, Envelope.parseFrom(serialized)); + } + } catch (InvalidProtocolBufferException e) { + logger.warn("Failed to parse envelope", e); } + + final Optional maybeRemovedMessage = Optional.ofNullable(removedMessageEntity); + + removeByIdExperiment.compareSupplierResultAsync(maybeRemovedMessage, () -> clusterMessagesCache.remove(destination, destinationDevice, id), experimentExecutor); + + return maybeRemovedMessage; } + @Override public Optional remove(String destination, long destinationDevice, String sender, long timestamp) { + OutgoingMessageEntity removedMessageEntity = null; Timer.Context timer = removeByNameTimer.time(); try { byte[] serialized = removeOperation.remove(destination, destinationDevice, sender, timestamp); if (serialized != null) { - Envelope envelope = Envelope.parseFrom(serialized); - return Optional.of(constructEntityFromEnvelope(0, envelope)); + removedMessageEntity = UserMessagesCache.constructEntityFromEnvelope(0, Envelope.parseFrom(serialized)); } } catch (InvalidProtocolBufferException e) { logger.warn("Failed to parse envelope", e); @@ -105,18 +144,23 @@ public class MessagesCache implements Managed { timer.stop(); } - return Optional.empty(); + final Optional maybeRemovedMessage = Optional.ofNullable(removedMessageEntity); + + removeBySenderExperiment.compareSupplierResultAsync(maybeRemovedMessage, () -> clusterMessagesCache.remove(destination, destinationDevice, sender, timestamp), experimentExecutor); + + return maybeRemovedMessage; } + @Override public Optional remove(String destination, long destinationDevice, UUID guid) { + OutgoingMessageEntity removedMessageEntity = null; Timer.Context timer = removeByGuidTimer.time(); try { byte[] serialized = removeOperation.remove(destination, destinationDevice, guid); if (serialized != null) { - Envelope envelope = Envelope.parseFrom(serialized); - return Optional.of(constructEntityFromEnvelope(0, envelope)); + removedMessageEntity = UserMessagesCache.constructEntityFromEnvelope(0, Envelope.parseFrom(serialized)); } } catch (InvalidProtocolBufferException e) { logger.warn("Failed to parse envelope", e); @@ -124,9 +168,14 @@ public class MessagesCache implements Managed { timer.stop(); } - return Optional.empty(); + final Optional maybeRemovedMessage = Optional.ofNullable(removedMessageEntity); + + removeByUuidExperiment.compareSupplierResultAsync(maybeRemovedMessage, () -> clusterMessagesCache.remove(destination, destinationDevice, guid), experimentExecutor); + + return maybeRemovedMessage; } + @Override public List get(String destination, long destinationDevice, int limit) { Timer.Context timer = getTimer.time(); @@ -139,18 +188,21 @@ public class MessagesCache implements Managed { try { long id = item.second().longValue(); Envelope message = Envelope.parseFrom(item.first()); - results.add(constructEntityFromEnvelope(id, message)); + results.add(UserMessagesCache.constructEntityFromEnvelope(id, message)); } catch (InvalidProtocolBufferException e) { logger.warn("Failed to parse envelope", e); } } + getMessagesExperiment.compareSupplierResultAsync(results, () -> clusterMessagesCache.get(destination, destinationDevice, limit), experimentExecutor); + return results; } finally { timer.stop(); } } + @Override public void clear(String destination) { Timer.Context timer = clearAccountTimer.time(); @@ -163,6 +215,7 @@ public class MessagesCache implements Managed { } } + @Override public void clear(String destination, long deviceId) { Timer.Context timer = clearDeviceTimer.time(); @@ -180,11 +233,7 @@ public class MessagesCache implements Managed { @Override public void start() throws Exception { - this.insertOperation = new InsertOperation(jedisPool); - this.removeOperation = new RemoveOperation(jedisPool); - this.getOperation = new GetOperation(jedisPool); this.messagePersister = new MessagePersister(jedisPool, database, pubSubManager, pushSender, accountsManager, delayMinutes, TimeUnit.MINUTES); - this.messagePersister.start(); } @@ -192,20 +241,8 @@ public class MessagesCache implements Managed { public void stop() throws Exception { messagePersister.shutdown(); logger.info("Message persister shut down..."); - } - private OutgoingMessageEntity constructEntityFromEnvelope(long id, Envelope envelope) { - return new OutgoingMessageEntity(id, true, - envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null, - envelope.getType().getNumber(), - envelope.getRelay(), - envelope.getTimestamp(), - envelope.getSource(), - envelope.hasSourceUuid() ? UUID.fromString(envelope.getSourceUuid()) : null, - envelope.getSourceDevice(), - envelope.hasLegacyMessage() ? envelope.getLegacyMessage().toByteArray() : null, - envelope.hasContent() ? envelope.getContent().toByteArray() : null, - envelope.hasServerTimestamp() ? envelope.getServerTimestamp() : 0); + this.experimentExecutor.shutdown(); } private static class Key { @@ -271,14 +308,14 @@ public class MessagesCache implements Managed { this.insert = LuaScript.fromResource(jedisPool, "lua/insert_item.lua"); } - public void insert(UUID guid, String destination, long destinationDevice, long timestamp, Envelope message) { + public long insert(UUID guid, String destination, long destinationDevice, long timestamp, Envelope message) { Key key = new Key(destination, destinationDevice); String sender = message.hasSource() ? (message.getSource() + "::" + message.getTimestamp()) : "nil"; List keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex()); List args = Arrays.asList(message.toByteArray(), String.valueOf(timestamp).getBytes(), sender.getBytes(), guid.toString().getBytes()); - insert.execute(keys, args); + return (long)insert.execute(keys, args); } } @@ -296,13 +333,13 @@ public class MessagesCache implements Managed { this.removeQueue = LuaScript.fromResource(jedisPool, "lua/remove_queue.lua" ); } - public void remove(Jedis jedis, String destination, long destinationDevice, long id) { + public byte[] remove(Jedis jedis, String destination, long destinationDevice, long id) { Key key = new Key(destination, destinationDevice); List keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex()); List args = Collections.singletonList(String.valueOf(id).getBytes()); - this.removeById.execute(jedis, keys, args); + return (byte[])this.removeById.execute(jedis, keys, args); } public byte[] remove(String destination, long destinationDevice, String sender, long timestamp) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCache.java new file mode 100644 index 000000000..c6e2c8857 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCache.java @@ -0,0 +1,206 @@ +package org.whispersystems.textsecuregcm.storage; + +import com.google.protobuf.InvalidProtocolBufferException; +import io.lettuce.core.ScriptOutputType; +import io.micrometer.core.instrument.Metrics; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; +import org.whispersystems.textsecuregcm.redis.ClusterLuaScript; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster; +import org.whispersystems.textsecuregcm.util.RedisClusterUtil; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.UUID; + +import static com.codahale.metrics.MetricRegistry.name; + +public class RedisClusterMessagesCache implements UserMessagesCache { + + private final ClusterLuaScript insertScript; + private final ClusterLuaScript removeByIdScript; + private final ClusterLuaScript removeBySenderScript; + private final ClusterLuaScript removeByGuidScript; + private final ClusterLuaScript getItemsScript; + private final ClusterLuaScript removeQueueScript; + + private static final String INSERT_TIMER_NAME = name(RedisClusterMessagesCache.class, "insert"); + private static final String REMOVE_TIMER_NAME = name(RedisClusterMessagesCache.class, "remove"); + private static final String GET_TIMER_NAME = name(RedisClusterMessagesCache.class, "get"); + private static final String CLEAR_TIMER_NAME = name(RedisClusterMessagesCache.class, "clear"); + + private static final String REMOVE_METHOD_TAG = "method"; + private static final String REMOVE_METHOD_ID = "id"; + private static final String REMOVE_METHOD_SENDER = "sender"; + private static final String REMOVE_METHOD_UUID = "uuid"; + + private static final Logger logger = LoggerFactory.getLogger(RedisClusterMessagesCache.class); + + public RedisClusterMessagesCache(final FaultTolerantRedisCluster redisCluster) throws IOException { + + this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER); + this.removeByIdScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_id.lua", ScriptOutputType.VALUE); + this.removeBySenderScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_sender.lua", ScriptOutputType.VALUE); + this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua", ScriptOutputType.VALUE); + this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.MULTI); + this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua", ScriptOutputType.STATUS); + } + + @Override + public long insert(final UUID guid, final String destination, final long destinationDevice, final MessageProtos.Envelope message) { + final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build(); + final String sender = message.hasSource() ? (message.getSource() + "::" + message.getTimestamp()) : "nil"; + + return (long)Metrics.timer(INSERT_TIMER_NAME).record(() -> + insertScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice), + getMessageQueueMetadataKey(destination, destinationDevice), + getQueueIndexKey(destination, destinationDevice)), + List.of(messageWithGuid.toByteArray(), + String.valueOf(message.getTimestamp()).getBytes(StandardCharsets.UTF_8), + sender.getBytes(StandardCharsets.UTF_8), + guid.toString().getBytes(StandardCharsets.UTF_8)))); + } + + public long insert(final UUID guid, final String destination, final long destinationDevice, final MessageProtos.Envelope message, final long messageId) { + final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build(); + final String sender = message.hasSource() ? (message.getSource() + "::" + message.getTimestamp()) : "nil"; + + return (long)Metrics.timer(INSERT_TIMER_NAME).record(() -> + insertScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice), + getMessageQueueMetadataKey(destination, destinationDevice), + getQueueIndexKey(destination, destinationDevice)), + List.of(messageWithGuid.toByteArray(), + String.valueOf(message.getTimestamp()).getBytes(StandardCharsets.UTF_8), + sender.getBytes(StandardCharsets.UTF_8), + guid.toString().getBytes(StandardCharsets.UTF_8), + String.valueOf(messageId).getBytes(StandardCharsets.UTF_8)))); + } + + @Override + public Optional remove(final String destination, final long destinationDevice, final long id) { + try { + final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_ID).record(() -> + removeByIdScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice), + getMessageQueueMetadataKey(destination, destinationDevice), + getQueueIndexKey(destination, destinationDevice)), + List.of(String.valueOf(id).getBytes(StandardCharsets.UTF_8)))); + + + if (serialized != null) { + return Optional.of(UserMessagesCache.constructEntityFromEnvelope(id, MessageProtos.Envelope.parseFrom(serialized))); + } + } catch (final InvalidProtocolBufferException e) { + logger.warn("Failed to parse envelope", e); + } + + return Optional.empty(); + } + + @Override + public Optional remove(final String destination, final long destinationDevice, final String sender, final long timestamp) { + try { + final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_SENDER).record(() -> + removeBySenderScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice), + getMessageQueueMetadataKey(destination, destinationDevice), + getQueueIndexKey(destination, destinationDevice)), + List.of((sender + "::" + timestamp).getBytes(StandardCharsets.UTF_8)))); + + if (serialized != null) { + return Optional.of(UserMessagesCache.constructEntityFromEnvelope(0, MessageProtos.Envelope.parseFrom(serialized))); + } + } catch (final InvalidProtocolBufferException e) { + logger.warn("Failed to parse envelope", e); + } + + return Optional.empty(); + } + + @Override + public Optional remove(final String destination, final long destinationDevice, final UUID guid) { + try { + final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_UUID).record(() -> + removeByGuidScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice), + getMessageQueueMetadataKey(destination, destinationDevice), + getQueueIndexKey(destination, destinationDevice)), + List.of(guid.toString().getBytes(StandardCharsets.UTF_8)))); + + if (serialized != null) { + return Optional.of(UserMessagesCache.constructEntityFromEnvelope(0, MessageProtos.Envelope.parseFrom(serialized))); + } + } catch (final InvalidProtocolBufferException e) { + logger.warn("Failed to parse envelope", e); + } + + return Optional.empty(); + } + + @Override + @SuppressWarnings("unchecked") + public List get(String destination, long destinationDevice, int limit) { + return Metrics.timer(GET_TIMER_NAME).record(() -> { + final List queueItems = (List)getItemsScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice), + getPersistInProgressKey(destination, destinationDevice)), + List.of(String.valueOf(limit).getBytes())); + + final List messageEntities; + + if (queueItems.size() % 2 == 0) { + messageEntities = new ArrayList<>(queueItems.size() / 2); + + for (int i = 0; i < queueItems.size() - 1; i += 2) { + try { + final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(queueItems.get(i)); + final long id = Long.parseLong(new String(queueItems.get(i + 1), StandardCharsets.UTF_8)); + + messageEntities.add(UserMessagesCache.constructEntityFromEnvelope(id, message)); + } catch (InvalidProtocolBufferException e) { + logger.warn("Failed to parse envelope", e); + } + } + } else { + logger.error("\"Get messages\" operation returned a list with a non-even number of elements."); + messageEntities = Collections.emptyList(); + } + + return messageEntities; + }); + } + + @Override + public void clear(final String destination) { + for (int i = 1; i < 256; i++) { + clear(destination, i); + } + } + + @Override + public void clear(final String destination, final long deviceId) { + Metrics.timer(CLEAR_TIMER_NAME).record(() -> + removeQueueScript.executeBinary(List.of(getMessageQueueKey(destination, deviceId), + getMessageQueueMetadataKey(destination, deviceId), + getQueueIndexKey(destination, deviceId)), + Collections.emptyList())); + } + + private static byte[] getMessageQueueKey(final String address, final long deviceId) { + return ("user_queue::{" + address + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); + } + + private static byte[] getMessageQueueMetadataKey(final String address, final long deviceId) { + return ("user_queue_metadata::{" + address + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); + } + + private byte[] getQueueIndexKey(final String address, final long deviceId) { + return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(address + "::" + deviceId) + "}").getBytes(StandardCharsets.UTF_8); + } + + private byte[] getPersistInProgressKey(final String address, final long deviceId) { + return ("user_queue_persisting::{" + address + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/UserMessagesCache.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/UserMessagesCache.java new file mode 100644 index 000000000..f3f813233 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/UserMessagesCache.java @@ -0,0 +1,41 @@ +package org.whispersystems.textsecuregcm.storage; + +import com.google.common.annotations.VisibleForTesting; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; +import org.whispersystems.textsecuregcm.push.PushSender; + +import java.util.List; +import java.util.Optional; +import java.util.UUID; + +public interface UserMessagesCache { + @VisibleForTesting + static OutgoingMessageEntity constructEntityFromEnvelope(long id, MessageProtos.Envelope envelope) { + return new OutgoingMessageEntity(id, true, + envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null, + envelope.getType().getNumber(), + envelope.getRelay(), + envelope.getTimestamp(), + envelope.getSource(), + envelope.hasSourceUuid() ? UUID.fromString(envelope.getSourceUuid()) : null, + envelope.getSourceDevice(), + envelope.hasLegacyMessage() ? envelope.getLegacyMessage().toByteArray() : null, + envelope.hasContent() ? envelope.getContent().toByteArray() : null, + envelope.hasServerTimestamp() ? envelope.getServerTimestamp() : 0); + } + + long insert(UUID guid, String destination, long destinationDevice, MessageProtos.Envelope message); + + Optional remove(String destination, long destinationDevice, long id); + + Optional remove(String destination, long destinationDevice, String sender, long timestamp); + + Optional remove(String destination, long destinationDevice, UUID guid); + + List get(String destination, long destinationDevice, int limit); + + void clear(String destination); + + void clear(String destination, long deviceId); +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java new file mode 100644 index 000000000..bb02d56dd --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/RedisClusterUtil.java @@ -0,0 +1,37 @@ +package org.whispersystems.textsecuregcm.util; + +import io.lettuce.core.cluster.SlotHash; +import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; +import io.lettuce.core.cluster.models.partitions.RedisClusterNode; + +public class RedisClusterUtil { + + private static final String[] HASHES_BY_SLOT = new String[SlotHash.SLOT_COUNT]; + + static { + int slotsCovered = 0; + int i = 0; + + while (slotsCovered < HASHES_BY_SLOT.length) { + final String hash = Integer.toString(i++, 36); + final int slot = SlotHash.getSlot(hash); + + if (HASHES_BY_SLOT[slot] == null) { + HASHES_BY_SLOT[slot] = hash; + slotsCovered += 1; + } + } + } + + /** + * Returns a short Redis hash tag that maps to the same Redis cluster slot as the given key. + * + * @param key the key for which to find a matching hash tag + * @return a Redis hash tag that maps to the same Redis cluster slot as the given key + * + * @see Redis Cluster Specification - Keys hash tags + */ + public static String getMinimalHashTag(final String key) { + return HASHES_BY_SLOT[SlotHash.getSlot(key)]; + } +} diff --git a/service/src/main/resources/lua/insert_item.lua b/service/src/main/resources/lua/insert_item.lua index 51c4dcaed..70ef4d248 100644 --- a/service/src/main/resources/lua/insert_item.lua +++ b/service/src/main/resources/lua/insert_item.lua @@ -1,7 +1,16 @@ -- keys: queue_key [1], queue_metadata_key [2], queue_total_index [3] --- argv: message [1], current_time [2], sender (possibly null) [3], guid [4] +-- argv: message [1], current_time [2], sender (possibly null) [3], guid [4], messageId (possibly null) [5] + +local messageId + +if ARGV[5] ~= nil then + -- TODO: Remove this branch (and ARGV[5]) once the migration to a clustered message cache is finished + messageId = tonumber(ARGV[5]) + redis.call("HSET", KEYS[2], "counter", messageId) +else + messageId = redis.call("HINCRBY", KEYS[2], "counter", 1) +end -local messageId = redis.call("HINCRBY", KEYS[2], "counter", 1) redis.call("ZADD", KEYS[1], "NX", messageId, ARGV[1]) if ARGV[3] ~= "nil" then diff --git a/service/src/main/resources/lua/remove_item_by_id.lua b/service/src/main/resources/lua/remove_item_by_id.lua index e9647a9fb..d9769dfd0 100644 --- a/service/src/main/resources/lua/remove_item_by_id.lua +++ b/service/src/main/resources/lua/remove_item_by_id.lua @@ -1,6 +1,7 @@ -- keys: queue_key, queue_metadata_key, queue_index -- argv: index_to_remove +local envelope = redis.call("ZRANGEBYSCORE", KEYS[1], ARGV[1], ARGV[1], "LIMIT", 0, 1) local removedCount = redis.call("ZREMRANGEBYSCORE", KEYS[1], ARGV[1], ARGV[1]) local senderIndex = redis.call("HGET", KEYS[2], ARGV[1]) local guidIndex = redis.call("HGET", KEYS[2], ARGV[1] .. "guid") @@ -19,4 +20,8 @@ if (redis.call("ZCARD", KEYS[1]) == 0) then redis.call("ZREM", KEYS[3], KEYS[1]) end -return removedCount > 0 +if envelope and next(envelope) then + return envelope[1] +else + return nil +end diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java index d3afb4371..b49598a92 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/AbstractRedisClusterTest.java @@ -36,7 +36,7 @@ public abstract class AbstractRedisClusterTest { private FaultTolerantRedisCluster redisCluster; @BeforeClass - public static void setUpBeforeClass() throws IOException, URISyntaxException, InterruptedException { + public static void setUpBeforeClass() throws Exception { assumeFalse(System.getProperty("os.name").equalsIgnoreCase("windows")); clusterNodes = new RedisServer[NODE_COUNT]; @@ -50,7 +50,7 @@ public abstract class AbstractRedisClusterTest { } @Before - public void setUp() { + public void setUp() throws Exception { final List urls = Arrays.stream(clusterNodes) .map(node -> String.format("redis://127.0.0.1:%d", node.ports().get(0))) .collect(Collectors.toList()); @@ -63,12 +63,12 @@ public abstract class AbstractRedisClusterTest { } @After - public void tearDown() { + public void tearDown() throws Exception { redisCluster.stop(); } @AfterClass - public static void tearDownAfterClass() { + public static void tearDownAfterClass() throws Exception { for (final RedisServer node : clusterNodes) { node.stop(); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AbstractMessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AbstractMessagesCacheTest.java new file mode 100644 index 000000000..734803496 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AbstractMessagesCacheTest.java @@ -0,0 +1,157 @@ +package org.whispersystems.textsecuregcm.storage; + +import com.google.protobuf.ByteString; +import junitparams.JUnitParamsRunner; +import junitparams.Parameters; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.whispersystems.textsecuregcm.entities.MessageProtos; +import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity; +import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.Random; +import java.util.UUID; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +@RunWith(JUnitParamsRunner.class) +public abstract class AbstractMessagesCacheTest extends AbstractRedisClusterTest { + + private static final String DESTINATION_ACCOUNT = "+18005551234"; + private static final int DESTINATION_DEVICE_ID = 7; + + private final Random random = new Random(); + private long serialTimestamp = 0; + + protected abstract UserMessagesCache getMessagesCache(); + + @Test + @Parameters({"true", "false"}) + public void testInsert(final boolean sealedSender) { + final UUID messageGuid = UUID.randomUUID(); + assertTrue(getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, generateRandomMessage(messageGuid, sealedSender)) > 0); + } + + @Test + @Parameters({"true", "false"}) + public void testRemoveById(final boolean sealedSender) { + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); + + final long messageId = getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message); + final Optional maybeRemovedMessage = getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageId); + + assertTrue(maybeRemovedMessage.isPresent()); + assertEquals(UserMessagesCache.constructEntityFromEnvelope(messageId, message), maybeRemovedMessage.get()); + assertEquals(Optional.empty(), getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageId)); + } + + @Test + public void testRemoveBySender() { + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, false); + + getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message); + final Optional maybeRemovedMessage = getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message.getSource(), message.getTimestamp()); + + assertTrue(maybeRemovedMessage.isPresent()); + assertEquals(UserMessagesCache.constructEntityFromEnvelope(0, message), maybeRemovedMessage.get()); + assertEquals(Optional.empty(), getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message.getSource(), message.getTimestamp())); + } + + @Test + @Parameters({"true", "false"}) + public void testRemoveByUUID(final boolean sealedSender) { + final UUID messageGuid = UUID.randomUUID(); + + assertEquals(Optional.empty(), getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageGuid)); + + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); + + getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message); + final Optional maybeRemovedMessage = getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageGuid); + + assertTrue(maybeRemovedMessage.isPresent()); + assertEquals(UserMessagesCache.constructEntityFromEnvelope(0, message), maybeRemovedMessage.get()); + } + + @Test + @Parameters({"true", "false"}) + public void testGetMessages(final boolean sealedSender) { + final int messageCount = 100; + + final List expectedMessages = new ArrayList<>(messageCount); + + for (int i = 0; i < messageCount; i++) { + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); + final long messageId = getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message); + + expectedMessages.add(UserMessagesCache.constructEntityFromEnvelope(messageId, message)); + } + + assertEquals(expectedMessages, getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageCount)); + } + + @Test + @Parameters({"true", "false"}) + public void testClearQueueForDevice(final boolean sealedSender) { + final int messageCount = 100; + + for (final int deviceId : new int[] { DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1 }) { + for (int i = 0; i < messageCount; i++) { + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); + + getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, deviceId, message); + } + } + + getMessagesCache().clear(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID); + + assertEquals(Collections.emptyList(), getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageCount)); + assertEquals(messageCount, getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID + 1, messageCount).size()); + } + + @Test + @Parameters({"true", "false"}) + public void testClearQueueForAccount(final boolean sealedSender) { + final int messageCount = 100; + + for (final int deviceId : new int[] { DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1 }) { + for (int i = 0; i < messageCount; i++) { + final UUID messageGuid = UUID.randomUUID(); + final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender); + + getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, deviceId, message); + } + } + + getMessagesCache().clear(DESTINATION_ACCOUNT); + + assertEquals(Collections.emptyList(), getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageCount)); + assertEquals(Collections.emptyList(), getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID + 1, messageCount)); + } + + protected MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final boolean sealedSender) { + final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder() + .setTimestamp(serialTimestamp++) + .setServerTimestamp(serialTimestamp++) + .setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256))) + .setType(MessageProtos.Envelope.Type.CIPHERTEXT) + .setServerGuid(messageGuid.toString()); + + if (!sealedSender) { + envelopeBuilder.setSourceDevice(random.nextInt(256)) + .setSource("+1" + RandomStringUtils.randomNumeric(10)); + } + + return envelopeBuilder.build(); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java new file mode 100644 index 000000000..6a0c50ba8 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/MessagesCacheTest.java @@ -0,0 +1,41 @@ +package org.whispersystems.textsecuregcm.storage; + +import org.junit.After; +import org.junit.Before; +import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration; +import org.whispersystems.textsecuregcm.providers.RedisClientFactory; +import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest; +import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool; +import redis.embedded.RedisServer; + +import java.util.List; + +import static org.mockito.Mockito.mock; + +public class MessagesCacheTest extends AbstractMessagesCacheTest { + + private RedisServer redisServer; + private MessagesCache messagesCache; + + @Before + public void setUp() throws Exception { + redisServer = new RedisServer(AbstractRedisClusterTest.getNextRedisClusterPort()); + redisServer.start(); + + final String redisUrl = String.format("redis://127.0.0.1:%d", redisServer.ports().get(0)); + final RedisClientFactory clientFactory = new RedisClientFactory("message-cache-test", redisUrl, List.of(redisUrl), new CircuitBreakerConfiguration()); + final ReplicatedJedisPool jedisPool = clientFactory.getRedisClientPool(); + + messagesCache = new MessagesCache(jedisPool, mock(Messages.class), mock(AccountsManager.class), 60, mock(RedisClusterMessagesCache.class)); + } + + @After + public void tearDown() { + redisServer.stop(); + } + + @Override + protected UserMessagesCache getMessagesCache() { + return messagesCache; + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCacheTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCacheTest.java new file mode 100644 index 000000000..dc167bf24 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RedisClusterMessagesCacheTest.java @@ -0,0 +1,48 @@ +package org.whispersystems.textsecuregcm.storage; + +import junitparams.Parameters; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.util.UUID; + +import static org.junit.Assert.assertEquals; + +public class RedisClusterMessagesCacheTest extends AbstractMessagesCacheTest { + + private static final String DESTINATION_ACCOUNT = "+18005551234"; + private static final int DESTINATION_DEVICE_ID = 7; + + private RedisClusterMessagesCache messagesCache; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + try { + messagesCache = new RedisClusterMessagesCache(getRedisCluster()); + } catch (final IOException e) { + throw new RuntimeException(e); + } + + getRedisCluster().useWriteCluster(connection -> connection.sync().flushall()); + } + + @Override + protected UserMessagesCache getMessagesCache() { + return messagesCache; + } + + @Test + @Parameters({"true", "false"}) + public void testInsertWithPrescribedId(final boolean sealedSender) { + final UUID firstMessageGuid = UUID.randomUUID(); + final UUID secondMessageGuid = UUID.randomUUID(); + final long messageId = 74; + + assertEquals(messageId, messagesCache.insert(firstMessageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, generateRandomMessage(firstMessageGuid, sealedSender), messageId)); + assertEquals(messageId + 1, messagesCache.insert(secondMessageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, generateRandomMessage(secondMessageGuid, sealedSender))); + } +}