Add a cluster-backed message cache.

This commit is contained in:
Jon Chambers 2020-07-09 09:34:20 -04:00 committed by Jon Chambers
parent 639898ec07
commit 6fc1b4c6c0
15 changed files with 690 additions and 59 deletions

View File

@ -47,8 +47,13 @@ directory:
reconciliationChunkIntervalMs: # CDS reconciliation chunk interval, in milliseconds
messageCache: # Redis server configuration for message store cache
url:
replicaUrls:
redis:
url:
replicaUrls:
cluster:
urls:
- redis://redis.example.com:6379/
messageStore: # Postgresql database configuration for message store
driverClass: org.postgresql.Driver

View File

@ -198,6 +198,12 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>pl.pragmatists</groupId>
<artifactId>JUnitParams</artifactId>
<version>1.1.1</version>
<scope>test</scope>
</dependency>
</dependencies>

View File

@ -50,7 +50,6 @@ import io.micrometer.datadog.DatadogMeterRegistry;
import io.micrometer.wavefront.WavefrontConfig;
import io.micrometer.wavefront.WavefrontMeterRegistry;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.coursera.metrics.datadog.DatadogReporter;
import org.eclipse.jetty.servlets.CrossOriginFilter;
import org.jdbi.v3.core.Jdbi;
import org.signal.zkgroup.ServerSecretParams;
@ -137,6 +136,7 @@ import org.whispersystems.textsecuregcm.storage.Profiles;
import org.whispersystems.textsecuregcm.storage.ProfilesManager;
import org.whispersystems.textsecuregcm.storage.PubSubManager;
import org.whispersystems.textsecuregcm.storage.PushFeedbackProcessor;
import org.whispersystems.textsecuregcm.storage.RedisClusterMessagesCache;
import org.whispersystems.textsecuregcm.storage.RemoteConfigs;
import org.whispersystems.textsecuregcm.storage.RemoteConfigsManager;
import org.whispersystems.textsecuregcm.storage.ReservedUsernames;
@ -300,7 +300,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
RedisClusterClient cacheClusterClient = RedisClusterClient.create(config.getCacheClusterConfiguration().getUrls().stream().map(RedisURI::create).collect(Collectors.toList()));
cacheClusterClient.setDefaultTimeout(config.getCacheClusterConfiguration().getTimeout());
FaultTolerantRedisCluster cacheCluster = new FaultTolerantRedisCluster("main_cache_cluster", config.getCacheClusterConfiguration().getUrls(), config.getCacheClusterConfiguration().getTimeout(), config.getCacheClusterConfiguration().getCircuitBreakerConfiguration());
FaultTolerantRedisCluster cacheCluster = new FaultTolerantRedisCluster("main_cache_cluster", config.getCacheClusterConfiguration().getUrls(), config.getCacheClusterConfiguration().getTimeout(), config.getCacheClusterConfiguration().getCircuitBreakerConfiguration());
FaultTolerantRedisCluster messagesCacheCluster = new FaultTolerantRedisCluster("messages_cluster", config.getMessageCacheConfiguration().getRedisClusterConfiguration().getUrls(), config.getMessageCacheConfiguration().getRedisClusterConfiguration().getTimeout(), config.getMessageCacheConfiguration().getRedisClusterConfiguration().getCircuitBreakerConfiguration());
DirectoryManager directory = new DirectoryManager(directoryClient);
DirectoryQueue directoryQueue = new DirectoryQueue(config.getDirectoryConfiguration().getSqsConfiguration());
@ -309,7 +310,8 @@ public class WhisperServerService extends Application<WhisperServerConfiguration
AccountsManager accountsManager = new AccountsManager(accounts, directory, cacheCluster);
UsernamesManager usernamesManager = new UsernamesManager(usernames, reservedUsernames, cacheCluster);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
MessagesCache messagesCache = new MessagesCache(messagesClient, messages, accountsManager, config.getMessageCacheConfiguration().getPersistDelayMinutes());
RedisClusterMessagesCache clusterMessagesCache = new RedisClusterMessagesCache(messagesCacheCluster);
MessagesCache messagesCache = new MessagesCache(messagesClient, messages, accountsManager, config.getMessageCacheConfiguration().getPersistDelayMinutes(), clusterMessagesCache);
MessagesManager messagesManager = new MessagesManager(messages, messagesCache);
RemoteConfigsManager remoteConfigsManager = new RemoteConfigsManager(remoteConfigs);
DeadLetterHandler deadLetterHandler = new DeadLetterHandler(messagesManager);

View File

@ -12,6 +12,11 @@ public class MessageCacheConfiguration {
@Valid
private RedisConfiguration redis;
@JsonProperty
@NotNull
@Valid
private RedisClusterConfiguration cluster;
@JsonProperty
private int persistDelayMinutes = 10;
@ -19,6 +24,10 @@ public class MessageCacheConfiguration {
return redis;
}
public RedisClusterConfiguration getRedisClusterConfiguration() {
return cluster;
}
public int getPersistDelayMinutes() {
return persistDelayMinutes;
}

View File

@ -3,6 +3,8 @@ package org.whispersystems.textsecuregcm.entities;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Arrays;
import java.util.Objects;
import java.util.UUID;
public class OutgoingMessageEntity {
@ -114,4 +116,30 @@ public class OutgoingMessageEntity {
return serverTimestamp;
}
@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final OutgoingMessageEntity that = (OutgoingMessageEntity)o;
return id == that.id &&
cached == that.cached &&
type == that.type &&
timestamp == that.timestamp &&
sourceDevice == that.sourceDevice &&
serverTimestamp == that.serverTimestamp &&
Objects.equals(guid, that.guid) &&
Objects.equals(relay, that.relay) &&
Objects.equals(source, that.source) &&
Objects.equals(sourceUuid, that.sourceUuid) &&
Arrays.equals(message, that.message) &&
Arrays.equals(content, that.content);
}
@Override
public int hashCode() {
int result = Objects.hash(id, cached, guid, type, relay, timestamp, source, sourceUuid, sourceDevice, serverTimestamp);
result = 31 * result + Arrays.hashCode(message);
result = 31 * result + Arrays.hashCode(content);
return result;
}
}

View File

@ -5,10 +5,12 @@ import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.SharedMetricRegistries;
import com.codahale.metrics.Timer;
import com.google.protobuf.InvalidProtocolBufferException;
import io.dropwizard.lifecycle.Managed;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.experiment.Experiment;
import org.whispersystems.textsecuregcm.push.NotPushRegisteredException;
import org.whispersystems.textsecuregcm.push.PushSender;
import org.whispersystems.textsecuregcm.redis.LuaScript;
@ -17,6 +19,9 @@ import org.whispersystems.textsecuregcm.util.Constants;
import org.whispersystems.textsecuregcm.util.Pair;
import org.whispersystems.textsecuregcm.util.Util;
import org.whispersystems.textsecuregcm.websocket.WebsocketAddress;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.Tuple;
import redis.clients.util.SafeEncoder;
import java.io.IOException;
import java.util.Arrays;
@ -27,18 +32,17 @@ import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import static com.codahale.metrics.MetricRegistry.name;
import io.dropwizard.lifecycle.Managed;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.Tuple;
import redis.clients.util.SafeEncoder;
public class MessagesCache implements Managed {
public class MessagesCache implements Managed, UserMessagesCache {
private static final Logger logger = LoggerFactory.getLogger(MessagesCache.class);
private static final Logger logger = LoggerFactory.getLogger(MessagesCache.class);
private static final MetricRegistry metricRegistry = SharedMetricRegistries.getOrCreate(Constants.METRICS_NAME);
private static final Timer insertTimer = metricRegistry.timer(name(MessagesCache.class, "insert" ));
@ -54,50 +58,85 @@ public class MessagesCache implements Managed {
private final AccountsManager accountsManager;
private final int delayMinutes;
private InsertOperation insertOperation;
private RemoveOperation removeOperation;
private GetOperation getOperation;
private final InsertOperation insertOperation;
private final RemoveOperation removeOperation;
private final GetOperation getOperation;
private PubSubManager pubSubManager;
private PushSender pushSender;
private MessagePersister messagePersister;
public MessagesCache(ReplicatedJedisPool jedisPool, Messages database, AccountsManager accountsManager, int delayMinutes) {
this.jedisPool = jedisPool;
this.database = database;
this.accountsManager = accountsManager;
this.delayMinutes = delayMinutes;
private final RedisClusterMessagesCache clusterMessagesCache;
private final ExecutorService experimentExecutor = new ThreadPoolExecutor(8, 8, 0, TimeUnit.MILLISECONDS, new ArrayBlockingQueue<>(1_000));
private final Experiment insertExperiment = new Experiment("MessagesCache", "insert");
private final Experiment removeByIdExperiment = new Experiment("MessagesCache", "removeById");
private final Experiment removeBySenderExperiment = new Experiment("MessagesCache", "removeBySender");
private final Experiment removeByUuidExperiment = new Experiment("MessagesCache", "removeByUuid");
private final Experiment getMessagesExperiment = new Experiment("MessagesCache", "getMessages");
public MessagesCache(ReplicatedJedisPool jedisPool, Messages database, AccountsManager accountsManager, int delayMinutes, RedisClusterMessagesCache clusterMessagesCache) throws IOException {
this.jedisPool = jedisPool;
this.database = database;
this.accountsManager = accountsManager;
this.delayMinutes = delayMinutes;
this.insertOperation = new InsertOperation(jedisPool);
this.removeOperation = new RemoveOperation(jedisPool);
this.getOperation = new GetOperation(jedisPool);
this.clusterMessagesCache = clusterMessagesCache;
}
public void insert(UUID guid, String destination, long destinationDevice, Envelope message) {
message = message.toBuilder().setServerGuid(guid.toString()).build();
@Override
public long insert(UUID guid, String destination, long destinationDevice, Envelope message) {
final Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build();
Timer.Context timer = insertTimer.time();
try {
insertOperation.insert(guid, destination, destinationDevice, System.currentTimeMillis(), message);
final long messageId = insertOperation.insert(guid, destination, destinationDevice, System.currentTimeMillis(), messageWithGuid);
insertExperiment.compareSupplierResultAsync(messageId, () -> clusterMessagesCache.insert(guid, destination, destinationDevice, message, messageId), experimentExecutor);
return messageId;
} finally {
timer.stop();
}
}
public void remove(String destination, long destinationDevice, long id) {
@Override
public Optional<OutgoingMessageEntity> remove(String destination, long destinationDevice, long id) {
OutgoingMessageEntity removedMessageEntity = null;
try (Jedis jedis = jedisPool.getWriteResource();
Timer.Context ignored = removeByIdTimer.time())
{
removeOperation.remove(jedis, destination, destinationDevice, id);
byte[] serialized = removeOperation.remove(jedis, destination, destinationDevice, id);
if (serialized != null) {
removedMessageEntity = UserMessagesCache.constructEntityFromEnvelope(id, Envelope.parseFrom(serialized));
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
final Optional<OutgoingMessageEntity> maybeRemovedMessage = Optional.ofNullable(removedMessageEntity);
removeByIdExperiment.compareSupplierResultAsync(maybeRemovedMessage, () -> clusterMessagesCache.remove(destination, destinationDevice, id), experimentExecutor);
return maybeRemovedMessage;
}
@Override
public Optional<OutgoingMessageEntity> remove(String destination, long destinationDevice, String sender, long timestamp) {
OutgoingMessageEntity removedMessageEntity = null;
Timer.Context timer = removeByNameTimer.time();
try {
byte[] serialized = removeOperation.remove(destination, destinationDevice, sender, timestamp);
if (serialized != null) {
Envelope envelope = Envelope.parseFrom(serialized);
return Optional.of(constructEntityFromEnvelope(0, envelope));
removedMessageEntity = UserMessagesCache.constructEntityFromEnvelope(0, Envelope.parseFrom(serialized));
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
@ -105,18 +144,23 @@ public class MessagesCache implements Managed {
timer.stop();
}
return Optional.empty();
final Optional<OutgoingMessageEntity> maybeRemovedMessage = Optional.ofNullable(removedMessageEntity);
removeBySenderExperiment.compareSupplierResultAsync(maybeRemovedMessage, () -> clusterMessagesCache.remove(destination, destinationDevice, sender, timestamp), experimentExecutor);
return maybeRemovedMessage;
}
@Override
public Optional<OutgoingMessageEntity> remove(String destination, long destinationDevice, UUID guid) {
OutgoingMessageEntity removedMessageEntity = null;
Timer.Context timer = removeByGuidTimer.time();
try {
byte[] serialized = removeOperation.remove(destination, destinationDevice, guid);
if (serialized != null) {
Envelope envelope = Envelope.parseFrom(serialized);
return Optional.of(constructEntityFromEnvelope(0, envelope));
removedMessageEntity = UserMessagesCache.constructEntityFromEnvelope(0, Envelope.parseFrom(serialized));
}
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
@ -124,9 +168,14 @@ public class MessagesCache implements Managed {
timer.stop();
}
return Optional.empty();
final Optional<OutgoingMessageEntity> maybeRemovedMessage = Optional.ofNullable(removedMessageEntity);
removeByUuidExperiment.compareSupplierResultAsync(maybeRemovedMessage, () -> clusterMessagesCache.remove(destination, destinationDevice, guid), experimentExecutor);
return maybeRemovedMessage;
}
@Override
public List<OutgoingMessageEntity> get(String destination, long destinationDevice, int limit) {
Timer.Context timer = getTimer.time();
@ -139,18 +188,21 @@ public class MessagesCache implements Managed {
try {
long id = item.second().longValue();
Envelope message = Envelope.parseFrom(item.first());
results.add(constructEntityFromEnvelope(id, message));
results.add(UserMessagesCache.constructEntityFromEnvelope(id, message));
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}
getMessagesExperiment.compareSupplierResultAsync(results, () -> clusterMessagesCache.get(destination, destinationDevice, limit), experimentExecutor);
return results;
} finally {
timer.stop();
}
}
@Override
public void clear(String destination) {
Timer.Context timer = clearAccountTimer.time();
@ -163,6 +215,7 @@ public class MessagesCache implements Managed {
}
}
@Override
public void clear(String destination, long deviceId) {
Timer.Context timer = clearDeviceTimer.time();
@ -180,11 +233,7 @@ public class MessagesCache implements Managed {
@Override
public void start() throws Exception {
this.insertOperation = new InsertOperation(jedisPool);
this.removeOperation = new RemoveOperation(jedisPool);
this.getOperation = new GetOperation(jedisPool);
this.messagePersister = new MessagePersister(jedisPool, database, pubSubManager, pushSender, accountsManager, delayMinutes, TimeUnit.MINUTES);
this.messagePersister.start();
}
@ -192,20 +241,8 @@ public class MessagesCache implements Managed {
public void stop() throws Exception {
messagePersister.shutdown();
logger.info("Message persister shut down...");
}
private OutgoingMessageEntity constructEntityFromEnvelope(long id, Envelope envelope) {
return new OutgoingMessageEntity(id, true,
envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null,
envelope.getType().getNumber(),
envelope.getRelay(),
envelope.getTimestamp(),
envelope.getSource(),
envelope.hasSourceUuid() ? UUID.fromString(envelope.getSourceUuid()) : null,
envelope.getSourceDevice(),
envelope.hasLegacyMessage() ? envelope.getLegacyMessage().toByteArray() : null,
envelope.hasContent() ? envelope.getContent().toByteArray() : null,
envelope.hasServerTimestamp() ? envelope.getServerTimestamp() : 0);
this.experimentExecutor.shutdown();
}
private static class Key {
@ -271,14 +308,14 @@ public class MessagesCache implements Managed {
this.insert = LuaScript.fromResource(jedisPool, "lua/insert_item.lua");
}
public void insert(UUID guid, String destination, long destinationDevice, long timestamp, Envelope message) {
public long insert(UUID guid, String destination, long destinationDevice, long timestamp, Envelope message) {
Key key = new Key(destination, destinationDevice);
String sender = message.hasSource() ? (message.getSource() + "::" + message.getTimestamp()) : "nil";
List<byte[]> keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex());
List<byte[]> args = Arrays.asList(message.toByteArray(), String.valueOf(timestamp).getBytes(), sender.getBytes(), guid.toString().getBytes());
insert.execute(keys, args);
return (long)insert.execute(keys, args);
}
}
@ -296,13 +333,13 @@ public class MessagesCache implements Managed {
this.removeQueue = LuaScript.fromResource(jedisPool, "lua/remove_queue.lua" );
}
public void remove(Jedis jedis, String destination, long destinationDevice, long id) {
public byte[] remove(Jedis jedis, String destination, long destinationDevice, long id) {
Key key = new Key(destination, destinationDevice);
List<byte[]> keys = Arrays.asList(key.getUserMessageQueue(), key.getUserMessageQueueMetadata(), Key.getUserMessageQueueIndex());
List<byte[]> args = Collections.singletonList(String.valueOf(id).getBytes());
this.removeById.execute(jedis, keys, args);
return (byte[])this.removeById.execute(jedis, keys, args);
}
public byte[] remove(String destination, long destinationDevice, String sender, long timestamp) {

View File

@ -0,0 +1,206 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.InvalidProtocolBufferException;
import io.lettuce.core.ScriptOutputType;
import io.micrometer.core.instrument.Metrics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.redis.ClusterLuaScript;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisCluster;
import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import static com.codahale.metrics.MetricRegistry.name;
public class RedisClusterMessagesCache implements UserMessagesCache {
private final ClusterLuaScript insertScript;
private final ClusterLuaScript removeByIdScript;
private final ClusterLuaScript removeBySenderScript;
private final ClusterLuaScript removeByGuidScript;
private final ClusterLuaScript getItemsScript;
private final ClusterLuaScript removeQueueScript;
private static final String INSERT_TIMER_NAME = name(RedisClusterMessagesCache.class, "insert");
private static final String REMOVE_TIMER_NAME = name(RedisClusterMessagesCache.class, "remove");
private static final String GET_TIMER_NAME = name(RedisClusterMessagesCache.class, "get");
private static final String CLEAR_TIMER_NAME = name(RedisClusterMessagesCache.class, "clear");
private static final String REMOVE_METHOD_TAG = "method";
private static final String REMOVE_METHOD_ID = "id";
private static final String REMOVE_METHOD_SENDER = "sender";
private static final String REMOVE_METHOD_UUID = "uuid";
private static final Logger logger = LoggerFactory.getLogger(RedisClusterMessagesCache.class);
public RedisClusterMessagesCache(final FaultTolerantRedisCluster redisCluster) throws IOException {
this.insertScript = ClusterLuaScript.fromResource(redisCluster, "lua/insert_item.lua", ScriptOutputType.INTEGER);
this.removeByIdScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_id.lua", ScriptOutputType.VALUE);
this.removeBySenderScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_sender.lua", ScriptOutputType.VALUE);
this.removeByGuidScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_item_by_guid.lua", ScriptOutputType.VALUE);
this.getItemsScript = ClusterLuaScript.fromResource(redisCluster, "lua/get_items.lua", ScriptOutputType.MULTI);
this.removeQueueScript = ClusterLuaScript.fromResource(redisCluster, "lua/remove_queue.lua", ScriptOutputType.STATUS);
}
@Override
public long insert(final UUID guid, final String destination, final long destinationDevice, final MessageProtos.Envelope message) {
final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build();
final String sender = message.hasSource() ? (message.getSource() + "::" + message.getTimestamp()) : "nil";
return (long)Metrics.timer(INSERT_TIMER_NAME).record(() ->
insertScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice),
getMessageQueueMetadataKey(destination, destinationDevice),
getQueueIndexKey(destination, destinationDevice)),
List.of(messageWithGuid.toByteArray(),
String.valueOf(message.getTimestamp()).getBytes(StandardCharsets.UTF_8),
sender.getBytes(StandardCharsets.UTF_8),
guid.toString().getBytes(StandardCharsets.UTF_8))));
}
public long insert(final UUID guid, final String destination, final long destinationDevice, final MessageProtos.Envelope message, final long messageId) {
final MessageProtos.Envelope messageWithGuid = message.toBuilder().setServerGuid(guid.toString()).build();
final String sender = message.hasSource() ? (message.getSource() + "::" + message.getTimestamp()) : "nil";
return (long)Metrics.timer(INSERT_TIMER_NAME).record(() ->
insertScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice),
getMessageQueueMetadataKey(destination, destinationDevice),
getQueueIndexKey(destination, destinationDevice)),
List.of(messageWithGuid.toByteArray(),
String.valueOf(message.getTimestamp()).getBytes(StandardCharsets.UTF_8),
sender.getBytes(StandardCharsets.UTF_8),
guid.toString().getBytes(StandardCharsets.UTF_8),
String.valueOf(messageId).getBytes(StandardCharsets.UTF_8))));
}
@Override
public Optional<OutgoingMessageEntity> remove(final String destination, final long destinationDevice, final long id) {
try {
final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_ID).record(() ->
removeByIdScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice),
getMessageQueueMetadataKey(destination, destinationDevice),
getQueueIndexKey(destination, destinationDevice)),
List.of(String.valueOf(id).getBytes(StandardCharsets.UTF_8))));
if (serialized != null) {
return Optional.of(UserMessagesCache.constructEntityFromEnvelope(id, MessageProtos.Envelope.parseFrom(serialized)));
}
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
return Optional.empty();
}
@Override
public Optional<OutgoingMessageEntity> remove(final String destination, final long destinationDevice, final String sender, final long timestamp) {
try {
final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_SENDER).record(() ->
removeBySenderScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice),
getMessageQueueMetadataKey(destination, destinationDevice),
getQueueIndexKey(destination, destinationDevice)),
List.of((sender + "::" + timestamp).getBytes(StandardCharsets.UTF_8))));
if (serialized != null) {
return Optional.of(UserMessagesCache.constructEntityFromEnvelope(0, MessageProtos.Envelope.parseFrom(serialized)));
}
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
return Optional.empty();
}
@Override
public Optional<OutgoingMessageEntity> remove(final String destination, final long destinationDevice, final UUID guid) {
try {
final byte[] serialized = (byte[])Metrics.timer(REMOVE_TIMER_NAME, REMOVE_METHOD_TAG, REMOVE_METHOD_UUID).record(() ->
removeByGuidScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice),
getMessageQueueMetadataKey(destination, destinationDevice),
getQueueIndexKey(destination, destinationDevice)),
List.of(guid.toString().getBytes(StandardCharsets.UTF_8))));
if (serialized != null) {
return Optional.of(UserMessagesCache.constructEntityFromEnvelope(0, MessageProtos.Envelope.parseFrom(serialized)));
}
} catch (final InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
return Optional.empty();
}
@Override
@SuppressWarnings("unchecked")
public List<OutgoingMessageEntity> get(String destination, long destinationDevice, int limit) {
return Metrics.timer(GET_TIMER_NAME).record(() -> {
final List<byte[]> queueItems = (List<byte[]>)getItemsScript.executeBinary(List.of(getMessageQueueKey(destination, destinationDevice),
getPersistInProgressKey(destination, destinationDevice)),
List.of(String.valueOf(limit).getBytes()));
final List<OutgoingMessageEntity> messageEntities;
if (queueItems.size() % 2 == 0) {
messageEntities = new ArrayList<>(queueItems.size() / 2);
for (int i = 0; i < queueItems.size() - 1; i += 2) {
try {
final MessageProtos.Envelope message = MessageProtos.Envelope.parseFrom(queueItems.get(i));
final long id = Long.parseLong(new String(queueItems.get(i + 1), StandardCharsets.UTF_8));
messageEntities.add(UserMessagesCache.constructEntityFromEnvelope(id, message));
} catch (InvalidProtocolBufferException e) {
logger.warn("Failed to parse envelope", e);
}
}
} else {
logger.error("\"Get messages\" operation returned a list with a non-even number of elements.");
messageEntities = Collections.emptyList();
}
return messageEntities;
});
}
@Override
public void clear(final String destination) {
for (int i = 1; i < 256; i++) {
clear(destination, i);
}
}
@Override
public void clear(final String destination, final long deviceId) {
Metrics.timer(CLEAR_TIMER_NAME).record(() ->
removeQueueScript.executeBinary(List.of(getMessageQueueKey(destination, deviceId),
getMessageQueueMetadataKey(destination, deviceId),
getQueueIndexKey(destination, deviceId)),
Collections.emptyList()));
}
private static byte[] getMessageQueueKey(final String address, final long deviceId) {
return ("user_queue::{" + address + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
private static byte[] getMessageQueueMetadataKey(final String address, final long deviceId) {
return ("user_queue_metadata::{" + address + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
private byte[] getQueueIndexKey(final String address, final long deviceId) {
return ("user_queue_index::{" + RedisClusterUtil.getMinimalHashTag(address + "::" + deviceId) + "}").getBytes(StandardCharsets.UTF_8);
}
private byte[] getPersistInProgressKey(final String address, final long deviceId) {
return ("user_queue_persisting::{" + address + "::" + deviceId + "}").getBytes(StandardCharsets.UTF_8);
}
}

View File

@ -0,0 +1,41 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.common.annotations.VisibleForTesting;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.push.PushSender;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
public interface UserMessagesCache {
@VisibleForTesting
static OutgoingMessageEntity constructEntityFromEnvelope(long id, MessageProtos.Envelope envelope) {
return new OutgoingMessageEntity(id, true,
envelope.hasServerGuid() ? UUID.fromString(envelope.getServerGuid()) : null,
envelope.getType().getNumber(),
envelope.getRelay(),
envelope.getTimestamp(),
envelope.getSource(),
envelope.hasSourceUuid() ? UUID.fromString(envelope.getSourceUuid()) : null,
envelope.getSourceDevice(),
envelope.hasLegacyMessage() ? envelope.getLegacyMessage().toByteArray() : null,
envelope.hasContent() ? envelope.getContent().toByteArray() : null,
envelope.hasServerTimestamp() ? envelope.getServerTimestamp() : 0);
}
long insert(UUID guid, String destination, long destinationDevice, MessageProtos.Envelope message);
Optional<OutgoingMessageEntity> remove(String destination, long destinationDevice, long id);
Optional<OutgoingMessageEntity> remove(String destination, long destinationDevice, String sender, long timestamp);
Optional<OutgoingMessageEntity> remove(String destination, long destinationDevice, UUID guid);
List<OutgoingMessageEntity> get(String destination, long destinationDevice, int limit);
void clear(String destination);
void clear(String destination, long deviceId);
}

View File

@ -0,0 +1,37 @@
package org.whispersystems.textsecuregcm.util;
import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
public class RedisClusterUtil {
private static final String[] HASHES_BY_SLOT = new String[SlotHash.SLOT_COUNT];
static {
int slotsCovered = 0;
int i = 0;
while (slotsCovered < HASHES_BY_SLOT.length) {
final String hash = Integer.toString(i++, 36);
final int slot = SlotHash.getSlot(hash);
if (HASHES_BY_SLOT[slot] == null) {
HASHES_BY_SLOT[slot] = hash;
slotsCovered += 1;
}
}
}
/**
* Returns a short Redis hash tag that maps to the same Redis cluster slot as the given key.
*
* @param key the key for which to find a matching hash tag
* @return a Redis hash tag that maps to the same Redis cluster slot as the given key
*
* @see <a href="https://redis.io/topics/cluster-spec#keys-hash-tags">Redis Cluster Specification - Keys hash tags</a>
*/
public static String getMinimalHashTag(final String key) {
return HASHES_BY_SLOT[SlotHash.getSlot(key)];
}
}

View File

@ -1,7 +1,16 @@
-- keys: queue_key [1], queue_metadata_key [2], queue_total_index [3]
-- argv: message [1], current_time [2], sender (possibly null) [3], guid [4]
-- argv: message [1], current_time [2], sender (possibly null) [3], guid [4], messageId (possibly null) [5]
local messageId
if ARGV[5] ~= nil then
-- TODO: Remove this branch (and ARGV[5]) once the migration to a clustered message cache is finished
messageId = tonumber(ARGV[5])
redis.call("HSET", KEYS[2], "counter", messageId)
else
messageId = redis.call("HINCRBY", KEYS[2], "counter", 1)
end
local messageId = redis.call("HINCRBY", KEYS[2], "counter", 1)
redis.call("ZADD", KEYS[1], "NX", messageId, ARGV[1])
if ARGV[3] ~= "nil" then

View File

@ -1,6 +1,7 @@
-- keys: queue_key, queue_metadata_key, queue_index
-- argv: index_to_remove
local envelope = redis.call("ZRANGEBYSCORE", KEYS[1], ARGV[1], ARGV[1], "LIMIT", 0, 1)
local removedCount = redis.call("ZREMRANGEBYSCORE", KEYS[1], ARGV[1], ARGV[1])
local senderIndex = redis.call("HGET", KEYS[2], ARGV[1])
local guidIndex = redis.call("HGET", KEYS[2], ARGV[1] .. "guid")
@ -19,4 +20,8 @@ if (redis.call("ZCARD", KEYS[1]) == 0) then
redis.call("ZREM", KEYS[3], KEYS[1])
end
return removedCount > 0
if envelope and next(envelope) then
return envelope[1]
else
return nil
end

View File

@ -36,7 +36,7 @@ public abstract class AbstractRedisClusterTest {
private FaultTolerantRedisCluster redisCluster;
@BeforeClass
public static void setUpBeforeClass() throws IOException, URISyntaxException, InterruptedException {
public static void setUpBeforeClass() throws Exception {
assumeFalse(System.getProperty("os.name").equalsIgnoreCase("windows"));
clusterNodes = new RedisServer[NODE_COUNT];
@ -50,7 +50,7 @@ public abstract class AbstractRedisClusterTest {
}
@Before
public void setUp() {
public void setUp() throws Exception {
final List<String> urls = Arrays.stream(clusterNodes)
.map(node -> String.format("redis://127.0.0.1:%d", node.ports().get(0)))
.collect(Collectors.toList());
@ -63,12 +63,12 @@ public abstract class AbstractRedisClusterTest {
}
@After
public void tearDown() {
public void tearDown() throws Exception {
redisCluster.stop();
}
@AfterClass
public static void tearDownAfterClass() {
public static void tearDownAfterClass() throws Exception {
for (final RedisServer node : clusterNodes) {
node.stop();
}

View File

@ -0,0 +1,157 @@
package org.whispersystems.textsecuregcm.storage;
import com.google.protobuf.ByteString;
import junitparams.JUnitParamsRunner;
import junitparams.Parameters;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.whispersystems.textsecuregcm.entities.MessageProtos;
import org.whispersystems.textsecuregcm.entities.OutgoingMessageEntity;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.UUID;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@RunWith(JUnitParamsRunner.class)
public abstract class AbstractMessagesCacheTest extends AbstractRedisClusterTest {
private static final String DESTINATION_ACCOUNT = "+18005551234";
private static final int DESTINATION_DEVICE_ID = 7;
private final Random random = new Random();
private long serialTimestamp = 0;
protected abstract UserMessagesCache getMessagesCache();
@Test
@Parameters({"true", "false"})
public void testInsert(final boolean sealedSender) {
final UUID messageGuid = UUID.randomUUID();
assertTrue(getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, generateRandomMessage(messageGuid, sealedSender)) > 0);
}
@Test
@Parameters({"true", "false"})
public void testRemoveById(final boolean sealedSender) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
final long messageId = getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message);
final Optional<OutgoingMessageEntity> maybeRemovedMessage = getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageId);
assertTrue(maybeRemovedMessage.isPresent());
assertEquals(UserMessagesCache.constructEntityFromEnvelope(messageId, message), maybeRemovedMessage.get());
assertEquals(Optional.empty(), getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageId));
}
@Test
public void testRemoveBySender() {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, false);
getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message);
final Optional<OutgoingMessageEntity> maybeRemovedMessage = getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message.getSource(), message.getTimestamp());
assertTrue(maybeRemovedMessage.isPresent());
assertEquals(UserMessagesCache.constructEntityFromEnvelope(0, message), maybeRemovedMessage.get());
assertEquals(Optional.empty(), getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message.getSource(), message.getTimestamp()));
}
@Test
@Parameters({"true", "false"})
public void testRemoveByUUID(final boolean sealedSender) {
final UUID messageGuid = UUID.randomUUID();
assertEquals(Optional.empty(), getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageGuid));
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message);
final Optional<OutgoingMessageEntity> maybeRemovedMessage = getMessagesCache().remove(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageGuid);
assertTrue(maybeRemovedMessage.isPresent());
assertEquals(UserMessagesCache.constructEntityFromEnvelope(0, message), maybeRemovedMessage.get());
}
@Test
@Parameters({"true", "false"})
public void testGetMessages(final boolean sealedSender) {
final int messageCount = 100;
final List<OutgoingMessageEntity> expectedMessages = new ArrayList<>(messageCount);
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
final long messageId = getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, message);
expectedMessages.add(UserMessagesCache.constructEntityFromEnvelope(messageId, message));
}
assertEquals(expectedMessages, getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageCount));
}
@Test
@Parameters({"true", "false"})
public void testClearQueueForDevice(final boolean sealedSender) {
final int messageCount = 100;
for (final int deviceId : new int[] { DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1 }) {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, deviceId, message);
}
}
getMessagesCache().clear(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID);
assertEquals(Collections.emptyList(), getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageCount));
assertEquals(messageCount, getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID + 1, messageCount).size());
}
@Test
@Parameters({"true", "false"})
public void testClearQueueForAccount(final boolean sealedSender) {
final int messageCount = 100;
for (final int deviceId : new int[] { DESTINATION_DEVICE_ID, DESTINATION_DEVICE_ID + 1 }) {
for (int i = 0; i < messageCount; i++) {
final UUID messageGuid = UUID.randomUUID();
final MessageProtos.Envelope message = generateRandomMessage(messageGuid, sealedSender);
getMessagesCache().insert(messageGuid, DESTINATION_ACCOUNT, deviceId, message);
}
}
getMessagesCache().clear(DESTINATION_ACCOUNT);
assertEquals(Collections.emptyList(), getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, messageCount));
assertEquals(Collections.emptyList(), getMessagesCache().get(DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID + 1, messageCount));
}
protected MessageProtos.Envelope generateRandomMessage(final UUID messageGuid, final boolean sealedSender) {
final MessageProtos.Envelope.Builder envelopeBuilder = MessageProtos.Envelope.newBuilder()
.setTimestamp(serialTimestamp++)
.setServerTimestamp(serialTimestamp++)
.setContent(ByteString.copyFromUtf8(RandomStringUtils.randomAlphanumeric(256)))
.setType(MessageProtos.Envelope.Type.CIPHERTEXT)
.setServerGuid(messageGuid.toString());
if (!sealedSender) {
envelopeBuilder.setSourceDevice(random.nextInt(256))
.setSource("+1" + RandomStringUtils.randomNumeric(10));
}
return envelopeBuilder.build();
}
}

View File

@ -0,0 +1,41 @@
package org.whispersystems.textsecuregcm.storage;
import org.junit.After;
import org.junit.Before;
import org.whispersystems.textsecuregcm.configuration.CircuitBreakerConfiguration;
import org.whispersystems.textsecuregcm.providers.RedisClientFactory;
import org.whispersystems.textsecuregcm.redis.AbstractRedisClusterTest;
import org.whispersystems.textsecuregcm.redis.ReplicatedJedisPool;
import redis.embedded.RedisServer;
import java.util.List;
import static org.mockito.Mockito.mock;
public class MessagesCacheTest extends AbstractMessagesCacheTest {
private RedisServer redisServer;
private MessagesCache messagesCache;
@Before
public void setUp() throws Exception {
redisServer = new RedisServer(AbstractRedisClusterTest.getNextRedisClusterPort());
redisServer.start();
final String redisUrl = String.format("redis://127.0.0.1:%d", redisServer.ports().get(0));
final RedisClientFactory clientFactory = new RedisClientFactory("message-cache-test", redisUrl, List.of(redisUrl), new CircuitBreakerConfiguration());
final ReplicatedJedisPool jedisPool = clientFactory.getRedisClientPool();
messagesCache = new MessagesCache(jedisPool, mock(Messages.class), mock(AccountsManager.class), 60, mock(RedisClusterMessagesCache.class));
}
@After
public void tearDown() {
redisServer.stop();
}
@Override
protected UserMessagesCache getMessagesCache() {
return messagesCache;
}
}

View File

@ -0,0 +1,48 @@
package org.whispersystems.textsecuregcm.storage;
import junitparams.Parameters;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.util.UUID;
import static org.junit.Assert.assertEquals;
public class RedisClusterMessagesCacheTest extends AbstractMessagesCacheTest {
private static final String DESTINATION_ACCOUNT = "+18005551234";
private static final int DESTINATION_DEVICE_ID = 7;
private RedisClusterMessagesCache messagesCache;
@Override
@Before
public void setUp() throws Exception {
super.setUp();
try {
messagesCache = new RedisClusterMessagesCache(getRedisCluster());
} catch (final IOException e) {
throw new RuntimeException(e);
}
getRedisCluster().useWriteCluster(connection -> connection.sync().flushall());
}
@Override
protected UserMessagesCache getMessagesCache() {
return messagesCache;
}
@Test
@Parameters({"true", "false"})
public void testInsertWithPrescribedId(final boolean sealedSender) {
final UUID firstMessageGuid = UUID.randomUUID();
final UUID secondMessageGuid = UUID.randomUUID();
final long messageId = 74;
assertEquals(messageId, messagesCache.insert(firstMessageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, generateRandomMessage(firstMessageGuid, sealedSender), messageId));
assertEquals(messageId + 1, messagesCache.insert(secondMessageGuid, DESTINATION_ACCOUNT, DESTINATION_DEVICE_ID, generateRandomMessage(secondMessageGuid, sealedSender)));
}
}